diff --git a/index.html b/index.html index f4ee942..b0803f8 100644 --- a/index.html +++ b/index.html @@ -267,8 +267,7 @@

Prompt Engineering

Prompt Engineering, also known as In-Context Prompting, refers to methods for how to communicate with LLM to steer its behavior for desired outcomes without updating the model weights. It is an empirical science and the effect of prompt engineering methods can vary a lot among models, thus requiring heavy experimentation and heuristics. -Useful resources: - OpenAI Cookbook has many in-depth examples for how to utilize LLM efficiently. Prompt Engineering Guide repo contains a pretty comprehensive collection of education materials on prompt engineering....

+This post only focuses on prompt engineering for autoregressive language models, so nothing with Cloze tests, image generation or multimodality models....

diff --git a/index.json b/index.json index 85c5fba..eec469b 100644 --- a/index.json +++ b/index.json @@ -1 +1 @@ -[{"content":"Prompt Engineering, also known as In-Context Prompting, refers to methods for how to communicate with LLM to steer its behavior for desired outcomes without updating the model weights. It is an empirical science and the effect of prompt engineering methods can vary a lot among models, thus requiring heavy experimentation and heuristics.\nUseful resources:\n OpenAI Cookbook has many in-depth examples for how to utilize LLM efficiently. Prompt Engineering Guide repo contains a pretty comprehensive collection of education materials on prompt engineering. This post only focuses on prompt engineering for autoregressive language models, so nothing with Cloze tests, image generation or multimodality models. At its core, the goal of prompt engineering is about alignment and model steerability. Check my previous post on controllable text generation.\n[My personal spicy take] In my opinion, some prompt engineering papers are not worthy 8 pages long, since those tricks can be explained in one or a few sentences and the rest is all about benchmarking. An easy-to-use and shared benchmark infrastructure should be more beneficial to the community. Iterative prompting or external tool use would not be trivial to set up. Also non-trivial to align the whole research community to adopt it.\nBasic Prompting Zero-shot and few-shot learning are two most basic approaches for prompting the model, pioneered by many LLM papers and commonly used for benchmarking LLM performance.\nZero-Shot Zero-shot learning is to simply feed the task text to the model and ask for results.\n(All the sentiment analysis examples are from SST-2)\nText: i'll bet the video game is a lot more fun than the film. Sentiment: Few-shot Few-shot learning presents a set of high-quality demonstrations, each consisting of both input and desired output, on the target task. As the model first sees good examples, it can better understand human intention and criteria for what kinds of answers are wanted. Therefore, few-shot learning often leads to better performance than zero-shot. However, it comes at the cost of more token consumption and may hit the context length limit when input and output text are long.\nText: (lawrence bounces) all over the stage, dancing, running, sweating, mopping his face and generally displaying the wacky talent that brought him fame in the first place. Sentiment: positive Text: despite all evidence to the contrary, this clunker has somehow managed to pose as an actual feature movie, the kind that charges full admission and gets hyped on tv and purports to amuse small children and ostensible adults. Sentiment: negative Text: for the first time in years, de niro digs deep emotionally, perhaps because he's been stirred by the powerful work of his co-stars. Sentiment: positive Text: i'll bet the video game is a lot more fun than the film. Sentiment: Many studies looked into how to construct in-context examples to maximize the performance and observed that choice of prompt format, training examples, and the order of the examples can lead to dramatically different performance, from near random guess to near SoTA.\nZhao et al. (2021) investigated the case of few-shot classification and proposed that several biases with LLM (they use GPT-3 in the experiments) contribute to such high variance: (1) Majority label bias exists if distribution of labels among the examples is unbalanced; (2) Recency bias refers to the tendency where the model may repeat the label at the end; (3) Common token bias indicates that LLM tends to produce common tokens more often than rare tokens. To conquer such bias, they proposed a method to calibrate the label probabilities output by the model to be uniform when the input string is N/A.\nTips for Example Selection Choose examples that are semantically similar to the test example using $k$-NN clustering in the embedding space (Liu et al., 2021)\n To select a diverse and representative set of examples, Su et al. (2022) proposed to use a graph-based approach: (1) First, construct a directed graph $G=(V, E)$ based on the embedding (e.g. by SBERT or other embedding models) cosine similarity between samples, where each node points to its $k$ nearest neighbors; (2) Start with a set of selected samples $\\mathcal{L}=\\emptyset$ and a set of remaining samples $\\mathcal{U}$. Each sample $u \\in \\mathcal{U}$ is scored by $$ \\text{score}(u) = \\sum_{v \\in \\{v \\mid (u, v) \\in E, v\\in \\mathcal{U}\\}} s(v)\\quad\\text{where }s(v)=\\rho^{- \\vert \\{\\ell \\in \\mathcal{L} \\vert (v, \\ell)\\in E \\}\\vert},\\quad\\rho \u0026gt; 1 $$ such that $s(v)$ is low if many of $v$\u0026rsquo;s neighbors are selected and thus the scoring encourages to pick diverse samples.\n Rubin et al. (2022) proposed to train embeddings via contrastive learning specific to one training dataset for in-context learning sample selection. Given each training pair $(x, y)$, the quality of one example $e_i$ (formatted input-output pair) can be measured by a conditioned probability assigned by LM: $\\text{score}(e_i) = P_\\text{LM}(y \\mid e_i, x)$. We can identify other examples with top-$k$ and bottom-$k$ scores as positive and negative sets of candidates for every training pair and use that for contrastive learning.\n Some researchers tried Q-Learning to do sample selection. (Zhang et al. 2022)\n Motivated by uncertainty-based active learning, Diao et al. (2023) suggested to identify examples with high disagreement or entropy among multiple sampling trials. Then annotate these examples to be used in few-shot prompts.\n Tips for Example Ordering A general suggestion is to keep the selection of examples diverse, relevant to the test sample and in random order to avoid majority label bias and recency bias. Increasing model sizes or including more training examples does not reduce variance among different permutations of in-context examples. Same order may work well for one model but badly for another. When the validation set is limited, consider choosing the order such that the model does not produce extremely unbalanced predictions or being overconfident about its predictions. (Lu et al. 2022) Instruction Prompting The purpose of presenting few-shot examples in the prompt is to explain our intent to the model; in other words, describe the task instruction to the model in the form of demonstrations. However, few-shot can be expensive in terms of token usage and restricts the input length due to limited context length. So, why not just give the instruction directly?\nInstructed LM (e.g. InstructGPT, natural instruction) finetunes a pretrained model with high-quality tuples of (task instruction, input, ground truth output) to make LM better understand user intention and follow instruction. RLHF (Reinforcement Learning from Human Feedback) is a common method to do so. The benefit of instruction following style fine-tuning improves the model to be more aligned with human intention and greatly reduces the cost of communication.\nWhen interacting with instruction models, we should describe the task requirement in details, trying to be specific and precise and avoiding say \u0026ldquo;not do something\u0026rdquo; but rather specify what to do.\nPlease label the sentiment towards the movie of the given movie review. The sentiment label should be \u0026quot;positive\u0026quot; or \u0026quot;negative\u0026quot;. Text: i'll bet the video game is a lot more fun than the film. Sentiment: Explaining the desired audience is another smart way to give instructions\n For example to produce education materials for kids, Describe what is quantum physics to a 6-year-old. And safe content, ... in language that is safe for work. In-context instruction learning (Ye et al. 2023) combines few-shot learning with instruction prompting. It incorporates multiple demonstration examples across different tasks in the prompt, each demonstration consisting of instruction, task input and output. Note that their experiments were only on classification tasks and the instruction prompt contains all label options.\nDefinition: Determine the speaker of the dialogue, \u0026quot;agent\u0026quot; or \u0026quot;customer\u0026quot;. Input: I have successfully booked your tickets. Ouput: agent Definition: Determine which category the question asks for, \u0026quot;Quantity\u0026quot; or \u0026quot;Location\u0026quot;. Input: What's the oldest building in US? Ouput: Location Definition: Classify the sentiment of the given movie review, \u0026quot;positive\u0026quot; or \u0026quot;negative\u0026quot;. Input: i'll bet the video game is a lot more fun than the film. Output: Self-Consistency Sampling Self-consistency sampling (Wang et al. 2022a) is to sample multiple outputs with temperature \u0026gt; 0 and then selecting the best one out of these candidates. The criteria for selecting the best candidate can vary from task to task. A general solution is to pick majority vote. For tasks that are easy to validate such as a programming question with unit tests, we can simply run through the interpreter and verify the correctness with unit tests.\nChain-of-Thought (CoT) Chain-of-thought (CoT) prompting (Wei et al. 2022) generates a sequence of short sentences to describe reasoning logics step by step, known as reasoning chains or rationales, to eventually lead to the final answer. The benefit of CoT is more pronounced for complicated reasoning tasks, while using large models (e.g. with more than 50B parameters). Simple tasks only benefit slightly from CoT prompting.\nTypes of CoT prompts Two main types of CoT prompting:\n Few-shot CoT. It is to prompt the model with a few demonstrations, each containing manually written (or model-generated) high-quality reasoning chains. (All the math reasoning examples are from GSM8k)\nQuestion: Tom and Elizabeth have a competition to climb a hill. Elizabeth takes 30 minutes to climb the hill. Tom takes four times as long as Elizabeth does to climb the hill. How many hours does it take Tom to climb up the hill? Answer: It takes Tom 30*4 = \u0026lt;\u0026lt;30*4=120\u0026gt;\u0026gt;120 minutes to climb the hill. It takes Tom 120/60 = \u0026lt;\u0026lt;120/60=2\u0026gt;\u0026gt;2 hours to climb the hill. So the answer is 2. === Question: Jack is a soccer player. He needs to buy two pairs of socks and a pair of soccer shoes. Each pair of socks cost $9.50, and the shoes cost $92. Jack has $40. How much more money does Jack need? Answer: The total cost of two pairs of socks is $9.50 x 2 = $\u0026lt;\u0026lt;9.5*2=19\u0026gt;\u0026gt;19. The total cost of the socks and the shoes is $19 + $92 = $\u0026lt;\u0026lt;19+92=111\u0026gt;\u0026gt;111. Jack need $111 - $40 = $\u0026lt;\u0026lt;111-40=71\u0026gt;\u0026gt;71 more. So the answer is 71. === Question: Marty has 100 centimeters of ribbon that he must cut into 4 equal parts. Each of the cut parts must be divided into 5 equal parts. How long will each final cut be? Answer: Zero-shot CoT. Use natural language statement like Let's think step by step to explicitly encourage the model to first generate reasoning chains and then to prompt with Therefore, the answer is to produce answers (Kojima et al. 2022 ). Or a similar statement Let's work this out it a step by step to be sure we have the right answer (Zhou et al. 2022). Question: Marty has 100 centimeters of ribbon that he must cut into 4 equal parts. Each of the cut parts must be divided into 5 equal parts. How long will each final cut be? Answer: Let's think step by step. Tips and Extensions Self-consistency sampling can improve reasoning accuracy by sampling a number of diverse answers and then taking the majority vote. (Wang et al. 2022a)\n Another approach for ensemble learning is to alter the example order or use model generated rationales to replace human-written ones to introduce randomness during multiple sample trials. Then aggregate model outputs with a majority vote to get final answer. (Wang et al. 2022b)\n If training examples are only associated with true answers (easy to verify!) but no rationales, we can follow the STaR (Self-Taught Reasoner; Zelikman et al. 2022) method : (1) Ask LLM to generate reasoning chains and only keep those leading to correct answers; (2) Then fine-tune the model with generated rationales and repeat the process until convergence. Note that higher temperature is more likely to generate incorrect rationales with correct answers. If training examples do not have ground truth answers, maybe consider using majority votes as the \u0026ldquo;correct\u0026rdquo; answers.\n Prompts with demonstrations of higher reasoning complexity can achieve better performance, where complexity is measured by the number of reasoning steps in the chains. When separating reasoning steps, newline \\n symbol works better than step i, period . or semicolon ;. (Fu et al. 2023)\n Complexity-based consistency is to explicitly prefer complex chains among all the generations by taking majority vote among only top $k$ complex chains. (Fu et al. 2023)\n Later, Shum et al. (2023) found that in their experiments CoT prompts with only complex examples can improve the accuracy of complex questions, but perform poorly in simple questions; evidence shown on GSM8k.\n Changing Q: to Question: is found to be helpful. (Fu et al. 2023)\n Ye \u0026amp; Durrett (2022) found that the benefit of including explanations in the prompt is small to moderate for NLP tasks that involve reasoning over text (i.e. QA and NLI) and the effects vary by models. They observed that explanations are more likely to be nonfactual than be inconsistent (i.e. whether explanation entails prediction). Nonfactual explanations most likely lead to incorrect predictions.\n Self-Ask (Press et al. 2022) is a method to repeatedly prompt the model to ask following-up questions to construct the thought process iteratively. Follow-up questions can be answered by search engine results. Similarly, IRCoT (Interleaving Retrieval CoT; Trivedi et al. 2022) and ReAct (Reason + Act; Yao et al. 2023) combines iterative CoT prompting with queries to Wikipedia APIs to search for relevant entities and content and then add it back into the context.\n Fig. 1. How Self-Ask works with external search queries.(Image source: Press et al. 2022). Automatic Prompt Design Prompt is a sequence of prefix tokens that increase the probability of getting desired output given input. Therefore we can treat them as trainable parameters and optimize them directly on the embedding space via gradient descent, such as AutoPrompt (Shin et al., 2020, Prefix-Tuning (Li \u0026amp; Liang (2021)), P-tuning (Liu et al. 2021) and Prompt-Tuning (Lester et al. 2021). This section in my \u0026ldquo;Controllable Neural Text Generation\u0026rdquo; post has a good coverage of them. The trend from AutoPrompt to Prompt-Tuning is that the setup gets gradually simplified.\nAPE (Automatic Prompt Engineer; Zhou et al. 2022) is a method to search over a pool of model-generated instruction candidates and then filters the candidate set according to a chosen score function to ultimately choose the best candidate with highest score.\n Prompt LLM to generate instruction candidates based on a small set of demonstrations in the form of input-output pairs. E.g. {{Given desired input-output pairs}}\\n\\nThe instruction is.\n Given a dataset of $\\mathcal{D}_\\text{train} = \\{(x, y)\\}$, we would like to find an instruction $\\rho$ such that $\\rho^* = \\arg\\max_\\rho \\mathbb{E}_{(x, y) \\in \\mathcal{D}_\\text{train}} [f(\\rho, x, y)]$, where $f(.)$ is a per-sample score function, such as execution accuracy $\\mathbb{1}[\\text{LM}(.\\vert \\rho, x)=y]$ or log probability: $p_\\text{LM}(y \\mid \\rho, x)$.\n Use an iterative Monte Carlo search method to improve the best candidates by proposing semantically similar variants via prompts like Generate a variation of the following instruction while keeping the semantic meaning.\\n\\nInput: ...\\n\\nOutput:...\n To construct chain-of-thought prompts automatically, Shum et al. (2023) suggested augment-prune-select, a three-step process:\n Augment: Generate multiple pseudo-chains of thought given question using few-shot or zero-shot CoT prompts; Prune: Prune pseudo chains based on whether generated answers match ground truths. Select: Apply a variance-reduced policy gradient strategy to learn the probability distribution over selected examples, while considering the probability distribution over examples as policy and the validation set accuracy as reward. Zhang et al. (2023) instead adopted clustering techniques to sample questions and then generates chains. They observed that LLMs tend to make certain types of mistakes. One type of errors can be similar in the emebedding space and thus get grouped together. By only sampling one or a few from frequent-error clusters, we can prevent too many wrong demonstrations of one error type and collect a diverse set of examples.\n Question clustering: Embed questions and run $k$-means for clustering. Demonstration selection: Select a set of representative questions from each cluster; i.e. one demonstration from one cluster. Samples in each cluster are sorted by distance to the cluster centroid and those closer to the centroid are selected first. Rationale generation: Use zero-shot CoT to generate reasoning chains for selected questions and construct few-shot prompt to run inference. Augmented Language Models A survey on augmented language models by Mialon et al. (2023) has great coverage over multiple categories of language models augmented with reasoning skills and the ability of using external tools. Recommend it.\nRetrieval Often we need to complete tasks that require latest knowledge after the model pretraining time cutoff or internal/private knowledge base. In that case, the model would not know the context if we don’t explicitly provide it in the prompt. Many methods for Open Domain Question Answering depend on first doing retrieval over a knowledge base and then incorporating the retrieved content as part of the prompt. The accuracy of such a process depends on the quality of both retrieval and generation steps.\nLazaridou et al. (2022) studied how to use Google Search for document retrieval to augment LLMs. Given a question $q$, clean text is extracted out of 20 URLs returned by Google, resulting in a set of documents. Because these documents are long, each document is split into paragraphs of 6 sentences, $\\{p\\}$. Paragraphs are ranked by TF-IDF based cosine similarity between evidence paragraphs and the query. Only the most relevant paragraph is used in the prompt to produce an answer $a$.\nFor closed-book QA, each demonstration is formatted as follows to construct few-shot prompts. Swapping the question with the evidence (longer distance between questions and answers) is found to consistently yield lower results across all datasets.\nEvidence: ... Question: ... Answer: ... The answer probability is computed in three ways:\n RAG style, $p(a_i \\mid q) = \\sum_{i=1}^n p_\\text{tf-idf} (p_i \\mid q) \\cdot p_\\text{LM}(a_i \\mid q, p_i)$, where $p_\\text{tf-idf} (p_i \\mid q)$ is the normalized cosine similarities between the TF-IDF passage and question representations. Noisy channel inference, $p(a_i\\mid q) = \\frac{p_\\text{LM}(q \\mid a_i, p_i) \\cdot p_\\text{LM}(a_i \\mid p_i)}{p_\\text{LM}(q \\mid p_i)}$ Product-of-Experts (PoE), combines all probabilities used above in addition to $p_\\text{LM}(p_i \\mid q)$. According to their experiments on generation and classification tasks, among three answer reranking scores - PoE \u0026gt; Noisy channel \u0026gt; RAG. Among individual probabilities, $p_\\text{LM}(a \\mid q, p_i)$ and $p_\\text{LM}(q \\mid p_i, a)$ are found to be most informative. $p_\\text{LM}(q \\mid p_i, a)$ captures how well the question can be explained by LM given evidence paragraph and answer and can reliably be used for reranking answer candidates.\nOne observation with SituatedQA dataset for questions grounded in different dates is that despite LM (pretraining cutoff is year 2020) has access to latest information via Google Search, its performance on post-2020 questions are still a lot worse than on pre-2020 questions. This suggests the existence of some discrepencies or conflicting parametric between contextual information and model internal knowledge.\nInterestingly it is found to be beneficial even with only \u0026ldquo;internal retrieval\u0026rdquo;, that is, to generate knowledge about a topic before answering the question (Liu et al. 2022). First we can use the following template to extract knowledge:\nGenerate some knowledge about the input. Examples: Input: What type of water formation is formed by clouds? Knowledge: Clouds are made of water vapor. Input: {question} Knowledge: And then with model-generated knowledge, prompt the LM further to get the answer.\nProgramming Language Both PAL (Program-aided language models); Gao et al. 2022) and PoT (Program of Thoughts prompting; Chen et al. 2022) ask LLM to generate programming language statements to resolve natural language reasoning problems, hence offloading the solution step to a runtime such as a Python interpreter. Such setup decouples complex computation and reasoning. It relies on a LM with good enough coding skills.\nFig. 2. Comparing CoT and PoT. (Image source: Chen et al. 2022). External APIs TALM (Tool Augmented Language Models; Parisi et al. 2022) is a language model augmented with text-to-text API calls. LM is guided to generate |tool-call and tool input text conditioned on task input text to construct API call requests. When |result shows up, the specified tool API is called and the returned result gets appended to the text sequence. The final output is generated following |output token.\nFig. 3. The format of API calls in TALM. (Image source: Parisi et al. 2022). TALM adopts a self-play approach to iteratively bootstrap the dataset of tool use examples and finetune LM with it. This iterative self-play pipeline mimics a RL process where LM is the policy network and it is trained by policy gradient with a binary reward signal.\nFig. 4. Self-play iterations help boost the model performance.(Image source: Parisi et al. 2022). Toolformer (Schick et al. 2023) is a LM that can use external tools via simple APIs, which is built in a self-supervised manner and only requires a handful of demonstrations for each API. The toolbox of Toolformer includes:\n Calculator to help LM with the lack of precise math skills; Q\u0026amp;A system to help with unfaithful content and hallucination; Search engine to provide up-to-date information after pretraining cut off time; Translation system to improve performance on low resource language; Calendar to make LM be aware of time progression. Fig. 5. Illustration of how to build Toolformer.(Image source: Schick et al. 2023). Toolformer is trained as follows:\n Prompting to annotate potential API calls. Ask a pre-trained LM to annotate a dataset via few-shot learning with API call usage examples. Formatting example:\nFig. 6. How dataset is annotated to do API calls.(Image source: Schick et al. 2023). Each API call is represented as a tuple of (API name, corresponding input), $c=(a_c, i_c)$ and its corresponding result is denoted as $r$. The API call sequences with and without results are labeled as follows, respectively:\n $$ \\begin{aligned} e(c) \u0026= \\langle\\texttt{API}\\rangle a_c(i_c) \\langle\\texttt{/API}\\rangle \\\\ e(c, r) \u0026= \\langle\\texttt{API}\\rangle a_c(i_c) \\to r \\langle\\texttt{/API}\\rangle \\end{aligned} $$ Sample API calls based on the probabilities $p_\\text{LM}(\\langle\\texttt{API}\\rangle \\mid \\text{prompt}(\\mathbf{x}), \\mathbf{x}_{1:i})$ and select top $k$ candidate positions for doing API calls at position $i$ if the probability is larger than a threshold.\n Then we sample potential API calls from the LM given the sequence $[\\text{prompt}(\\mathbf{x}), x_1, \\dots, x_{i-1}, \\langle\\texttt{API}\\rangle]$ as prefix and $\\langle\\texttt{/API}\\rangle$ as suffix.\n Filter annotations based on whether API calls help model predict future tokens. Use a self-supervised loss to decide which API calls are actually helpful.\n Execute each API call $c_i$ to get corresponding result $r_i$.\n Compute weighted cross entropy loss for the LM over tokens $x_i, \\dots, x_n$ when the model is prefixed with the prompt. Two versions are computed, one with API result and the other with empty sequence $\\varepsilon$.\n $$ \\begin{aligned} L^+_i \u0026= L_i(e(c_i, r_i)) \\\\ L^-_i \u0026= \\min(L_i(\\varepsilon), L_i(e(c_i, \\varepsilon))) \\\\ \\end{aligned} $$ Only API calls with $L^-_i - L^+_i$ larger than a threshold are kept, meaning that adding this API call and its results help the model predict future tokens.\n Fine-tune LM on this annotated dataset. The new training sequences are constructed as $\\mathbf{x}^* = x_{1:i-1}, e(c_i, r_i), x_{i:n}$ . The training data is a combination of the original dataset (e.g. a subset of CCNet, as in the paper) and its augmented version.\n At inference time, decoding runs until the model produces \u0026ldquo;$\\to$ \u0026quot; token, indicating that it is expecting response from an API call next.\nToolformer currently does not support tool use in a chain (i.e. using the output of one tool as an input for another tool) or in an interactive way (i.e. adopt API response after human selection). Both are interesting future directions to expand the model for.\nCitation Cited as:\n Weng, Lilian. (Mar 2023). Prompt Engineering. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/.\n Or\n@article{weng2023prompt, title = \u0026quot;Prompt Engineering\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2023\u0026quot;, month = \u0026quot;Mar\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/\u0026quot; } References [1] Zhao et al. \u0026ldquo;Calibrate Before Use: Improving Few-shot Performance of Language Models.\u0026quot; ICML 2021\n[2] Liu et al. \u0026ldquo;What Makes Good In-Context Examples for GPT-3?\u0026quot; arXiv preprint arXiv:2101.06804 (2021).\n[3] Lu et al. \u0026ldquo;Fantastically Ordered Prompts and Where to Find Them: Overcoming Few-Shot Prompt Order Sensitivity.\u0026quot; ACL 2022\n[4] Ye et al. \u0026ldquo;In-Context Instruction Learning.\u0026quot; arXiv preprint arXiv:2302.14691 (2023).\n[5] Su et al. \u0026ldquo;Selective annotation makes language models better few-shot learners.\u0026quot; arXiv preprint arXiv:2209.01975 (2022).\n[6] Rubin et al. \u0026ldquo;Learning to retrieve prompts for in-context learning.\u0026quot; NAACL-HLT 2022\n[7] Wei et al. \u0026ldquo;Chain of thought prompting elicits reasoning in large language models.\u0026quot; NeurIPS 2022\n[8] Wang et al. \u0026ldquo;Self-Consistency Improves Chain of Thought Reasoning in Language Models.\u0026quot; ICLR 2023.\n[9] Diao et al. \u0026ldquo;Active Prompting with Chain-of-Thought for Large Language Models.\u0026quot; arXiv preprint arXiv:2302.12246 (2023).\n[10] Zelikman et al. \u0026ldquo;STaR: Bootstrapping Reasoning With Reasoning.\u0026quot; arXiv preprint arXiv:2203.14465 (2022).\n[11] Ye \u0026amp; Durrett. \u0026ldquo;The unreliability of explanations in few-shot in-context learning.\u0026quot; arXiv preprint arXiv:2205.03401 (2022).\n[12] Trivedi et al. \u0026ldquo;Interleaving retrieval with chain-of-thought reasoning for knowledge-intensive multi-step questions.\u0026quot; arXiv preprint arXiv:2212.10509 (2022).\n[13] Press et al. \u0026ldquo;Measuring and narrowing the compositionality gap in language models.\u0026quot; arXiv preprint arXiv:2210.03350 (2022).\n[14] Yao et al. \u0026ldquo;ReAct: Synergizing reasoning and acting in language models.\u0026quot; ICLR 2023.\n[15] Fu et al. \u0026ldquo;Complexity-based prompting for multi-step reasoning.\u0026quot; arXiv preprint arXiv:2210.00720 (2022).\n[16] Wang et al. \u0026ldquo;Rationale-augmented ensembles in language models.\u0026quot; arXiv preprint arXiv:2207.00747 (2022).\n[17] Zhang et al. \u0026ldquo;Automatic chain of thought prompting in large language models.\u0026quot; arXiv preprint arXiv:2210.03493 (2022).\n[18] Shum et al. \u0026ldquo;Automatic Prompt Augmentation and Selection with Chain-of-Thought from Labeled Data.\u0026quot; arXiv preprint arXiv:2302.12822 (2023).\n[19] Zhou et al. \u0026ldquo;Large Language Models Are Human-Level Prompt Engineers.\u0026quot; ICLR 2023.\n[20] Lazaridou et al. \u0026ldquo;Internet augmented language models through few-shot prompting for open-domain question answering.\u0026quot; arXiv preprint arXiv:2203.05115 (2022).\n[21] Chen et al. \u0026ldquo;Program of Thoughts Prompting: Disentangling Computation from Reasoning for Numerical Reasoning Tasks.\u0026quot; arXiv preprint arXiv:2211.12588 (2022).\n[22] Gao et al. \u0026ldquo;PAL: Program-aided language models.\u0026quot; arXiv preprint arXiv:2211.10435 (2022).\n[23] Parisi et al. \u0026ldquo;TALM: Tool Augmented Language Models\u0026rdquo; arXiv preprint arXiv:2205.12255 (2022).\n[24] Schick et al. \u0026ldquo;Toolformer: Language Models Can Teach Themselves to Use Tools.\u0026quot; arXiv preprint arXiv:2302.04761 (2023).\n[25] Mialon et al. \u0026ldquo;Augmented Language Models: a Survey\u0026rdquo; arXiv preprint arXiv:2302.07842 (2023).\n","permalink":"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/","summary":"Prompt Engineering, also known as In-Context Prompting, refers to methods for how to communicate with LLM to steer its behavior for desired outcomes without updating the model weights. It is an empirical science and the effect of prompt engineering methods can vary a lot among models, thus requiring heavy experimentation and heuristics.\nUseful resources:\n OpenAI Cookbook has many in-depth examples for how to utilize LLM efficiently. Prompt Engineering Guide repo contains a pretty comprehensive collection of education materials on prompt engineering.","title":"Prompt Engineering"},{"content":"Many new Transformer architecture improvements have been proposed since my last post on \u0026ldquo;The Transformer Family\u0026rdquo; about three years ago. Here I did a big refactoring and enrichment of that 2020 post \u0026mdash; restructure the hierarchy of sections and improve many sections with more recent papers. Version 2.0 is a superset of the old version, about twice the length.\nNotations Symbol Meaning $d$ The model size / hidden state dimension / positional encoding size. $h$ The number of heads in multi-head attention layer. $L$ The segment length of input sequence. $N$ The total number of attention layers in the model; not considering MoE. $\\mathbf{X} \\in \\mathbb{R}^{L \\times d}$ The input sequence where each element has been mapped into an embedding vector of shape $d$, same as the model size. $\\mathbf{W}^k \\in \\mathbb{R}^{d \\times d_k}$ The key weight matrix. $\\mathbf{W}^q \\in \\mathbb{R}^{d \\times d_k}$ The query weight matrix. $\\mathbf{W}^v \\in \\mathbb{R}^{d \\times d_v}$ The value weight matrix. Often we have $d_k = d_v = d$. $\\mathbf{W}^k_i, \\mathbf{W}^q_i \\in \\mathbb{R}^{d \\times d_k/h}; \\mathbf{W}^v_i \\in \\mathbb{R}^{d \\times d_v/h}$ The weight matrices per head. $\\mathbf{W}^o \\in \\mathbb{R}^{d_v \\times d}$ The output weight matrix. $\\mathbf{Q} = \\mathbf{X}\\mathbf{W}^q \\in \\mathbb{R}^{L \\times d_k}$ The query embedding inputs. $\\mathbf{K} = \\mathbf{X}\\mathbf{W}^k \\in \\mathbb{R}^{L \\times d_k}$ The key embedding inputs. $\\mathbf{V} = \\mathbf{X}\\mathbf{W}^v \\in \\mathbb{R}^{L \\times d_v}$ The value embedding inputs. $\\mathbf{q}_i, \\mathbf{k}_i \\in \\mathbb{R}^{d_k}, \\mathbf{v}_i \\in \\mathbb{R}^{d_v}$ Row vectors in query, key, value matrices, $\\mathbf{Q}$, $\\mathbf{K}$ and $\\mathbf{V}$. $S_i$ A collection of key positions for the $i$-th query $\\mathbf{q}_i$ to attend to. $\\mathbf{A} \\in \\mathbb{R}^{L \\times L}$ The self-attention matrix between a input sequence of lenght $L$ and itself. $\\mathbf{A} = \\text{softmax}(\\mathbf{Q}\\mathbf{K}^\\top / \\sqrt{d_k})$. $a_{ij} \\in \\mathbf{A}$ The scalar attention score between query $\\mathbf{q}_i$ and key $\\mathbf{k}_j$. $\\mathbf{P} \\in \\mathbb{R}^{L \\times d}$ position encoding matrix, where the $i$-th row $\\mathbf{p}_i$ is the positional encoding for input $\\mathbf{x}_i$. Transformer Basics The Transformer (which will be referred to as \u0026ldquo;vanilla Transformer\u0026rdquo; to distinguish it from other enhanced versions; Vaswani, et al., 2017) model has an encoder-decoder architecture, as commonly used in many NMT models. Later simplified Transformer was shown to achieve great performance in language modeling tasks, like in encoder-only BERT or decoder-only GPT.\nAttention and Self-Attention Attention is a mechanism in neural network that a model can learn to make predictions by selectively attending to a given set of data. The amount of attention is quantified by learned weights and thus the output is usually formed as a weighted average.\nSelf-attention is a type of attention mechanism where the model makes prediction for one part of a data sample using other parts of the observation about the same sample. Conceptually, it feels quite similar to non-local means. Also note that self-attention is permutation-invariant; in other words, it is an operation on sets.\nThere are various forms of attention / self-attention, Transformer (Vaswani et al., 2017) relies on the scaled dot-product attention: given a query matrix $\\mathbf{Q}$, a key matrix $\\mathbf{K}$ and a value matrix $\\mathbf{V}$, the output is a weighted sum of the value vectors, where the weight assigned to each value slot is determined by the dot-product of the query with the corresponding key:\n $$ \\text{attn}(\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}) = \\text{softmax}(\\frac{\\mathbf{Q} {\\mathbf{K}}^\\top}{\\sqrt{d_k}})\\mathbf{V} $$ And for a query and a key vector $\\mathbf{q}_i, \\mathbf{k}_j \\in \\mathbb{R}^d$ (row vectors in query and key matrices), we have a scalar score:\n $$ a_{ij} = \\text{softmax}(\\frac{\\mathbf{q}_i {\\mathbf{k}_j}^\\top}{\\sqrt{d_k}}) = \\frac{\\exp(\\mathbf{q}_i {\\mathbf{k}_j}^\\top)}{ \\sqrt{d_k} \\sum_{r \\in \\mathcal{S}_i} \\exp(\\mathbf{q}_i {\\mathbf{k}_r}^\\top) } $$ where $\\mathcal{S}_i$ is a collection of key positions for the $i$-th query to attend to.\nSee my old post for other types of attention if interested.\nMulti-Head Self-Attention The multi-head self-attention module is a key component in Transformer. Rather than only computing the attention once, the multi-head mechanism splits the inputs into smaller chunks and then computes the scaled dot-product attention over each subspace in parallel. The independent attention outputs are simply concatenated and linearly transformed into expected dimensions.\n $$ \\begin{aligned} \\text{MultiHeadAttn}(\\mathbf{X}_q, \\mathbf{X}_k, \\mathbf{X}_v) \u0026= [\\text{head}_1; \\dots; \\text{head}_h] \\mathbf{W}^o \\\\ \\text{where head}_i \u0026= \\text{Attention}(\\mathbf{X}_q\\mathbf{W}^q_i, \\mathbf{X}_k\\mathbf{W}^k_i, \\mathbf{X}_v\\mathbf{W}^v_i) \\end{aligned} $$ where $[.;.]$ is a concatenation operation. $\\mathbf{W}^q_i, \\mathbf{W}^k_i \\in \\mathbb{R}^{d \\times d_k/h}, \\mathbf{W}^v_i \\in \\mathbb{R}^{d \\times d_v/h}$ are weight matrices to map input embeddings of size $L \\times d$ into query, key and value matrices. And $\\mathbf{W}^o \\in \\mathbb{R}^{d_v \\times d}$ is the output linear transformation. All the weights should be learned during training.\nFig. 1. Illustration of the multi-head scaled dot-product attention mechanism. (Image source: Figure 2 in Vaswani, et al., 2017) Encoder-Decoder Architecture The encoder generates an attention-based representation with capability to locate a specific piece of information from a large context. It consists of a stack of 6 identity modules, each containing two submodules, a multi-head self-attention layer and a point-wise fully connected feed-forward network. By point-wise, it means that it applies the same linear transformation (with same weights) to each element in the sequence. This can also be viewed as a convolutional layer with filter size 1. Each submodule has a residual connection and layer normalization. All the submodules output data of the same dimension $d$.\nThe function of Transformer decoder is to retrieve information from the encoded representation. The architecture is quite similar to the encoder, except that the decoder contains two multi-head attention submodules instead of one in each identical repeating module. The first multi-head attention submodule is masked to prevent positions from attending to the future.\nFig. 2. The architecture of the vanilla Transformer model. (Image source: Figure 17) Positional Encoding Because self-attention operation is permutation invariant, it is important to use proper positional encoding to provide order information to the model. The positional encoding $\\mathbf{P} \\in \\mathbb{R}^{L \\times d}$ has the same dimension as the input embedding, so it can be added on the input directly. The vanilla Transformer considered two types of encodings:\nSinusoidal Positional Encoding Sinusoidal positional encoding is defined as follows, given the token position $i=1,\\dots,L$ and the dimension $\\delta=1,\\dots,d$:\n $$ \\text{PE}(i,\\delta) = \\begin{cases} \\sin(\\frac{i}{10000^{2\\delta'/d}}) \u0026 \\text{if } \\delta = 2\\delta'\\\\ \\cos(\\frac{i}{10000^{2\\delta'/d}}) \u0026 \\text{if } \\delta = 2\\delta' + 1\\\\ \\end{cases} $$ In this way each dimension of the positional encoding corresponds to a sinusoid of different wavelengths in different dimensions, from $2\\pi$ to $10000 \\cdot 2\\pi$.\nFig. 3. Sinusoidal positional encoding with $L=32$ and $d=128$. The value is between -1 (black) and 1 (white) and the value 0 is in gray. Learned Positional Encoding Learned positional encoding assigns each element with a learned column vector which encodes its absolute position (Gehring, et al. 2017) and furthermroe this encoding can be learned differently per layer (Al-Rfou et al. 2018).\nRelative Position Encoding Shaw et al. (2018)) incorporated relative positional information into $\\mathbf{W}^k$ and $\\mathbf{W}^v$. Maximum relative position is clipped to a maximum absolute value of $k$ and this clipping operation enables the model to generalize to unseen sequence lengths. Therefore, $2k + 1$ unique edge labels are considered and let us denote $\\mathbf{P}^k, \\mathbf{P}^v \\in \\mathbb{R}^{2k+1}$ as learnable relative position representations.\n $$ A_{ij}^k = P^k_{\\text{clip}(j - i, k)} \\quad A_{ij}^v = P^v_{\\text{clip}(j - i, k)} \\quad \\text{where }\\text{clip}(x, k) = \\text{clip}(x, -k, k) $$ Transformer-XL (Dai et al., 2019) proposed a type of relative positional encoding based on reparametrization of dot-product of keys and queries. To keep the positional information flow coherently across segments, Transformer-XL encodes the relative position instead, as it could be sufficient enough to know the position offset for making good predictions, i.e. $i-j$, between one key vector $\\mathbf{k}_{\\tau, j}$ and its query $\\mathbf{q}_{\\tau, i}$.\nIf omitting the scalar $1/\\sqrt{d_k}$ and the normalizing term in softmax but including positional encodings, we can write the attention score between query at position $i$ and key at position $j$ as:\n $$ \\begin{aligned} a_{ij} \u0026= \\mathbf{q}_i {\\mathbf{k}_j}^\\top = (\\mathbf{x}_i + \\mathbf{p}_i)\\mathbf{W}^q ((\\mathbf{x}_j + \\mathbf{p}_j)\\mathbf{W}^k)^\\top \\\\ \u0026= \\mathbf{x}_i\\mathbf{W}^q {\\mathbf{W}^k}^\\top\\mathbf{x}_j^\\top + \\mathbf{x}_i\\mathbf{W}^q {\\mathbf{W}^k}^\\top\\mathbf{p}_j^\\top + \\mathbf{p}_i\\mathbf{W}^q {\\mathbf{W}^k}^\\top\\mathbf{x}_j^\\top + \\mathbf{p}_i\\mathbf{W}^q {\\mathbf{W}^k}^\\top\\mathbf{p}_j^\\top \\end{aligned} $$ Transformer-XL reparameterizes the above four terms as follows:\n $$ a_{ij}^\\text{rel} = \\underbrace{ \\mathbf{x}_i\\mathbf{W}^q \\color{blue}{ {\\mathbf{W}_E^k}^\\top } \\mathbf{x}_j^\\top }_\\text{content-based addressing} + \\underbrace{ \\mathbf{x}_i\\mathbf{W}^q \\color{blue}{ {\\mathbf{W}_R^k}^\\top } \\color{green}{\\mathbf{r}_{i-j}^\\top} }_\\text{content-dependent positional bias} + \\underbrace{ \\color{red}{\\mathbf{u}} \\color{blue}{ {\\mathbf{W}_E^k}^\\top } \\mathbf{x}_j^\\top }_\\text{global content bias} + \\underbrace{ \\color{red}{\\mathbf{v}} \\color{blue}{ {\\mathbf{W}_R^k}^\\top } \\color{green}{\\mathbf{r}_{i-j}^\\top} }_\\text{global positional bias} $$ Replace $\\mathbf{p}_j$ with relative positional encoding $\\mathbf{r}_{i-j} \\in \\mathbf{R}^{d}$; Replace $\\mathbf{p}_i\\mathbf{W}^q$ with two trainable parameters $\\mathbf{u}$ (for content) and $\\mathbf{v}$ (for location) in two different terms; Split $\\mathbf{W}^k$ into two matrices, $\\mathbf{W}^k_E$ for content information and $\\mathbf{W}^k_R$ for location information. Rotary Position Embedding Rotary position embedding (RoPE; Su et al. 2021) encodes the absolution position with a rotation matrix and multiplies key and value matrices of every attention layer with it to inject relative positional information at every layer.\nWhen encoding relative positional information into the inner product of the $i$-th key and the $j$-th query, we would like to formulate the function in a way that the inner product is only about the relative position $i-j$. Rotary Position Embedding (RoPE) makes use of the rotation operation in Euclidean space and frames the relative position embedding as simply rotating feature matrix by an angle proportional to its position index.\nGiven a vector $\\mathbf{z}$, if we want to rotate it counterclockwise by $\\theta$, we can multiply it by a rotation matrix to get $R\\mathbf{z}$ where the rotation matrix $R$ is defined as:\n $$ R = \\begin{bmatrix} \\cos\\theta \u0026 -\\sin\\theta \\\\ \\sin\\theta \u0026 \\cos\\theta \\end{bmatrix} $$ When generalizing to higher dimensional space, RoPE divide the $d$-dimensional space into $d/2$ subspaces and constructs a rotation matrix $R$ of size $d \\times d$ for token at position $i$:\n $$ R^d_{\\Theta, i} = \\begin{bmatrix} \\cos i\\theta_1 \u0026 -\\sin i\\theta_1 \u0026 0 \u0026 0 \u0026 \\dots \u0026 0 \u0026 0 \\\\ \\sin i\\theta_1 \u0026 \\cos i\\theta_1 \u0026 0 \u0026 0 \u0026 \\dots \u0026 0 \u0026 0 \\\\ 0 \u0026 0 \u0026 \\cos i\\theta_2 \u0026 -\\sin i\\theta_2 \u0026 \\dots \u0026 0 \u0026 0 \\\\ 0 \u0026 0 \u0026 \\sin i\\theta_1 \u0026 \\cos i\\theta_1 \u0026 \\dots \u0026 0 \u0026 0 \\\\ \\vdots \u0026 \\vdots \u0026 \\vdots \u0026 \\vdots \u0026 \\ddots \u0026 \\vdots \u0026 \\vdots \\\\ 0 \u0026 0 \u0026 0 \u0026 0 \u0026 \\dots \u0026 \\cos i\\theta_{d/2} \u0026 -\\sin i\\theta_{d/2} \\\\ 0 \u0026 0 \u0026 0 \u0026 0 \u0026 \\dots \u0026 \\sin i\\theta_{d/2} \u0026 \\cos i\\theta_{d/2} \\\\ \\end{bmatrix} $$ where in the paper we have $\\Theta = {\\theta_i = 10000^{-2(i−1)/d}, i \\in [1, 2, \u0026hellip;, d/2]}$. Note that this is essentially equivalent to sinusoidal positional encoding but formulated as a rotation matrix.\nThen both key and query matrices incorporates the positional information by multiplying with this rotation matrix:\n $$ \\begin{aligned} \u0026 \\mathbf{q}_i^\\top \\mathbf{k}_j = (R^d_{\\Theta, i} \\mathbf{W}^q\\mathbf{x}_i)^\\top (R^d_{\\Theta, j} \\mathbf{W}^k\\mathbf{x}_j) = \\mathbf{x}_i^\\top\\mathbf{W}^q R^d_{\\Theta, j-i}\\mathbf{W}^k\\mathbf{x}_j \\\\ \u0026 \\text{ where } R^d_{\\Theta, j-i} = (R^d_{\\Theta, i})^\\top R^d_{\\Theta, j} \\end{aligned} $$ Fig. 4. Visual illustration of how rotary position embedding is implemented.(Image source: Su et al., 2021) Longer Context The length of an input sequence for transformer models at inference time is upper-bounded by the context length used for training. Naively increasing context length leads to high consumption in both time ($\\mathcal{O}(L^2d)$) and memory ($\\mathcal{O}(L^2)$) and may not be supported due to hardware constraints.\nThis section introduces several improvements in transformer architecture to better support long context at inference; E.g. using additional memory, design for better context extrapolation, or recurrency mechanism.\nContext Memory The vanilla Transformer has a fixed and limited attention span. The model can only attend to other elements in the same segments during each update step and no information can flow across separated fixed-length segments. This context segmentation causes several issues:\n The model cannot capture very long term dependencies. It is hard to predict the first few tokens in each segment given no or thin context. The evaluation is expensive. Whenever the segment is shifted to the right by one, the new segment is re-processed from scratch, although there are a lot of overlapped tokens. Transformer-XL (Dai et al., 2019; \u0026ldquo;XL\u0026rdquo; means \u0026ldquo;extra long\u0026rdquo;) modifies the architecture to reuse hidden states between segments with an additional memory. The recurrent connection between segments is introduced into the model by continuously using the hidden states from the previous segments.\nFig. 5. A comparison between the training phrase of vanilla Transformer \u0026 Transformer-XL with a segment length 4. (Image source: left part of Figure 2 in Dai et al., 2019). Let\u0026rsquo;s label the hidden state of the $n$-th layer for the $(\\tau + 1)$-th segment in the model as $\\mathbf{h}_{\\tau+1}^{(n)} \\in \\mathbb{R}^{L \\times d}$. In addition to the hidden state of the last layer for the same segment $\\mathbf{h}_{\\tau+1}^{(n-1)}$, it also depends on the hidden state of the same layer for the previous segment $\\mathbf{h}_{\\tau}^{(n)}$. By incorporating information from the previous hidden states, the model extends the attention span much longer in the past, over multiple segments.\n $$ \\begin{aligned} \\color{red}{\\widetilde{\\mathbf{h}}_{\\tau+1}^{(n-1)}} \u0026= [\\text{stop-gradient}(\\mathbf{h}_{\\tau}^{(n-1)}) \\circ \\mathbf{h}_{\\tau+1}^{(n-1)}] \\\\ \\mathbf{Q}_{\\tau+1}^{(n)} \u0026= \\mathbf{h}_{\\tau+1}^{(n-1)}\\mathbf{W}^q \\\\ \\mathbf{K}_{\\tau+1}^{(n)} \u0026= \\color{red}{\\widetilde{\\mathbf{h}}_{\\tau+1}^{(n-1)}} \\mathbf{W}^k \\\\ \\mathbf{V}_{\\tau+1}^{(n)} \u0026= \\color{red}{\\widetilde{\\mathbf{h}}_{\\tau+1}^{(n-1)}} \\mathbf{W}^v \\\\ \\mathbf{h}_{\\tau+1}^{(n)} \u0026= \\text{transformer-layer}(\\mathbf{Q}_{\\tau+1}^{(n)}, \\mathbf{K}_{\\tau+1}^{(n)}, \\mathbf{V}_{\\tau+1}^{(n)}) \\end{aligned} $$ Note that both keys and values rely on extended hidden states, while queries only consume hidden states at the current step. The concatenation operation $[. \\circ .]$ is along the sequence length dimension. And Transformer-XL needs to use relative positional encoding because previous and current segments would be assigned with the same encoding if we encode absolute positions, which is undesired.\nCompressive Transformer (Rae et al. 2019) extends Transformer-XL by compressing past memories to support longer sequences. It explicitly adds memory slots of size $m_m$ per layer for storing past activations of this layer to preserve long context. When some past activations become old enough, they are compressed and saved in an additional compressed memory of size $m_{cm}$ per layer.\nFig. 6. Compressive transformer maintains two types of memory slots, memory and compressed memory, to support long context. (Image source: Rae et al. 2019). Both memory and compressed memory are FIFO queues. Given the model context length $L$, the compression function of compression rate $c$ is defined as $f_c: \\mathbb{R}^{L \\times d} \\to \\mathbb{R}^{[\\frac{L}{c}] \\times d}$, mapping $L$ oldest activations to $[\\frac{L}{c}]$ compressed memory elements. There are several choices of compression functions:\n Max/mean pooling of kernel and stride size $c$; 1D convolution with kernel and stride size $c$ (need to learn additional parameters); Dilated convolution (need to learn additional parameters). In their experiments, convolution compression works out the best on EnWik8 dataset; Most used memories. Compressive transformer has two additional training losses:\n Auto-encoding loss (lossless compression objective) measures how well we can reconstruct the original memories from compressed memories\n $$ \\mathcal{L}_{ac} = \\| \\textbf{old_mem}^{(i)} - g(\\textbf{new_cm}^{(i)}) \\|_2 $$ where $g: \\mathbb{R}^{[\\frac{L}{c}] \\times d} \\to \\mathbb{R}^{L \\times d}$ reverses the compression function $f$. Attention-reconstruction loss (lossy objective) reconstructs content-based attention over memory vs compressed memory and minimize the difference:\n $$ \\mathcal{L}_{ar} = \\|\\text{attn}(\\mathbf{h}^{(i)}, \\textbf{old_mem}^{(i)}) − \\text{attn}(\\mathbf{h}^{(i)}, \\textbf{new_cm}^{(i)})\\|_2 $$ Transformer-XL with a memory of size $m$ has a maximum temporal range of $m \\times N$, where $N$ is the number of layers in the model, and attention cost $\\mathcal{O}(L^2 + Lm)$. In comparison, compressed transformer has a temporal range of $(m_m + c \\cdot m_{cm}) \\times N$ and attention cost $\\mathcal{O}(L^2 + L(m_m + m_{cm}))$. A larger compression rate $c$ gives better tradeoff between temporal range length and attention cost.\nAttention weights, from oldest to newest, are stored in three locations: compressed memory → memory → causally masked sequence. In the experiments, they observed an increase in attention weights from oldest activations stored in the regular memory, to activations stored in the compressed memory, implying that the network is learning to preserve salient information.\nFig. 7. Attention weights with one standard deviation as error bars versus memory positions, from oldest (left) to newest (right). (Image source: Rae et al. 2019). Non-Differentiable External Memory $k$NN-LM (Khandelwal et al. 2020) enhances a pretrained LM with a separate $k$NN model by linearly interpolating the next token probabilities predicted by both models. The $k$NN model is built upon an external key-value store which can store any large pre-training dataset or OOD new dataset. This datastore is preprocessed to save a large number of pairs, (LM embedding representation of context, next token) and the nearest neighbor retrieval happens in the LM embedding space. Because the datastore can be gigantic, we need to rely on libraries for fast dense vector search such as FAISS or ScaNN. The indexing process only happens once and parallelism is easy to implement at inference time.\nAt inference time, the next token probability is a weighted sum of two predictions:\n $$ \\begin{aligned} p(y \\vert \\mathbf{x}) \u0026= \\lambda \\; p_\\text{kNN}(y \\vert \\mathbf{x}) + (1- \\lambda) \\; p_\\text{LM}(y \\vert \\mathbf{x}) \\\\ p_\\text{kNN}(y \\vert \\mathbf{x}) \u0026\\propto \\sum_{(k_i, w_i) \\in \\mathcal{N}} \\mathbb{1}[y = w_i] \\exp(-d(k_i, f(\\mathbf{x}))) \\end{aligned} $$ where $\\mathcal{N}$ contains a set of nearest neighbor data points retrieved by $k$NN; $d(., .)$ is a distance function such as L2 distance.\nAccording to the experiments, larger datastore size or larger $k$ is correlated with better perplexity. The weighting scalar $\\lambda$ should be tuned, but in general it is expected to be larger for out-of-domain data compared to in-domain data and larger datastore can afford a larger $\\lambda$.\nSPALM (Adaptive semiparametric language models; Yogatama et al. 2021) incorporates both (1) Transformer-XL style memory for hidden states from external context as short-term memory and (2) $k$NN-LM style key-value store as long memory.\nFig. 8. Illustration of how SPALM combines context memory of past hidden states (short term memory) with an external key-value datastore (long term memory) to support longer context. (Image source: Yogatama et al. 2021). SPALM runs $k$NN search to fetch $k$ tokens with most relevant context. For each token we can get the same embedding representation provided by a pretrained LM, denoted as $\\{\\mathbf{y}_i\\}_{i=1}^k$. The gating mechanism first aggregates the retrieved token embeddings with a simple attention layer using $\\mathbf{h}^R_t$ (the hidden state for token $x_t$ at layer $R$) as a query and then learns a gating parameter $\\mathbf{g}_t$ to balance between local information $\\mathbf{h}^R_t$ and long-term information $\\mathbf{m}_t$.\n $$ \\begin{aligned} \\mathbf{m}_t \u0026= \\sum_{i=1}^k \\frac{\\exp(\\mathbf{y}_i^\\top \\mathbf{h}^R_t)}{\\sum_{j=1}^k \\exp(\\mathbf{y}_j^\\top \\mathbf{h}^R_t)} \\cdot \\mathbf{y}_i \\\\ \\mathbf{g}_t \u0026= \\sigma(\\mathbf{w}_g^\\top \\mathbf{h}_t^R) \\\\ \\mathbf{z}_t \u0026= (1 - \\mathbf{g}_t) \\odot \\mathbf{m}_t + \\mathbf{g}_t \\odot \\mathbf{h}^R_t \\\\ p(x_{t+1}\\mid \\mathbf{x}_{\\leq t}) \u0026= \\text{softmax}(\\mathbf{z}_t; \\mathbf{W}) \\end{aligned} $$ where $\\mathbf{w}_g$ is a parameter vector to learn; $\\sigma(.)$ is sigmoid; $\\mathbf{W}$ is the word embedding matrix shared between both input and output tokens. Different from $k$NN-LM, they didn\u0026rsquo;t find the nearest neighbor distance to be helpful in the aggregation of retrieved tokens.\nDuring training, the key representations in the long-term memory stay constant, produced by a pretrained LM, but the value encoder, aka the word embedding matrix, gets updated.\nMemorizing Transformer (Wu et al. 2022) adds a $k$NN-augmented attention layer near the top stack of a decoder-only Transformer. This special layer maintains a Transformer-XL style FIFO cache of past key-value pairs.\nThe same QKV values are used for both local attention and $k$NN mechanisms. The $k$NN lookup returns top-$k$ (key, value) pairs for each query in the input sequence and then they are processed through the self-attention stack to compute a weighted average of retrieved values. Two types of attention are combined with a learnable per-head gating parameter. To prevent large distributional shifts in value magnitude, both keys and values in the cache are normalized.\nWhat they found during experiments with Memorizing Transformer:\n It is observed in some experiments that training models with a small memory and then finetuned with a larger memory works better than training with a large memory from scratch. The smaller Memorizing Transformer with just 8k tokens in memory can match the perplexity of a larger vanilla Transformer with 5X more trainable parameters. Increasing the size of external memory provided consistent gains up to a size of 262K. A non-memory transformer can be finetuned to use memory. Fig. 9. Fine-tuning a vanilla Transformer with a key-value memory can achieve similar performance as training a memorizing transformer from scratch. (Image source: Wu et al. 2022). Distance-Enhanced Attention Scores Distance Aware Transformer(DA-Transformer; Wu, et al. 2021) and Attention with Linear Biases (ALiBi; Press et al. 2022) are motivated by similar ideas \u0026mdash; in order to encourage the model to extrapolate over longer context than what the model is trained on, we can explicitly attach the positional information to every pair of attention score based on the distance between key and query tokens.\nNote that the default positional encoding in vanilla Transformer only adds positional information to the input sequence, while later improved encoding mechanisms alter attention scores of every layer, such as rotary position embedding, and they take on form very similar to distance enhanced attention scores.\nDA-Transformer (Wu, et al. 2021) multiplies attention scores at each layer by a learnable bias that is formulated as a function of the distance between key and query. Different attention heads use different parameters to distinguish diverse preferences to short-term vs long-term context. Given two positions, $i, j$, DA-Transformer uses the following weighting function to alter the self-attention score:\n $$ \\begin{aligned} \\mathbf{R}^{(i)} \u0026= \\alpha_i \\mathbf{R} \\quad \\text{where }R_{ij} = \\vert i-j \\vert\\\\ f(\\mathbf{R}^{(i)}; \\beta_i) \u0026= \\frac{1 + \\exp(\\beta_i)}{1 + \\exp(\\beta_i - \\mathbf{R}^{(i)})} \\\\ \\text{attn}(\\mathbf{Q}^{(i)}, \\mathbf{K}^{(i)}, \\mathbf{V}^{(i)}) \u0026= \\text{row-softmax}\\Big(\\frac{\\text{ReLU}(\\mathbf{Q}^{(i)}\\mathbf{K}^{(i)\\top})f(\\mathbf{R}^{(i)})}{\\sqrt{d}}\\Big) \\mathbf{V}^{(i)} \\end{aligned} $$ where $\\alpha_i$ is a learnable parameters to weight relative distance differently per head where the head is indexed by superscript $^{(i)}$; $\\beta_i$ is a learnable parameter to control the upper bound and ascending slope wrt the distance for the $i$-th attention head. The weighting function $f(.)$ is designed in a way that: (1) $f(0)=1$; (2) $f(\\mathbf{R}^{(i)}) = 0$ when $\\mathbf{R}^{(i)} \\to -\\infty$; (3) $f(\\mathbf{R}^{(i)})$ is bounded when $\\mathbf{R}^{(i)} \\to +\\infty$; (4) the scale is tunable; (5) and the function is monotonic. The extra time complexity brought by $f(\\mathbf{R}^{(i)})$ is $\\mathcal{O}(L^2)$ and it is small relative to the self attention time complexity $\\mathcal{O}(L^2 d)$. The extra memory consumption is minimal, ~$\\mathcal{O}(2h)$.\nInstead of multipliers, ALiBi (Press et al. 2022) adds a constant bias term on query-key attention scores, proportional to pairwise distances. The bias introduces a strong recency preference and penalizes keys that are too far away. The penalties are increased at different rates within different heads. $$ \\text{softmax}(\\mathbf{q}_i \\mathbf{K}^\\top + \\alpha_i \\cdot [0, -1, -2, \\dots, -(i-1)]) $$ where $\\alpha_i$ is a head-specific weighting scalar. Different from DA-transformer, $\\alpha_i$ is not learned but fixed as a geometric sequence; for example, for 8 heads, ${\\alpha_i} = {\\frac{1}{2}, \\frac{1}{2^2}, \\dots, \\frac{1}{2^8}}$. The overall idea is very much similar to what relative positional encoding aims to solve.\nFig. 10. Illustration of how ALiBi enhances attention scores with a positional bias term. (Image source: Press et al. 2021). With ALiBi, Press et al. (2022) trained a 1.3B model on context length 1024 during training and extrapolated to 2046 at inference time.\nFig. 11. Extrapolation experiments for running inference with Transformers of different configs, including sinusoidal positional encoding, rotary positional encoding, simplified relative positional encoding in T5 and ALiBi. All models were trained with small context length but inference ran for much longer context. (Image source: Press et al. 2021). Make it Recurrent Universal Transformer (Dehghani, et al. 2019) combines self-attention in Transformer with the recurrent mechanism in RNN, aiming to benefit from both a long-term global receptive field of Transformer and learned inductive biases of RNN. Rather than going through a fixed number of layers, Universal Transformer dynamically adjusts the number of steps using adaptive computation time. If we fix the number of steps, an Universal Transformer is equivalent to a multi-layer Transformer with shared parameters across layers.\nOn a high level, the universal transformer can be viewed as a recurrent function for learning the hidden state representation per token. The recurrent function evolves in parallel across token positions and the information between positions is shared through self-attention.\nFig. 12. How the Universal Transformer refines a set of hidden state representations repeatedly for every position in parallel. (Image source: Figure 1 in Dehghani, et al. 2019). Given an input sequence of length $L$, Universal Transformer iteratively updates the representation $\\mathbf{h}^t \\in \\mathbb{R}^{L \\times d}$ at step $t$ for an adjustable number of steps. At step 0, $\\mathbf{h}^0$ is initialized to be same as the input embedding matrix. All the positions are processed in parallel in the multi-head self-attention mechanism and then go through a recurrent transition function.\n $$ \\begin{aligned} \\mathbf{A}^t \u0026= \\text{LayerNorm}(\\mathbf{h}^{t-1} + \\text{MultiHeadAttention}(\\mathbf{h}^{t-1} + \\mathbf{P}^t) \\\\ \\mathbf{h}^t \u0026= \\text{LayerNorm}(\\mathbf{A}^{t-1} + \\text{Transition}(\\mathbf{A}^t)) \\end{aligned} $$ where $\\text{Transition}(.)$ is either a separable convolution or a fully-connected neural network that consists of two position-wise (i.e. applied to each row of $\\mathbf{A}^t$ individually) affine transformation + one ReLU.\nThe positional encoding $\\mathbf{P}^t$ uses sinusoidal position signal but with an additional time dimension:\n $$ \\text{PE}(i, t, \\delta) = \\begin{cases} \\sin(\\frac{i}{10000^{2\\delta'/d}}) \\oplus \\sin(\\frac{t}{10000^{2\\delta'/d}}) \u0026 \\text{if } \\delta = 2\\delta'\\\\ \\cos(\\frac{i}{10000^{2\\delta'/d}}) \\oplus \\cos(\\frac{t}{10000^{2\\delta'/d}}) \u0026 \\text{if } \\delta = 2\\delta' + 1\\\\ \\end{cases} $$ Fig. 13. A simplified illustration of Universal Transformer. The encoder and decoder share the same basic recurrent structure. But the decoder also attends to final encoder representation $\\mathbf{h}^T$. (Image source: Figure 2 in Dehghani, et al. 2019) In the adaptive version of Universal Transformer, the number of recurrent steps $T$ is dynamically determined by ACT. Each position is equipped with a dynamic ACT halting mechanism. Once a per-token recurrent block halts, it stops taking more recurrent updates but simply copies the current value to the next step until all the blocks halt or until the model reaches a maximum step limit.\nAdaptive Modeling Adaptive modeling refers to a mechanism that can adjust the amount of computation according to different inputs. For example, some tokens may only need local information and thus demand a shorter attention span; Or some tokens are relatively easier to predict and do not need to be processed through the entire attention stack.\nAdaptive Attention Span One key advantage of Transformer is the capability of capturing long-term dependencies. Depending on the context, the model may prefer to attend further sometime than others; or one attention head may had different attention pattern from the other. If the attention span could adapt its length flexibly and only attend further back when needed, it would help reduce both computation and memory cost to support longer maximum context size in the model.\nThis is the motivation for Adaptive Attention Span. Sukhbaatar et al (2019) proposed a self-attention mechanism that seeks an optimal attention span. They hypothesized that different attention heads might assign scores differently within the same context window (See Fig. 14) and thus the optimal span would be trained separately per head.\nFig. 14. Two attention heads in the same model, A \u0026 B, assign attention differently within the same context window. Head A attends more to the recent tokens, while head B look further back into the past uniformly. (Image source: Sukhbaatar, et al. 2019) Given the $i$-th token, we need to compute the attention weights between this token and other keys within its attention span of size $s$:\n $$ \\begin{aligned} e_{ij} \u0026= \\mathbf{q}_i {\\mathbf{k}_j}^\\top \\\\ a_{ij} \u0026= \\text{softmax}(e_{ij}) = \\frac{\\exp(e_{ij})}{\\sum_{r=i-s}^{i-1} \\exp(e_{ir})} \\\\ \\mathbf{y}_i \u0026= \\sum_{r=i-s}^{i-1}a_{ir}\\mathbf{v}_r = \\sum_{r=i-s}^{i-1}a_{ir}\\mathbf{x}_r\\mathbf{W}^v \\end{aligned} $$ A soft mask function $m_z$ is added to control for an effective adjustable attention span, which maps the distance between query and key into a [0, 1] value. $m_z$ is parameterized by $z \\in [0, s]$ and $z$ is to be learned:\n $$ m_z(x) = \\text{clip}(\\frac{1}{R}(R+z-x), 0, 1) $$ where $R$ is a hyper-parameter which defines the softness of $m_z$.\nFig. 15. The soft masking function used in the adaptive attention span. (Image source: Sukhbaatar, et al. 2019.) The soft mask function is applied to the softmax elements in the attention weights:\n $$ a_{ij} = \\frac{m_z(i-j)\\exp(s_{ij})}{\\sum_{r=i-s}^{i-1}m_z(i-r) \\exp(s_{ir})} $$ In the above equation, $z$ is differentiable so it is trained jointly with other parts of the model. Parameters $z^{(i)}, i=1, \\dots, h$ are learned separately per head. Moreover, the loss function has an extra L1 penalty on $\\sum_{i=1}^h z^{(i)}$.\nUsing Adaptive Computation Time, the approach can be further enhanced to have flexible attention span length, adaptive to the current input dynamically. The span parameter $z_t$ of an attention head at time $t$ is a sigmoidal function, $z_t = S \\sigma(\\mathbf{v} \\cdot \\mathbf{x}_t +b)$, where the vector $\\mathbf{v}$ and the bias scalar $b$ are learned jointly with other parameters.\nIn the experiments of Transformer with adaptive attention span, Sukhbaatar, et al. (2019) found a general tendency that lower layers do not require very long attention spans, while a few attention heads in higher layers may use exceptionally long spans. Adaptive attention span also helps greatly reduce the number of FLOPS, especially in a big model with many attention layers and a large context length.\nDepth-Adaptive Transformer At inference time, it is natural to assume that some tokens are easier to predict and thus do not require as much computation as others. Therefore we may only process its prediction through a limited number of layers to achieve a good balance between speed and performance.\nBoth Depth-Adaptive Transformer (Elabyad et al. 2020) and Confident Adaptive Language Model (CALM; Schuster et al. 2022) are motivated by this idea and learn to predict optimal numbers of layers needed for different input tokens.\nDepth-adaptive transformer (Elabyad et al. 2020) attaches an output classifier to every layer to produce exit predictions based on activations of that layer. The classifier weight matrices can be different per layer or shared across layers. During training, the model sample different sequences of exits such that the model is optimized with hidden states of different layers. The learning objective incorporates likelihood probabilities predicted at different layers, $n=1, \\dots, N$:\n $$ \\text{LL}^n_t = \\log p(y_t \\vert \\mathbf{h}^n_{t-1}) \\quad \\text{LL}^n = \\sum_{t=1}^{\\vert\\mathbf{y}\\vert} LL^n_t $$ Adaptive depth classifiers outputs a parametric distribution $q_t$. It is trained with cross entropy loss against an oracle distribution $q^*_t$. The paper explored three confiurations for how to learn such a classifier $q_t$.\nFig. 16. Illustration of three types of adaptive depth classifiers. (Image source: Elabyad et al. 2020). Sequence-specific depth classifier: All tokens of the same sequence share the same exit block. It depends on the average of the encoder representation of the sequence. Given an input sequence $\\mathbf{x}$ of length $L$, the classifier takes $\\bar{\\mathbf{x}} = \\frac{1}{L} \\sum_{t=1}^L \\mathbf{x}_t$ as input and outputs a multinomial distribution of $N$ dimensions, corresponding to $N$ layers.\n $$ \\begin{aligned} q(n \\vert \\mathbf{x}) \u0026=\\text{softmax}(\\mathbf{W}_n \\bar{\\mathbf{x}} + b_n) \\in \\mathbb{R}^N \\\\ q_\\text{lik}^*(\\mathbf{x}, \\mathbf{y}) \u0026= \\delta(\\arg\\max_n \\text{LL}^n - \\lambda n) \\\\ \\text{or }q_\\text{corr}^*(\\mathbf{x}, \\mathbf{y}) \u0026= \\delta(\\arg\\max_n C^n - \\lambda n) \\text{ where }C^n = \\vert\\{t \\vert y_t = \\arg\\max_y p(y \\vert \\mathbf{h}^n_{t-1})\\}\\vert \\\\ \\end{aligned} $$ where $\\delta$ is dirac delta (unit impulse) function and $-\\lambda n$ is a regularization term to encourage lower layer exits. The ground truth $q^*$ can be prepared in two way, based on maximum likelihood $q_\\text{lik}^*$ or correctness $q_\\text{corr}^*$. \n Token-specific depth classifier (multinomial): Each token is decoded with different exit block, predicted conditioned on the first decoder hidden state $\\mathbf{h}^1_t$:\n $$ q_t(n \\vert \\mathbf{x}, \\mathbf{y}_{ Token-specific depth classifier (geometric-like): A binary exit prediction distribution is made per layer per token, $\\mathcal{X}^n_t$. The RBF kernel $\\kappa(t, t’) = \\exp(\\frac{\\vert t - t’ \\vert^2}{\\sigma})$ is used to smooth the predictions to incorporate the impact of current decision on future time steps.\n $$ \\begin{aligned} \\mathcal{X}^n_t \u0026= \\text{sigmoid}(\\mathbf{w}_n^\\top \\mathbf{h}^n_t + b_n)\\quad \\forall n \\in [1, \\dots, N-1] \\\\ q_t(n \\vert \\mathbf{x}, \\mathbf{y}_{ At inference time, the confidence threshold for making an exit decision needs to be calibrated. Depth-adaptive transformer finds such a threshold on a validation set via grid search. CALM (Schuster et al. 2022) applied the Learn then Test (LTT) framework (Angelopoulos et al. 2021) to identify a subset of valid thresholds and chose the minimum value as the threshold for inference. Except for training per-layer exit classifier, CALM also explored other methods for adaptive depth prediction, including the softmax responses (i.e. difference between top two softmax outputs) and hidden state saturation (i.e. $\\cos(\\mathbf{h}^n_t, \\mathbf{h}^{n+1}_t)$) as confidence scores for exit decisions. They found softmax responses result in best inference speedup.\nEfficient Attention The computation and memory cost of the vanilla Transformer grows quadratically with sequence length and hence it is hard to be applied on very long sequences. Many efficiency improvements for Transformer architecture have something to do with the self-attention module - making it cheaper, smaller or faster to run. See the survey paper on Efficient Transformers (Tay et al. 2020).\nSparse Attention Patterns Fixed Local Context A simple alternation to make self-attention less expensive is to restrict the attention span of each token to local context only, so that self-attention grows linearly with the sequence length.\nThe idea was introduced by Image Transformer (Parmer, et al 2018), which formulates image generation as sequence modeling using an encoder-decoder transformer architecture:\n The encoder generates a contextualized, per-pixel-channel representation of the source image; Then the decoder autoregressively generates an output image, one channel per pixel at each time step. Let\u0026rsquo;s label the representation of the current pixel to be generated as the query $\\mathbf{q}$. Other positions whose representations will be used for computing $\\mathbf{q}$ are key vector $\\mathbf{k}_1, \\mathbf{k}_2, \\dots$ and they together form a memory matrix $\\mathbf{M}$. The scope of $\\mathbf{M}$ defines the context window for pixel query $\\mathbf{q}$.\nImage Transformer introduced two types of localized $\\mathbf{M}$, as illustrated below.\nFig. 17. Illustration of 1D and 2D attention span for visual inputs in Image Transformer. The black line marks a query block and the cyan outlines the actual attention span for pixel q. (Image source: Figure 2 in Parmer et al, 2018) 1D Local Attention: The input image is flattened in the raster scanning order, that is, from left to right and top to bottom. The linearized image is then partitioned into non-overlapping query blocks. The context window consists of pixels in the same query block as $\\mathbf{q}$ and a fixed number of additional pixels generated before this query block.\n 2D Local Attention: The image is partitioned into multiple non-overlapping rectangular query blocks. The query pixel can attend to all others in the same memory blocks. To make sure the pixel at the top-left corner can also have a valid context window, the memory block is extended to the top, left and right by a fixed amount, respectively.\n Strided Context Sparse Transformer (Child et al., 2019) introduced factorized self-attention, through sparse matrix factorization, making it possible to train dense attention networks with hundreds of layers on sequence length up to 16,384, which would be infeasible on modern hardware otherwise.\nGiven a set of attention connectivity pattern $\\mathcal{S} = \\{S_1, \\dots, S_n\\}$, where each $S_i$ records a set of key positions that the $i$-th query vector attends to.\n $$ \\begin{aligned} \\text{Attend}(\\mathbf{X}, \\mathcal{S}) \u0026= \\Big( a(\\mathbf{x}_i, S_i) \\Big)_{i \\in \\{1, \\dots, L\\}} \\\\ \\text{ where } a(\\mathbf{x}_i, S_i) \u0026= \\text{softmax}\\Big(\\frac{(\\mathbf{x}_i \\mathbf{W}^q)(\\mathbf{x}_j \\mathbf{W}^k)_{j \\in S_i}^\\top}{\\sqrt{d_k}}\\Big) (\\mathbf{x}_j \\mathbf{W}^v)_{j \\in S_i} \\end{aligned} $$ Note that although the size of $S_i$ is not fixed, $a(\\mathbf{x}_i, S_i)$ is always of size $d_v$ and thus $\\text{Attend}(\\mathbf{X}, \\mathcal{S}) \\in \\mathbb{R}^{L \\times d_v}$.\nIn anto-regressive models, one attention span is defined as $S_i = \\{j: j \\leq i\\}$ as it allows each token to attend to all the positions in the past.\nIn factorized self-attention, the set $S_i$ is decomposed into a tree of dependencies, such that for every pair of $(i, j)$ where $j \\leq i$, there is a path connecting $i$ back to $j$ and $i$ can attend to $j$ either directly or indirectly.\nPrecisely, the set $S_i$ is divided into $p$ non-overlapping subsets, where the $m$-th subset is denoted as $A^{(m)}_i \\subset S_i, m = 1,\\dots, p$. Therefore the path between the output position $i$ and any $j$ has a maximum length $p + 1$. For example, if $(j, a, b, c, \\dots, i)$ is a path of indices between $i$ and $j$, we would have $j \\in A_a^{(1)}, a \\in A_b^{(2)}, b \\in A_c^{(3)}, \\dots$, so on and so forth.\nSparse Factorized Attention\nSparse Transformer proposed two types of fractorized attention. It is easier to understand the concepts as illustrated in Fig. 10 with 2D image inputs as examples.\nFig. 18. The top row illustrates the attention connectivity patterns in (a) Transformer, (b) Sparse Transformer with strided attention, and (c) Sparse Transformer with fixed attention. The bottom row contains corresponding self-attention connectivity matrices. Note that the top and bottom rows are not in the same scale. (Image source: Child et al., 2019 + a few of extra annotations.) Strided attention with stride $\\ell \\sim \\sqrt{n}$. This works well with image data as the structure is aligned with strides. In the image case, each pixel would attend to all the previous $\\ell$ pixels in the raster scanning order (naturally cover the entire width of the image) and then those pixels attend to others in the same column (defined by another attention connectivity subset).\n $$ \\begin{aligned} A_i^{(1)} \u0026= \\{ t, t+1, \\dots, i\\} \\text{, where } t = \\max(0, i - \\ell) \\\\ A_i^{(2)} \u0026= \\{j: (i-j) \\mod \\ell = 0\\} \\end{aligned} $$ Fixed attention. A small set of tokens summarize previous locations and propagate that information to all future locations.\n $$ \\begin{aligned} A_i^{(1)} \u0026= \\{j: \\lfloor \\frac{j}{\\ell} \\rfloor = \\lfloor \\frac{i}{\\ell} \\rfloor \\} \\\\ A_i^{(2)} \u0026= \\{j: j \\mod \\ell \\in \\{\\ell-c, \\dots, \\ell-1\\} \\} \\end{aligned} $$ where $c$ is a hyperparameter. If $c=1$, it restricts the representation whereas many depend on a few positions. The paper chose $c\\in \\{ 8, 16, 32 \\}$ for $\\ell \\in \\{ 128, 256 \\}$.\n Use Factorized Self-Attention in Transformer\nThere are three ways to use sparse factorized attention patterns in Transformer architecture:\n One attention type per residual block and then interleave them, $\\text{attn}(\\mathbf{X}) = \\text{Attend}(\\mathbf{X}, A^{(n \\mod p)}) \\mathbf{W}^o$, where $n$ is the index of the current residual block. Set up a single head which attends to locations that all the factorized heads attend to, $\\text{attn}(\\mathbf{X}) = \\text{Attend}(\\mathbf{X}, \\cup_{m=1}^p A^{(m)}) \\mathbf{W}^o $. Use a multi-head attention mechanism, but different from vanilla Transformer, each head might adopt a pattern presented above, 1 or 2. $\\rightarrow$ This option often performs the best. Sparse Transformer also proposed a set of changes so as to train the Transformer up to hundreds of layers, including gradient checkpointing, recomputing attention \u0026amp; FF layers during the backward pass, mixed precision training, efficient block-sparse implementation, etc. Please check the paper for more details or my previous post on techniques for scaling up model training.\nBlockwise Attention (Qiu et al. 2019) introduces a sparse block matrix to only allow each token to attend to a small set of other tokens. Each attention matrix of size $L \\times L$ is partitioned into $n \\times n$ smaller blocks of size $\\frac{L}{n}\\times\\frac{L}{n}$ and a sparse block matrix $\\mathbf{M} \\in \\{0, 1\\}^{L \\times L}$ is defined by a permutation $\\pi$ of ${1, \\dots, n}$, which records the column index per row in the block matrix.\n $$ \\begin{aligned} \\text{attn}(\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}, \\mathbf{M}) \u0026= \\text{softmax}\\Big(\\frac{\\mathbf{Q}\\mathbf{K}^\\top}{\\sqrt{d}} \\odot \\mathbf{M}\\Big)\\mathbf{V} \\\\ (\\mathbf{A} \\odot \\mathbf{M})_{ij} \u0026= \\begin{cases} A_{ij} \u0026 \\text{if }M_{ij} = 1 \\\\ -\\infty \u0026 \\text{if }M_{ij} = 0 \\\\ \\end{cases} \\\\ \\text{where } M_{ij} \u0026= \\begin{cases} 1 \u0026 \\text{if }\\pi\\big(\\lfloor\\frac{(i-1)n}{L} + 1\\rfloor\\big) = \\lfloor\\frac{(j-1)n}{L} + 1\\rfloor \\\\ 0 \u0026 \\text{otherwise} \\end{cases} \\end{aligned} $$ The actual implementation of Blockwise Attention only stores QKV as block matrices, each of size $n\\times n$:\n $$ \\text{Blockwise-attn}(\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}, \\mathbf{M}) = \\begin{bmatrix} \\text{softmax}\\big(\\frac{\\hat{\\mathbf{q}}_1\\hat{\\mathbf{k}}_{\\pi(1)}^\\top}{\\sqrt{d}} \\Big)\\hat{\\mathbf{v}}_{\\pi(1)} \\\\ \\vdots \\\\ \\text{softmax}\\big(\\frac{\\hat{\\mathbf{q}}_n\\hat{\\mathbf{k}}_{\\pi(n)}^\\top}{\\sqrt{d}} \\odot \\Big)\\hat{\\mathbf{v}}_{\\pi(n)} \\\\ \\end{bmatrix} $$ where $\\hat{\\mathbf{q}}_i$, $\\hat{\\mathbf{k}}_i$ and $\\hat{\\mathbf{v}}_i$ are the $i$-the row in the QKV block matrix respectively. Each $\\mathbf{q}_i\\mathbf{k}_{\\pi(i)}^\\top, \\forall i = 1, \\dots, n$ is of size $\\frac{N}{n}\\times\\frac{N}{n}$ and therefore Blockwise Attention is able to reduce the memory complexity of attention matrix from $\\mathcal{O}(L^2)$ to $\\mathcal{O}(\\frac{L}{n}\\times\\frac{L}{n} \\times n) = \\mathcal{O}(L^2/n)$.\nCombination of Local and Global Context ETC (Extended Transformer Construction; Ainslie et al. 2019), Longformer (Beltagy et al. 2020) and Big Bird (Zaheer et al. 2020) models combine both local and global context when building an attention matrix. All these models can be initialized from existing pretrained models.\nGlobal-Local Attention of ETC (Ainslie et al. 2019) takes two inputs, (1) the long input $\\mathbf{x}^l$ of size $n_l$ which is the regular input sequence and (2) the global input $\\mathbf{x}^g$ of size $n_g$ which contains a smaller number of auxiliary tokens, $n_g \\ll n_l$. Attention is thus split into four components based on directional attention across these two inputs: g2g, g2l, l2g and l2l. Because the l2l attention piece can be very large, it is restricted to a fixed size attention span of radius $w$ (i.e. local attention span) and the l2l matrix can be reshaped to $n_l \\times (2w+1)$.\nETC utilizes four binary matrices to handle structured inputs, $\\mathbf{M}^{g2g}$, $\\mathbf{M}^{g2l}$, $\\mathbf{M}^{l2g}$ and $\\mathbf{M}^{l2l}$. For example, each element $z^g_i \\in \\mathbb{R}^d$ in the attention output $z^g = (z^g_1, \\dots, z^g_{n_g})$ for g2g attention piece is formatted as:\n $$ \\begin{aligned} a^{g2g}_{ij} = \\frac{1}{\\sqrt{d}} x^g_i \\mathbf{W}^Q (x^g_j \\mathbf{W}^K + P^K_{ij})^\\top - (1- M^{g2g}_{ij})C \\\\ A^{g2g}_{ij} = \\frac{\\exp(a^{g2g}_{ij})}{\\sum_{k=1}^{n_g} \\exp(a^{g2g}_{ik})} \\quad z^g_i = \\sum^{n_g}_{j=1} A^{g2g}_{ij} x^g_j \\mathbf{W}^V \\end{aligned} $$ where $P^K_{ij}$ is a learnable vector for relative position encoding and $C$ is a very large constant ($C=10000$ in the paper) to offset any attention weights when mask is off.\nFig. 19. Attention patterns of ETC, Longformer and Big Bird. One more update in ETC is to incorporate a CPC (contrastive predictive coding) task using NCE loss into the pretraining stage, besides the MLM task: The representation of one sentence should be similar to the representation of context around it when this sentence is masked.\nThe global input $\\mathbf{x}^g$ for ETC is constructed as follows: Assuming there are some segments within the long inputs (e.g. by sentence), each segment is attached with one auxiliary token to learn global inputs. Relative position encoding is used to mark the global segment tokens with the token position. Hard masking in one direction (i.e., tokens before vs after are labeled differently) is found to bring performance gains in some datasets.\nAttention pattern in Longformer contains three components:\n Local attention: Similar to ETC, local attention is controlled by a sliding window of fixed size $w$; Global attention of preselected tokens: Longformer has a few pre-selected tokens (e.g. [CLS] token) assigned with global attention span, that is, attending to all other tokens in the input sequence. Dilated attention: Dilated sliding window of fixed size $r$ and gaps of dilation size $d$, similar to Sparse Transformer; Big Bird is quite similar to Longformer, equipped with both local attention and a few preselected tokens with global attention span, but Big Bird replaces dilated attention with a new mechanism where all tokens attend to a set of random tokens. The design is motivated by the fact that attention pattern can be viewed as a directed graph and a random graph has the property that information is able to rapidly flow between any pair of nodes.\nLongformer uses smaller window size at lower layers and larger window sizes at higher layers. Ablation studies showed that this setup works better than reversed or fixed size config. Lower layers do not have dilated sliding windows to better learn to use immediate local context. Longformer also has a staged training procedure where initially the model is trained with small window size to learn from local context and then subsequent stages of training have window sizes increased and learning rate decreased.\nContent-based Attention The improvements proposed by Reformer (Kitaev, et al. 2020) aim to solve the following pain points in vanilla Transformer:\n Quadratic time and memory complexity within self-attention module. Memory in a model with $N$ layers is $N$-times larger than in a single-layer model because we need to store activations for back-propagation. The intermediate FF layers are often quite large. Reformer proposed two main changes:\n Replace the dot-product attention with locality-sensitive hashing (LSH) attention, reducing the complexity from $\\mathcal{O}(L^2)$ to $\\mathcal{O}(L\\log L)$. Replace the standard residual blocks with reversible residual layers, which allows storing activations only once during training instead of $N$ times (i.e. proportional to the number of layers). Locality-Sensitive Hashing Attention\nIn $\\mathbf{Q} \\mathbf{K}^\\top$ part of the attention formula, we are only interested in the largest elements as only large elements contribute a lot after softmax. For each query $\\mathbf{q}_i \\in \\mathbf{Q}$, we are looking for row vectors in $\\mathbf{K}$ closest to $\\mathbf{q}_i$. In order to find nearest neighbors quickly in high-dimensional space, Reformer incorporates Locality-Sensitive Hashing (LSH) into its attention mechanism.\nA hashing scheme $x \\mapsto h(x)$ is locality-sensitive if it preserves the distancing information between data points, such that close vectors obtain similar hashes while distant vectors have very different ones. The Reformer adopts a hashing scheme as such, given a fixed random matrix $\\mathbf{R} \\in \\mathbb{R}^{d \\times b/2}$ (where $b$ is a hyperparam), the hash function is $h(x) = \\arg\\max([xR; −xR])$.\n$$ \\mathbf{o}_i = \\sum_{j \\in S_i} \\exp(\\mathbf{q}_i \\cdot \\mathbf{k}_j - Z(i, S_i)) \\mathbf{v}_j \\text{, where } S_i = \\{j: j \\leq i\\} $$ -- Fig. 20. Illustration of Locality-Sensitive Hashing (LSH) attention. (Image source: right part of Figure 1 in Kitaev, et al. 2020). In LSH attention, a query can only attend to positions in the same hashing bucket, $S_i = \\{j: h(\\mathbf{q}_i) = h(\\mathbf{k}_j)\\}$. It is carried out in the following process, as illustrated in Fig. 20:\n (a) The attention matrix for full attention is often sparse. (b) Using LSH, we can sort the keys and queries to be aligned according to their hash buckets. (c) Set $\\mathbf{Q} = \\mathbf{K}$ (precisely $\\mathbf{k}_j = \\mathbf{q}_j / |\\mathbf{q}_j|$), so that there are equal numbers of keys and queries in one bucket, easier for batching. Interestingly, this \u0026ldquo;shared-QK\u0026rdquo; config does not affect the performance of the Transformer. (d) Apply batching where chunks of $m$ consecutive queries are grouped together. Fig. 21. The LSH attention consists of 4 steps: bucketing, sorting, chunking, and attention computation. (Image source: left part of Figure 1 in Kitaev, et al. 2020). Reversible Residual Network\nAnother improvement by Reformer is to use reversible residual layers (Gomez et al. 2017). The motivation for reversible residual network is to design the architecture in a way that activations at any given layer can be recovered from the activations at the following layer, using only the model parameters. Hence, we can save memory by recomputing the activation during backprop rather than storing all the activations.\nGiven a layer $x \\mapsto y$, the normal residual layer does $y = x + F(x)$, but the reversible layer splits both input and output into pairs $(x_1, x_2) \\mapsto (y_1, y_2)$ and then executes the following:\n $$ y_1 = x_1 + F(x_2),\\; y_2 = x_2 + G(y_1) $$ and reversing is easy:\n $$ x_2 = y_2 - G(y_1), \\; x_1 = y_1 − F(x_2) $$ Reformer applies the same idea to Transformer by combination attention ($F$) and feed-forward layers ($G$) within a reversible net block:\n $$ Y_1 = X_1 + \\text{Attention}(X_2), \\; Y_2 = X_2 + \\text{FeedForward}(Y_1) $$ The memory can be further reduced by chunking the feed-forward computation:\n $$ Y_2 = [Y_2^{(1)}; \\dots; Y_2^{(c)}] = [X_2^{(1)} + \\text{FeedForward}(Y_1^{(1)}); \\dots; X_2^{(c)} + \\text{FeedForward}(Y_1^{(c)})] $$ The resulting reversible Transformer does not need to store activation in every layer.\nRouting Transformer (Roy et al. 2021) is also built on content-based clustering of keys and queries. Instead of using a static hashing function like LSH, it utilizes online $k$-means clustering and combines it with local, temporal sparse attention to reduce the attention complexity from $O(L^2)$ to $O(L^{1.5})$.\nWithin routing attention, both keys and queries are clustered with $k$-means clustering method and the same set of centroids $\\boldsymbol{\\mu} = (\\mu_1, \\dots, \\mu_k) \\in \\mathbb{R}^{k \\times d}$. Queries are routed to keys that get assigned to the same centroid. The total complexity is $O(Lkd + L^2d/k)$, where $O(Lkd)$ is for running clustering assignments and $O(L^2d/k)$ is for attention computation. The cluster centroids are updated by EMA (exponential moving average) using all associated keys and queries.\nIn the experiments for Routing Transformer, some best config only has routing attention enabled in the last two layers of the model and half of the attention heads, while the other half utilizing local attention. They also observed that local attention is a pretty strong baseline and larger attention window always leads to better results.\nLow-Rank Attention Linformer (Wang et al. 2020) approximates the full attention matrix with a low rank matrix, reducing the time \u0026amp; space complexity to be linear. Instead of using expensive SVD to identify low rank decomposition, Linformer adds two linear projections $\\mathbf{E}_i, \\mathbf{F}_i \\in \\mathbb{R}^{L \\times k}$ for key and value matrices, respectively, reducing their dimensions from $L \\times d$ to $k \\times d$. As long as $k \\ll L$, the attention memory can be greatly reduced.\n $$ \\begin{aligned} \\overline{\\text{head}}_i \u0026= \\text{attn}(\\mathbf{X}_q\\mathbf{W}^q_i, \\mathbf{E}_i\\mathbf{X}_k\\mathbf{W}^k_i, \\mathbf{F}_i\\mathbf{X}_v\\mathbf{W}^v_i) \\\\ \u0026= \\underbrace{\\text{softmax}\\Big( \\frac{\\mathbf{X}_q\\mathbf{W}^q_i (\\mathbf{E}_i \\mathbf{X}_k\\mathbf{W}^k_i)^\\top}{\\sqrt{d}} \\Big)}_{\\text{low rank attention matrix }\\bar{A} \\in \\mathbb{R}^{k \\times d}} \\mathbf{F}_i \\mathbf{X}_v\\mathbf{W}^v_i \\end{aligned} $$ Additional techniques can be applied to further improve efficiency of Linformer:\n Parameter sharing between projection layers, such as head-wise, key-value and layer-wise (across all layers) sharing. Use different $k$ at different layers, as heads in higher layers tend to have a more skewed distribution (lower rank) and thus we can use smaller $k$ at higher layers. Use different types of projections; e.g. mean/max pooling, convolution layer with kernel and stride $L/k$. Fig. 22. (Left) Informer has two projection layers added for keys and values. (Right) Plot of inference time as a function of sequence length. (Image source: Wang et al. 2020). Random Feature Attention (RFA; Peng et al. 2021) relies on random feature methods (Rahimi \u0026amp; Recht, 2007) to approximate softmax operation in self-attention with low rank feature maps in order to achieve linear time and space complexity. Performers (Choromanski et al. 2021) also adopts random feature attention with improvements on the kernel construction to further reduce the kernel approximation error.\nThe main theorem behind RFA is from Rahimi \u0026amp; Recht, 2007:\n Let $\\phi: \\mathbb{R}^d \\to \\mathbb{R}^{2D}$ be a nonlinear transformation:\n $$ \\phi(\\mathbf{x}) = \\frac{1}{\\sqrt{D}}[\\sin(\\mathbf{w}_1^\\top \\mathbf{x}), \\dots, \\sin(\\mathbf{w}_D^\\top \\mathbf{x}), \\cos(\\mathbf{w}_1^\\top \\mathbf{x}), \\dots, \\cos(\\mathbf{w}_D^\\top \\mathbf{x})]^\\top $$ When $d$-dimensional random vectors $\\mathbf{w}_i$ are i.i.d. from $\\mathcal{N}(\\mathbf{0}, \\sigma^2\\mathbf{I}_d)$, $$ \\mathbb{E}_{\\mathbf{w}_i} [\\phi(\\mathbf{x}) \\cdot \\phi(\\mathbf{y})] = \\exp(-\\frac{\\| \\mathbf{x} - \\mathbf{y} \\|^2}{2\\sigma^2}) $$ An unbiased estimation of $\\exp(\\mathbf{x} \\cdot \\mathbf{y})$ is:\n $$ \\begin{aligned} \\exp(\\mathbf{x} \\cdot \\mathbf{y} / \\sigma^2) \u0026= \\exp(\\frac{1}{2\\sigma^2}(\\|\\mathbf{x}\\|^2 + \\|\\mathbf{y}\\|^2 - \\|\\mathbf{x} - \\mathbf{y}\\|^2) \\\\ \u0026= \\exp(\\frac{\\|\\mathbf{x}\\|^2}{2\\sigma^2}) \\exp(\\frac{\\|\\mathbf{y}\\|^2}{2\\sigma^2}) ( - \\frac{\\|\\mathbf{x} - \\mathbf{y}\\|^2}{2\\sigma^2}) \\\\ \u0026\\approx \\exp(\\frac{\\|\\mathbf{x}\\|^2}{2\\sigma^2}) \\exp(\\frac{\\|\\mathbf{y}\\|^2}{2\\sigma^2})\\;\\phi(\\mathbf{x})\\cdot\\phi(\\mathbf{y}) \\\\ \u0026= \\exp(\\frac{1}{\\sigma^2})\\;\\phi(\\mathbf{x})\\cdot\\phi(\\mathbf{y}) \u0026 \\text{; unit vectors} \\end{aligned} $$ Then we can write the attention function as follows, where $\\otimes$ is outer product operation and $\\sigma^2$ is the temperature:\n $$ \\begin{aligned} \\text{attn}(\\mathbf{q}_t, \\{\\mathbf{k}_i\\}, \\{\\mathbf{v}_i\\}) \u0026= \\sum_i \\frac{\\exp(\\mathbf{q}_t\\cdot\\mathbf{k}_i/\\sigma^2)}{\\sum_j \\exp(\\mathbf{q}_t\\cdot\\mathbf{k}_j/\\sigma^2)}\\mathbf{v}_i^\\top \\approx \\sum_i \\frac{\\phi(\\mathbf{q}_t)\\phi(\\mathbf{k}_i)\\mathbf{v}_i^\\top}{\\sum_j \\phi(\\mathbf{q}_t)\\phi(\\mathbf{k}_j)} \\\\ \u0026= \\color{green}{\\frac{\\phi(\\mathbf{q}_t)^\\top \\sum_i \\phi(\\mathbf{k}_i)\\otimes\\mathbf{v}_i}{\\phi(\\mathbf{q}_t)^\\top \\sum_j \\phi(\\mathbf{k}_j)} = \\text{RFA}(\\mathbf{q}_t, \\{\\mathbf{k}_i\\}, \\{\\mathbf{v}_i\\})} \\end{aligned} $$ Fig. 23. (Left) The order of computation for default softmax operation. (Right) The order of computation when using random feature attention, a lot cheaper than default softmax. (Image source: Peng et al. 2021). Causal Attention RFA has token at time step $t$ only attend to earlier keys and values $\\{\\mathbf{k}_i\\}_{i \\leq t}, \\{\\mathbf{v}_i\\}_{i \\leq t}$. Let us use a tuple of variables, $(\\mathbf{S}_t \\in \\mathbb{R}^{2D \\times d}, \\mathbf{z} \\in \\mathbb{R}^{2D})$, to track the hidden state history at time step $t$, similar to RNNs:\n $$ \\begin{aligned} \u0026\\text{causal-RFA}(\\mathbf{q}_t, \\{\\mathbf{k}_i\\}_{i \\leq t}, \\{\\mathbf{v}_i\\}_{i \\leq t}) = \\frac{\\phi(\\mathbf{q}_t)^\\top \\mathbf{S}_t}{\\phi(\\mathbf{q}_t) \\cdot \\mathbf{z}_t} \\\\ \u0026\\text{where } \\mathbf{S}_t = \\mathbf{S}_{t-1} + \\phi(\\mathbf{k}_t)\\otimes\\mathbf{v}_t, \\quad \\mathbf{z}_t = \\mathbf{z}_{t-1} + \\phi(\\mathbf{k}_t) \\end{aligned} $$ where $2D$ is the size of $\\phi(.)$ and $D$ should be no less than the model size $d$ for reasonable approximation.\nRFA leads to significant speedup in autoregressive decoding and the memory complexity mainly depends on the choice of $D$ when constructing the kernel $\\phi(.)$.\nPerformer modifies the random feature attention with positive random feature maps to reduce the estimation error. It also keeps the randomly sampled $\\mathbf{w}_1, \\dots, \\mathbf{w}_D$ to be orthogonal to further reduce the variance of the estimator.\nFig. 24. Comparison of approximation error when using (Left) i.i.d vs orthogonal features and (Right) sin/cos vs positive random features. (Image source: Choromanski et al. 2021). Transformers for Reinforcement Learning The self-attention mechanism avoids compressing the whole past into a fixed-size hidden state and does not suffer from vanishing or exploding gradients as much as RNNs. Reinforcement Learning tasks can for sure benefit from these traits. However, it is quite difficult to train Transformer even in supervised learning, let alone in the RL context. It could be quite challenging to stabilize and train a LSTM agent by itself, after all.\nThe Gated Transformer-XL (GTrXL; Parisotto, et al. 2019) is one attempt to use Transformer for RL. GTrXL succeeded in stabilizing training with two changes on top of Transformer-XL:\n The layer normalization is only applied on the input stream in a residual module, but NOT on the shortcut stream. A key benefit to this reordering is to allow the original input to flow from the first to last layer. The residual connection is replaced with a GRU-style (Gated Recurrent Unit; Chung et al., 2014) gating mechanism. $$ \\begin{aligned} r \u0026= \\sigma(W_r^{(l)} y + U_r^{(l)} x) \\\\ z \u0026= \\sigma(W_z^{(l)} y + U_z^{(l)} x - b_g^{(l)}) \\\\ \\hat{h} \u0026= \\tanh(W_g^{(l)} y + U_g^{(l)} (r \\odot x)) \\\\ g^{(l)}(x, y) \u0026= (1-z)\\odot x + z\\odot \\hat{h} \\end{aligned} $$ The gating function parameters are explicitly initialized to be close to an identity map - this is why there is a $b_g$ term. A $b_g \u0026gt; 0$ greatly helps with the learning speedup.\nFig. 25. Comparison of the model architecture of Transformer-XL, Transformer-XL with the layer norm reordered, and Gated Transformer-XL. (Image source: Figure 1 in Parisotto, et al. 2019) Decision Transformer (DT; Chen et al 2021) formulates Reinforcement Learning problems as a process of conditional sequence modeling, outputting the optimal actions conditioned on the desired return, past states and actions. It therefore becomes straightforward to use Transformer architecture. Decision Transformer is for off-policy RL, where the model only has access to a fixed collection of trajectories collected by other policies.\nTo encourage the model to learn how to act in order to achieve a desired return, it feeds the model with desired future return $\\hat{R} = \\sum_{t'=t}^T r_{t'}$ instead of the current reward. The trajectory consists of a list of triplets, (return-to-go $\\hat{R}_t, state $s_t$, action $a_t$), and it is used as an input sequence for Transformer:\n $$ \\tau = (\\hat{R}_1, s_1, a_1, \\hat{R}_2, s_2, a_2, \\dots, \\hat{R}_T, s_T, a_T) $$ Three linear layers are added and trained for return-to-go, state and action respectively to extract token embeddings. The prediction head learns to predict $a_t$ corresponding to the input token $s_t$. The training uses cross-entropy loss for discrete actions or MSE for continuous actions. Predicting the states or return-to-go was not found to help improve the performance in their experiments.\nThe experiments compared DT with several model-free RL algorithm baselines and showed that:\n DT is more efficient than behavior cloning in low data regime; DT can model the distribution of returns very well; Having a long context is crucial for obtaining good results; DT can work with sparse rewards. Citation Cited as:\n Weng, Lilian. (Jan 2023). The transformer family version 2.0. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/.\n Or\n@article{weng2023transformer, title = \u0026quot;The Transformer Family Version 2.0\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2023\u0026quot;, month = \u0026quot;Jan\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/\u0026quot; } References [1] Ashish Vaswani, et al. \u0026ldquo;Attention is all you need.\u0026quot; NIPS 2017.\n[2] Rami Al-Rfou, et al. \u0026ldquo;Character-level language modeling with deeper self-attention.\u0026quot; AAAI 2019.\n[3] Olah \u0026amp; Carter, \u0026ldquo;Attention and Augmented Recurrent Neural Networks\u0026rdquo;, Distill, 2016.\n[4] Sainbayar Sukhbaatar, et al. \u0026ldquo;Adaptive Attention Span in Transformers\u0026rdquo;. ACL 2019.\n[5] Rewon Child, et al. \u0026ldquo;Generating Long Sequences with Sparse Transformers\u0026rdquo; arXiv:1904.10509 (2019).\n[6] Nikita Kitaev, et al. \u0026ldquo;Reformer: The Efficient Transformer\u0026rdquo; ICLR 2020.\n[7] Alex Graves. (\u0026ldquo;Adaptive Computation Time for Recurrent Neural Networks\u0026rdquo;)[https://arxiv.org/abs/1603.08983]\n[8] Niki Parmar, et al. \u0026ldquo;Image Transformer\u0026rdquo; ICML 2018.\n[9] Zihang Dai, et al. \u0026ldquo;Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.\u0026quot; ACL 2019.\n[10] Aidan N. Gomez, et al. \u0026ldquo;The Reversible Residual Network: Backpropagation Without Storing Activations\u0026rdquo; NIPS 2017.\n[11] Mostafa Dehghani, et al. \u0026ldquo;Universal Transformers\u0026rdquo; ICLR 2019.\n[12] Emilio Parisotto, et al. \u0026ldquo;Stabilizing Transformers for Reinforcement Learning\u0026rdquo; arXiv:1910.06764 (2019).\n[13] Rae et al. “Compressive Transformers for Long-Range Sequence Modelling.” 2019.\n[14] Press et al. “Train Short, Test Long: Attention With Linear Biases Enables Input Length Extrapolation.” ICLR 2022.\n[15] Wu, et al. “DA-Transformer: Distance Aware Transformer” 2021.\n[16] Elabyad et al. “Depth-Adaptive Transformer.” ICLR 2020.\n[17] Schuster et al. “Confident Adaptive Language Modeling” 2022.\n[18] Qiu et al. “Blockwise self-attention for long document understanding” 2019\n[19] Roy et al. “Efficient Content-Based Sparse Attention with Routing Transformers.” 2021.\n[20] Ainslie et al. “ETC: Encoding Long and Structured Inputs in Transformers.” EMNLP 2019.\n[21] Beltagy et al. “Longformer: The long-document transformer.” 2020.\n[22] Zaheer et al. “Big Bird: Transformers for Longer Sequences.” 2020.\n[23] Wang et al. “Linformer: Self-Attention with Linear Complexity.” arXiv preprint arXiv:2006.04768 (2020).\n[24] Tay et al. 2020 “Sparse Sinkhorn Attention.” ICML 2020.\n[25] Peng et al. “Random Feature Attention.” ICLR 2021.\n[26] Choromanski et al. “Rethinking Attention with Performers.” ICLR 2021.\n[27] Khandelwal et al. “Generalization through memorization: Nearest neighbor language models.” ICLR 2020.\n[28] Yogatama et al. “Adaptive semiparametric language models.” ACL 2021.\n[29] Wu et al. “Memorizing Transformers.” ICLR 2022.\n[30] Su et al. “Roformer: Enhanced transformer with rotary position embedding.” arXiv preprint arXiv:2104.09864 (2021).\n[31] Shaw et al. “Self-attention with relative position representations.” arXiv preprint arXiv:1803.02155 (2018).\n[32] Tay et al. \u0026ldquo;Efficient Transformers: A Survey.\u0026quot; ACM Computing Surveys 55.6 (2022): 1-28.\n[33] Chen et al., \u0026ldquo;Decision Transformer: Reinforcement Learning via Sequence Modeling\u0026rdquo; arXiv preprint arXiv:2106.01345 (2021).\n","permalink":"https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/","summary":"Many new Transformer architecture improvements have been proposed since my last post on \u0026ldquo;The Transformer Family\u0026rdquo; about three years ago. Here I did a big refactoring and enrichment of that 2020 post \u0026mdash; restructure the hierarchy of sections and improve many sections with more recent papers. Version 2.0 is a superset of the old version, about twice the length.\nNotations Symbol Meaning $d$ The model size / hidden state dimension / positional encoding size.","title":"The Transformer Family Version 2.0"},{"content":"[Updated on 2023-01-24: add a small section on Distillation.]\nLarge transformer models are mainstream nowadays, creating SoTA results for a variety of tasks. They are powerful but very expensive to train and use. The extremely high inference cost, in both time and memory, is a big bottleneck for adopting a powerful transformer for solving real-world tasks at scale.\nWhy is it hard to run inference for large transformer models? Besides the increasing size of SoTA models, there are two main factors contributing to the inference challenge (Pope et al. 2022):\n Large memory footprint. Both model parameters and intermediate states are needed in memory at inference time. For example, The KV cache should be stored in memory during decoding time; E.g. For a batch size of 512 and context length of 2048, the KV cache totals 3TB, that is 3x the model size (!). Inference cost from the attention mechanism scales quadratically with input sequence length. Low parallelizability. Inference generation is executed in an autoregressive fashion, making the decoding process hard to parallel. In this post, we will look into several approaches for making transformer inference more efficient. Some are general network compression methods, while others are specific to transformer architecture.\nMethods Overview We in general consider the following as goals for model inference optimization:\n Reduce the memory footprint of the model by using fewer GPU devices and less GPU memory; Reduce the desired computation complexity by lowering the number of FLOPs needed; Reduce the inference latency and make things run faster. Several methods can be used to make inference cheaper in memory or/and faster in time.\n Apply various parallelism to scale up the model across a large number of GPUs. Smart parallelism of model components and data makes it possible to run a model of trillions of parameters. Memory offloading to offload temporarily unused data to the CPU and read them back when needed later. This helps with memory usage but causes higher latency. Smart batching strategy; E.g. EffectiveTransformer packs consecutive sequences together to remove padding within one batch. Network compression techniques, such as pruning, quantization, distillation. A model of smaller size, in terms of parameter count or bitwidth, should demand less memory and run faster. Improvement specific to a target model architecture. Many architectural changes, especially those for attention layers, help with transformer decoding speed. Check the previous post on large model training on different types of training parallelism and memory saving designs including CPU memory offloading. This post focuses on network compression techniques and architecture-specific improvement for transformer models.\nDistillation Knowledge Distillation (KD; Hinton et al. 2015, Gou et al. 2020) is a straightforward way to build a smaller, cheaper model (\u0026ldquo;student model\u0026rdquo;) to speed up inference by transferring skills from a pre-trained expensive model (\u0026ldquo;teacher model\u0026rdquo;) into the student. There is no much restriction on how the student architecture should be constructed, except for a matched output space with the teacher in order to construct a proper learning objective.\nFig. 1. The generic framework of teacher-student knowledge distillation training. (Image source: Gou et al. 2020) Given a dataset, a student model is trained to mimic outputs of a teacher via distillation loss. Usually a neural network has a softmax layer; For example, a LLM outputs a probability distribution over tokens. Let\u0026rsquo;s denote the logits layer right before softmax as $\\mathbf{z}_t$ and $\\mathbf{z}_s$ for teacher and student models, respectively. The distillation loss minimizes the difference between two softmax outputs with a high temperature $T$. When ground truth labels $\\mathbf{y}$ are known, we can combine it with a supervised learning objective between ground truth and the student\u0026rsquo;s soft logits using e.g. cross-entropy.\n $$ \\mathcal{L}_\\text{KD} = \\mathcal{L}_\\text{distll}(\\text{softmax}(\\mathbf{z}_t, T), \\text{softmax}(\\mathbf{z}_s, T)) + \\lambda\\mathcal{L}_\\text{CE}(\\mathbf{y}, \\mathbf{z}_s) $$ where $\\lambda$ is a hyperparameter to balance between soft and hard learning objectives. A common choice for $\\mathcal{L}_\\text{distll}$ is KL divergence / cross entropy.\nA successful early trial is DistilBERT (Sanh et al. 2019) that is able to reduce the parameters of a BERT by 40% while maintaining 97% performance of BERT on fine-tuned downstream tasks and running 71% faster. The loss of pre-training DistilBERT is a combination of soft distillation loss, supervised training loss (i.e. Masked language modeling loss $\\mathcal{L}_\\text{MLM}$ in the case of BERT) and a special cosine embedding loss to align the hidden state vectors between teacher and student.\nDistillation can be easily combined with quantization, pruning or sparsification techniques, where the teacher model is the original full-precision, dense model and the student is quantized, pruned, or trimmed to have higher sparsity level.\nQuantization There are two common approaches for applying quantization on a deep neural network:\n Post-Training Quantization (PTQ): A model is first trained to convergence and then we convert its weights to lower precision without more training. It is usually quite cheap to implement, in comparison to training. Quantization-Aware Training (QAT): Quantization is applied during pre-training or further fine-tuning. QAT is able to attain better performance but requires extra computation resources and access to representative training data. We should be aware of the gap between theoretical optimal quantization strategy and the hardware kernel support. Due to the lack of GPU kernel support for certain types of matrix multiplication (e.g. INT4 x FP16), not all the methods below result in speedup for the actual inference.\nChallenges for Transformer Quantization Many studies on Transformer model quantization have the same observation: A simple low-precision (e.g. 8-bit) post-training quantization leads to significant performance drop mainly due to the high dynamic ranges of activation and a naive activation quantization strategy fails to maintain the capacity.\nFig. 2. Only quantizing model weights to 8-bit while keeping activation at full precision (`W8A32`) achieves much better results when activations are quantized to 8-bit irrespective of whether weights are in lower precision (`W8A8` and `W32A8`). (Image source: Bondarenko et al. 2021) Bondarenko et al. (2021) observed in a small BERT model that FFN’s input and output have very different dynamic ranges due to strong outliers in the output tensor. Therefore per-tensor quantization for the FFN’s residual sum is likely to cause a notable error.\nAs the model size continues to grow to billions of parameters, outlier features of high magnitude start to emerge in all transformer layers, causing failure of simple low-bit quantization. Dettmers et al. (2022) observed such a phenomenon for OPT models larger than 6.7B parameters. Larger models have more layers with extreme outliers and these outlier features have a significant impact on the model performance. The scale of activation outliers in a few dimensions can be ~100× larger than most of the other values.\nFig. 3. The mean zero-shot accuracy over a set of language tasks (WinoGrande, HellaSwag, PIQA, LAMBADA) of OPT models of increasing sizes. (Image source: Dettmers et al. 2022) Post-training quantization (PTQ) Mixed-precision quantization The most straightforward approach for resolving the above quantization challenge is to implement quantization at different precision for weights vs activation.\nGOBO (Zadeh et al. 2020) is one of the first models to apply post-training quantization on transformers (i.e. a small BERT model). It assumes that model weights of each layer follow a Gaussian distribution and therefore detects outliers by tracking mean and standard deviation per layer. Outlier features remain in original form, while other values are split into multiple bins and only corresponding bin indices of weights and the centroid values are stored.\nBased on the observation that only certain activation layers (e.g. residual connections after FFN) in BERT cause big performance drop, Bondarenko et al. (2021) adopted mixed-precision quantization by using 16-bit quantization on problematic activations but 8-bit on others.\nMixed-precision quantization in LLM.int8() (Dettmers et al. 2022) is implemented via two mixed-precision decompositions:\n Because matrix multiplication contains a set of independent inner products between row and column vectors, we can impose independent quantization per inner product: Each row and column are scaled by the absolution maximum values and then quantized to INT8. Outlier activation features (e.g. 20x larger than other dimensions) remain in FP16 but they represent only a tiny fraction of total weights. How to identify outliers is empirical. Fig. 4. Two mixed-precision decompositions of `LLM.int8()`. (Image source: Dettmers et al. 2022) Quantization at fine-grained granularity Fig. 5. Comparison of quantization at different granularity. $d$ is the model size / hidden state dimension and $h$ is the number of heads in one MHSA (multi-head self-attention) component. Naively quantizing the entire weight matrix in one layer (\u0026ldquo;per-tensor\u0026rdquo; or \u0026ldquo;per-layer\u0026rdquo; quantization) is easiest to implement but does not lead to good granularity of quantization.\nQ-BERT (Shen, Dong \u0026amp; Ye, et al. 2020) applied group-wise quantization to a fine-tuned BERT model, treating an individual matrix $W$ with respect to each head in MHSA (multi-head self-attention) as one group and then applies Hessian based mixed precision quantization.\nPer-embedding group (PEG) activation quantization was motivated by the observation that outlier values only appear in a few out of $d$ (hidden state / model size) dimensions (Bondarenko et al. 2021). Per-embedding is pretty computationally expensive. In comparison, PEG quantization splits the activation tensor into several evenly sized groups along the embedding dimension where elements in the same group share quantization parameters. To ensure all outliers are grouped together, they apply a deterministic range-based permutation of embedding dimensions, where dimensions are sorted by their value ranges.\nZeroQuant (Yao et al. 2022) uses group-wise quantization for weights, same as in Q-BERT, and token-wise quantization for activation. To avoid expensive quantization and de-quantization computation, ZeroQuant built customized kernel to fuse quantization operation with its previous operator.\nSecond order information for quantization Q-BERT (Shen, Dong \u0026amp; Ye, et al. 2020) developed Hessian AWare Quantization (HAWQ) for its mixed-precision quantization. The motivation is that parameters with higher Hessian spectrum (i.e., larger top eigenvalues) are more sensitive to quantization and thus require higher precision. It is essentially a way to identify outliers.\nIn another viewpoint, the problem of quantization is an optimization problem. Given a weight matrix $\\mathbf{W}$ and an input matrix $\\mathbf{X}$ , we want to find a quantized weight matrix $\\hat{\\mathbf{W}}$ to minimize the MSE:\n$$ \\hat{\\mathbf{W}}^* = {\\arg\\min}_{\\hat{\\mathbf{W}}} | \\mathbf{W}\\mathbf{X} - \\hat{\\mathbf{W}}\\mathbf{X}| $$\nGPTQ (Frantar et al. 2022) treats the weight matrix $\\mathbf{W}$ as a collection of row vectors ${\\mathbf{w}}$ and applies quantization to each row independently. GPTQ iteratively quantizes more weights that are selected greedily to minimize the quantization error. The update on selected weights has a closed-form formula, utilizing Hessian matrices. Read more details in the paper and the OBQ (Optimal Brain Quantization; Frantar \u0026amp; Alistarh 2022) method if interested. GPTQ can reduce the bitwidth of weights in OPT-175B down to 3 or 4 bits without much performance loss, but it only applies to model weights not activation.\nOutlier smoothing It is known that activations are harder to quantize than weights in transformer models. SmoothQuant (Xiao \u0026amp; Lin 2022) proposed a smart solution to smooth outlier features from activations to weights via mathematically equivalent transformation and then enable quantization on both weights and activations (W8A8). Because of this, SmoothQuant has better hardware efficiency than mixed-precision quantization.\nFig. 6. SmoothQuant migrates the scale variance from activations to weights offline to reduce the difficulty of activation quantization. Both the resulting new weight and activation matrices are easy to quantize. (Image source: Xiao \u0026 Lin 2022) Considering a per-channel smooth factor $\\mathbf{s}$, SmoothQuant scales the weights according to:\n$$ \\mathbf{Y} = (\\mathbf{X} \\text{diag}(\\mathbf{s})^{-1}) \\cdot (\\text{diag}(\\mathbf{s})\\mathbf{W}) = \\hat{\\mathbf{X}}\\hat{\\mathbf{W}} $$\nThe smoothing factor can be easily fused into previous layers' parameters offline. A hyperparameter $\\alpha$ controls how much we migrate the quantization difficulty from activations to weights: $\\mathbf{s} = \\max (\\vert \\mathbf{X}_j \\vert)^\\alpha / \\max( \\vert \\mathbf{W}_j \\vert )^{1-\\alpha}$. The paper found that $\\alpha=0.5$ is a sweet spot for many LLMs in the experiments. For models with more significant outliers in activation, $\\alpha$ can be adjusted to be larger.\nQuantization-aware training (QAT) Quantization-aware training fuses the quantization operation into the pre-training or fine-tuning process. It learns model weights in low-bit representation directly and leads to better performance at the cost of additional training time and computation.\nThe most straightforward approach is to fine-tune the model after quantization on a training dataset that is the same as or representative of the pre-training dataset. The training objective can be the same as the one for pre-training (e.g. NLL/MLM in general language model training) or specific to a downstream task that we care about (e.g. Cross entropy for classification).\nAnother approach is to consider the full-precision model as the teacher and the lower-precision model as the student, and then optimize the low-precision model with distillation loss. Distillation usually doesn\u0026rsquo;t need to use the original dataset; E.g. Wikipedia dataset is a good choice and even random tokens can give decent performance gain. The Layer-by-layer Knowledge Distillation (LKD; Yao et al. 2022) method quantizes the network layer by layer and uses its original, unquantized version as the teacher model. Given the same inputs, LKD minimizes the MSE between the multiplication with layer weights and the multiplication of quantized layer weights.\nPruning Network pruning is to reduce the model size by trimming unimportant model weights or connections while the model capacity remains. It may or may not require re-training. Pruning can be unstructured or structured.\n Unstructured pruning is allowed to drop any weight or connection, so it does not retain the original network architecture. Unstructured pruning often does not work well with modern hardware and doesn\u0026rsquo;t lead to actual inference speedup. Structured pruning aims to maintain the dense matrix multiplication form where some elements are zeros. They may need to follow certain pattern restrictions to work with what hardware kernel supports. Here we focus on structured pruning to achieve high sparsity in transformer models. A routine workflow to construct a pruned network has three steps:\n Train a dense network until convergence; Prune the network to remove unwanted structure; Optionally retrain the network to recover the performance with new weights. The idea of discovering a sparse structure within a dense model via network pruning while the sparse network can still maintain similar performance is motivated by Lottery Ticket Hypothesis (LTH): A randomly initialized, dense, feed-forward network contains a pool of subnetworks and among them only a subset (a sparse network) are \u0026ldquo;winning tickets\u0026rdquo; which can achieve the optimal performance when trained in isolation.\nHow to prune? Magnitude pruning is simplest yet quite effective pruning method - weights with smallest absolute values are trimmed. In fact, some studies (Gale et al. 2019) found that simple magnitude pruning approaches can achieve comparable or better results than complicated pruning methods, such as variational dropout (Molchanov et al. 2017) and $l_0$ regularization (Louizos et al. 2017). Magnitude pruning is simple to apply to large models and achieves reasonably consistent performance across a wide range of hyperparameters.\nZhu \u0026amp; Gupta (2017) found that large sparse models were able to achieve better performance than their small but dense counterparts. They proposed Gradual Magnitude Pruning (GMP) algorithm that increases the sparsity of a network gradually over the course of training. At each training step, weights with smallest absolute values are masked to be zeros to achieve a desired sparsity level $s$ and masked weights do not get gradient update during back-propagation. The desired sparsity level $s$ goes up with more training steps. The process of GMP is sensitive to the learning rate schedule, which should be higher than what\u0026rsquo;s used in dense network training, but not too high to prevent convergence.\nIterative pruning (Renda et al. 2020) iterates step 2 (prune) \u0026amp; step 3 (retrain) multiple times: Only a small fraction of weights are pruned and the model is retrained in each iteration. The process repeats until a desired sparsity level is reached.\nHow to retrain? The retraining step can be simple fine-tuning using the same pre-training data or other task-specific datasets.\nLottery Ticket Hypothesis proposed a weight rewinding retraining technique: After pruning, the unpruned weights are reinitialized back to original values earlier in the training and then retrain with the same learning rate schedule.\nLearning rate rewinding (Renda et al. 2020) only resets the learning rate back to its early value, while the unpruned weights stay unchanged since the end of the last train stage. They observed that (1) retraining with weight rewinding outperforms retraining with fine-tuning across networks and datasets and (2) learning rate rewinding matches or outperforms weight rewinding in all tested scenarios.\nSparsity Sparsity is an effective way to scale up model capacity while keeping model inference computationally efficient. Here we consider two types of sparsity for transformers:\n Sparsified dense layers, including both self-attention and FFN layers. Sparse model architecture; i.e. via incorporating the Mixture-of-Experts (MoE) component. N:M Sparsity via Pruning N:M sparsity is a structured sparsity pattern that works well with modern GPU hardware optimization, in which $N$ out of every $M$ consecutive elements are zeros. For example, the sparse tensor core of Nvidia A100 GPU has support for 2:4 sparsity for faster inference (Nvidia 2020).\nFig. 7. A matrix of 2:4 structured sparsity and its compressed representation. (Image source: Nvidia blog) To sparsify a dense neural network to follow a N:M structured sparsity pattern, Nvidia (2020) suggested using the three-step routine workflow for training a pruned network: train \u0026ndash;\u0026gt; prune to satisfy 2:4 sparsity \u0026ndash;\u0026gt; retrain.\nPermuting columns can provide more options in the pruning process to maintain parameters of large magnitude or to satisfy a special restriction like N:M sparsity (Pool \u0026amp; Yu 2021). As long as paired axes of two matrices are permuted in the same order, the results of matrix multiplication would not change. For example,\n(1) Within the self-attention module, if the same permutation order is applied on the axis 1 of query embedding matrix $\\mathbf{Q}$ and the axis 0 of key embedding matrix $\\mathbf{K}^\\top$, the final result of matrix multiplication of $\\mathbf{Q}\\mathbf{K}^\\top$ would stay the same.\nFig. 8. Illustration of same permutation on $\\mathbf{Q}$ (axis 1) and $\\mathbf{K}^\\top$ (axis 0) to keep the results of a self-attention module unchanged. (2) Within the FFN layer that contains two MLP layers and one ReLU non-linear layer, we can permute the first linear weight matrix $\\mathbf{W}_1$ along the axis 1 and the second linear weight matrix $\\mathbf{W}_2$ along the axis 0 in the same order.\nFig. 9. Illustration of the same permutation on $\\mathbf{W}_1$ (axis 1) and $\\mathbf{W}_2$ (axis 0) to keep the FFN layer's output unchanged. For simplicity, the bias terms are skipped but the same permutation should be applied on them too. To enforce N:M structured sparsity, let\u0026rsquo;s split the columns of one matrix into multiple slides of $M$ columns (named \u0026ldquo;stripe\u0026rdquo;) and we can easily observe that both the order of columns within each stripe and the order of stripes have no effect on the N:M sparsity restriction.\nPool \u0026amp; Yu (2021) proposed an iterative greedy algorithm to find optimal permutation that maximizes the weight magnitude for N:M sparsity. All pairs of channels are speculatively swapped and only the swap that leads to the greatest increase in magnitude is adopted, generating a new permutation and concluding a single iteration. Greedy algorithm may only find local minima, so they introduced two techniques to escape local minima:\n Bounded regressions: In practice two random channels are swapped, up to a fixed number of times. The solution search is limited to a depth of only one channel swap to keep the search space broad and shallow. Narrow, deep search: Choose multiple stripes and optimize them at the same time. Fig. 10. Algorithm of finding the best permutation for N:M sparsity greedily and iteratively. (Image source: Pool \u0026 Yu 2021) The network can achieve better performance if it was permuted before pruning, compared to pruning the network in its default channel order.\nTo train a model with N:M sparsity from scratch, Zhou \u0026amp; Ma, et al. (2021) extended STE (Straight-Through Estimator; Bengio et al. 2013), which is commonly used for back-propagation update in model quantization, to work for magnitude pruning and sparse parameter update.\nSTE computes the gradients of dense parameters wrt the pruned network $\\widetilde{W}$, $\\partial \\mathcal{L}/\\partial \\widetilde{W}$, and applies that to the dense network $W$ as an approximation:\n$$ W_{t+1} \\gets W_t - \\gamma \\frac{\\partial\\mathcal{L}}{\\partial\\widetilde{W}} $$\nThe extended version, SR-STE (Sparse-refined STE), updates the dense weights $W$ by:\n$$ W_{t+1} \\gets W_t - \\gamma \\frac{\\partial\\mathcal{L}}{\\partial\\widetilde{W}} + \\lambda_W (\\bar{\\mathcal{E}} \\odot W_t) $$ where $\\bar{\\mathcal{E}}$ is the mask matrix for $\\widetilde{W}$ and $\\odot$ is element-wise multiplication. SR-STE is proposed to prevent large change in the binary mask by (1) restricting the values of weights pruned in $\\widetilde{W}_t$, and (2) promoting the non-pruned weights in $\\widetilde{W}_t$.\nFig. 11. Comparison of STE and SR-STE. $\\odot$ is element-wise product; $\\otimes$ is matrix multiplication. (Image source: Zhou \u0026 Ma, et al. 2021) Different from STE or SR-STE, the Top-KAST (Jayakumar et al. 2021) method can preserve constant sparsity throughout training in both the forward and backward-passes but does not require forward passes with dense parameters or dense gradients.\nAt one training step $t$, Top-KAST processes as follows:\n Sparse forward pass: Select a subset of parameters $A^t \\subset \\Theta$, containing top-$K$ parameters by magnitude by each layer, restricted to top $D$-proportion of weights. The parameterization $\\alpha^t$ at time $t$ has parameters zeroed out if it is not in $A^t$ (active weights). $$ \\alpha^t_i = \\begin{cases} \\theta^t_i \u0026 \\text{ if } i \\in A^t = \\{i \\mid \\theta^t_i \\in \\text{TopK}(\\theta^t, D) \\}\\\\ 0 \u0026 \\text{ otherwise} \\end{cases} $$ where $\\text{TopK}(\\theta, x)$ selected top $x$ proportion of weights from $\\theta$ based on magnitude.\nSparse backward pass: Then apply gradients to a larger parameter subset $B \\subset \\Theta$ where $B$ contains $(D+M)$-proportion of weights and $A \\subset B$. Updating a larger proportion of weights enables more effective exploration of different pruning masks, making it more likely to cause permutations in the top $D$-proportion active weights. $$ \\Delta_{\\theta^t_i} = \\begin{cases} -\\eta \\nabla_{\\alpha_t} \\mathcal{L}(y, x, \\alpha^t)_i \u0026 \\text{ if } i\\in B^t = \\{i \\mid \\theta^t_i \\in \\text{TopK}(\\theta^t, D+M) \\} \\\\ 0 \u0026 \\text{ otherwise } \\end{cases} $$ Training is split into two stages and the additional coordinates in the set $B \\setminus A$ controls how much exploration is brought in. The amount of exploration is expected to diminish gradually through the training process and the mask eventually stabilizes.\nFig. 12. The pruning mask of Top-KAST stabilizes in time. (Image source: Jayakumar et al. 2021) To prevent rich-get-richer phenomenon, Top-KAST penalizes the magnitude of active weights via a L2 regularization loss to encourage more exploration of new items. Parameters in $B \\setminus A$ are penalized more than $A$ for a higher selection bar during updates to stabilize the mask.\n $$ L_\\text{penalty}(\\alpha^t_i) = \\begin{cases} \\vert \\theta^t_i\\vert \u0026 \\text{ if } i \\in A^t \\\\ \\vert \\theta^t_i\\vert / D \u0026 \\text{ if } i \\in B^t \\setminus A^t \\\\ 0 \u0026 \\text{ otherwise} \\end{cases} $$ Sparsified Transformer Scaling Transformer (Jaszczur et al. 2021) sparsifies both self-attention and FFN layers in transformer architecture, achieving 37x speedup for single-example inference.\nFig. 13. The speed of decoding a single token (unbatched inference) by a transformer model when sparsification is applied on different layers. (Image source: Jaszczur et al. 2021) Sparse FFN layer: Each FFN layer contains 2 MLP and one ReLU in-between. Because ReLU will introduce a lot of zeros, they implement a fixed structure on activations to enforce only 1 non-zero value in one block of $N$ elements. The sparsity pattern is dynamic, different for each token.\n $$ \\begin{aligned} Y_\\text{sparse} \u0026= \\max(0, xW_1 + b_1) \\odot \\text{Controller}(x) \\\\ \\text{SparseFFN}(x) \u0026= Y_\\text{sparse} W_2 + b_2 \\\\ \\text{Controller}(x) \u0026= \\arg\\max(\\text{Reshape}(x C_1 C_2, (-1, N))) \\end{aligned} $$ where each activation in $Y_\\text{sparse}$ corresponds to one column in $W_1$ and one row in $W_2$. The controller is implemented as a low-rank bottleneck dense layer, $C_1 \\in \\mathbb{R}^{d_\\text{model} \\times d_\\text{lowrank}}, C_2 \\in \\mathbb{R}^{d_\\text{lowrank} \\times d_\\text{ff}}$ and $d_\\text{lowrank} = d_\\text{model} / N$. It uses $\\arg\\max$ for inference to select which columns should be non-zero and Gumbel-softmax trick (Jang et al. 2016) during training. Because we can compute $\\text{Controller}(x)$ before loading FFN weight matrices, we know which columns will be zeroed out and thus choose not to load them into memory for inference speedup.\nFig. 14. (a) Sparse FFN layer; columns in red are not loaded in memory for faster inference. (b) Sparse FFN controller for 1:4 sparsity. (Image source: Jaszczur et al. 2021) *Lilian's side note*: Fig (a) in the illustration from the paper is actually $Y_\\text{sparse} = \\max\\big(0, (xW_1 + b_1) \\odot \\text{Controller}(x)\\big)$, but it doesn't change the results. Sparse QKV (attention) layer: In the attention layer, the dimensionality $d_\\text{model}$ is divided into $S$ modules, each of size $M=d_\\text{model} /S$. To make sure each subdivision can access any part of the embedding, Scaling Transformer introduces a multiplicative layer (i.e., a multiplication layer multiplies inputs from multiple neural network layers element-wise) which can represent arbitrary permutation but contains fewer parameters than a dense layer.\nGiven an input vector $x \\in \\mathbb{R}^{d_\\text{model}}$, the multiplicative layer outputs $y \\in \\mathbb{R}^{S \\times M}$:\n $$ y_{s,m} = \\sum_i x_i D_{i,s} E_{i,m} \\quad\\text{where }D \\in \\mathbb{R}^{d_\\text{model} \\times S}, D \\in \\mathbb{R}^{d_\\text{model} \\times M} $$ The output of the multiplicative layer is a tensor of size $\\in \\mathbb{R}^{\\text{batch size}\\times \\text{length} \\times S \\times M}$. It then gets processed by a two-dimensional convolutional layer, where $\\text{length}$ and $S$ are treated as the height and width of an image. Such a convolution layer further reduces the parameter count and computation time of attention layer.\nFig. 15. (a) A multiplicative layer is introduced to enable partitions to access any part of an embedding. (b) Combination of multiplicative dense layer and 2-D convolutional layer reduces the number of parameters and computation time of the attention layer. (Image source: Jaszczur et al. 2021) To better work with long sequences, Scaling Transformer is further equipped with LSH (locality-sensitive hashing) attention from Reformer (Kitaev, et al. 2020) and FFN block recurrence, resulting in Terraformer.\nMixture-of-Experts Mixture-of-experts (MoE) models depend on a collection of \u0026ldquo;expert\u0026rdquo; networks and each example only activates a subset of networks to get predictions. The idea originated back to the 1990s (Jacobs et al. 1991) and is strongly related to ensemble methods. For details on how to incorporate MoE module into transformer, please check my previous post on large model training techniques and a survey paper on MoE by Fedus et al. 2022.\nWith MoE architecture, only partial parameters are utilized at decoding time and therefore it saves inference cost. The capacity of each expert can be adjusted with a hyperparameter, capacity factor $C$, and the expert capacity is defined as:\n $$ \\text{Expert capacity} = \\text{round}(C \\cdot k \\cdot \\frac{\\text{total # tokens in one batch}}{\\text{# experts}}) $$ where top-$k$ experts are selected per token. Larger $C$ leads to higher expert capacity and improved performance but more expensive computationally. When $C\u0026gt;1$, a slack capacity is added; otherwise, when $C\u0026lt;1$, the routing network needs to ignore some tokens.\nRouting Strategy Improvement MoE layer has a routing network to assign a subset of experts for each input token. The routing strategy in vanilla MoE models is to route each token toward preferred experts differently as they come up in the natural order. If a token is routed to experts that have reached their capacity, the token would be marked \u0026ldquo;overflowed\u0026rdquo; and skipped.\nV-MoE (Vision MoE; Riquelme et al. 2021) adds MoE layers into ViT (Vision Transformer). It matches the performance of previous SoTA but only requires half of inference compute. V-MoE can be scaled up to 15B parameters. Their experiments used $k=2$, 32 experts and every-2 expert placement (meaning that MoEs are placed in every other layer).\nSince each expert has a limited capacity, some important and informative tokens may have to be discarded if they come up too late in the predefined sequence order (e.g. the order of words in a sentence, or the order of image patches). To avoid such a drawback in the vanilla routing scheme, V-MoE adopts BPR (Batch Priority Routing) to assign experts to tokens with a high priority score first. BPR computes a priority score (max or sum of top-$k$ router scores) per token before expert assignment and alters the order of tokens accordingly. This guarantees that the expert capacity buffer would be fulfilled with key tokens first.\nFig. 16. How image patches are discarded according to priority scores when $C Riquelme et al. 2021) BPR works much better than vanilla routing when $C\\leq 0.5$, where the model starts dropping a significant amount of tokens. It capacitates the model to be competitive with the dense network even at quite low capacities.\nWhen looking into how to interpret image class-expert association, they observed that early MoE layers are more general, while later MoE layers could be specialized for a few image classes.\nTask MoE (Task-level Mixture-of-Experts; Kudugunta et al. 2021 ) takes the task information into consideration and routes tokens at the task level instead of the word or token level for machine translation. They used MNMT (multilingual neural machine translation) as an example and group translation tasks based on the target language or language pairs.\nToken level routing is dynamic and the routing decision for each token is made disjointly. Hence, at inference time, the server needs to preload all the experts. In comparison, task level routing is static given a fixed task, so the inference server for one task only needs to preload $k$ experts (assuming top-$k$ routing). According to their experiments, Task MoE can achieve similar performance gain as token MoE compared to dense model baseline with 2.6x higher peak throughput and 1.6% of the decoder size.\nTask level MoE is essentially to categorize a distribution of tasks according to predefined heuristics and incorporate such human knowledge into the router. When such heuristics do not exist (e.g. consider a general sentence continuation task), it would not be straightforward how to utilize Task MoE.\nPR-MoE (Pyramid residual MoE; Rajbhandari et al. 2022) has each token pass one fixed MLP and one chosen expert. Due to the observation that MoE at later layers is more beneficial, PR-MoE adopts more exports at later layers. DeepSpeed library implements a flexible multi-expert, multi-data parallelism to enable training PR-MoE with different numbers of experts across layers.\nFig. 17. Illustration of PR-MoE architecture in comparison with a standard MoE. (Image source: Rajbhandari et al. 2022) Kernel Improvement Expert networks can be hosted on different devices. However, when the number of GPUs increases, the number of experts per GPU decreases and the communication between experts (\u0026ldquo;All-to-all\u0026rdquo;) grows to be more expensive. All-to-all communication between experts across a number of GPUs relies on P2P APIs of NCCL, which cannot saturate the bandwidth of high-speed links (e.g. NVLink, HDR InfiniBand) at a large scale, as individual chunk gets smaller with more nodes used. The existing all-to-all algorithm performs poorly at large scale with a small workload. There are a variety of kernel improvements to enable more efficient MoE computation, such as making all-to-all communication cheaper/faster.\nBoth the DeepSpeed library (Rajbhandari et al. 2022) and TUTEL (Hwang et al. 2022) implemented a tree-based hierarchical all-to-all algorithm, which runs an intra-node all-to-all followed by an inter-node all-to-all. It reduces the communication hops from $O(G)$ to $O(G_\\text{node} + G / G_\\text{node})$, where $G$ is the total number of GPU nodes and $G_\\text{node}$ is the number of GPU cores per node. Although the communication volume is doubled in such implementation, it enables better scaling with small batches at large scale as the bottleneck is on latency instead of communication bandwidth when the batch size is small.\nDynaMoE (Kossmann et al. 2022) uses dynamic recompilation to adapt the computational resources to dynamic workloads among experts. The RECOMPILE mechanism compiles the computation graph from scratch and only reallocates resources when needed. It measures how many samples are assigned to each expert and adjusts their capacity factors $C$ dynamically, in order to reduce the memory and computation requirements at run time. Based on the observation that sample-expert assignments converge early in training, sample assignment caching is introduced after convergence and then RECOMPILE is used to eliminate the dependency between the gating network and experts.\nArchitectural Optimization The survey paper on Efficient Transformers (Tay et al. 2020) reviewed a collection of new transformer architectures with improvement for better computational and memory efficiency. Strongly recommend a read. You can also check out my post \u0026ldquo;The Transformer Family Version 2.0\u0026rdquo; for introduction to a diverse set of transformer archiecture improvements in depth, including changes to make the model cheaper to run.\nFig. 18. Categorization of efficient transformer models.(Image source: Tay et al. 2020) Since the self-attention mechanism has quadratic time and memory complexity and that is the main bottleneck for better transformer decoding efficiency, all the efficient transformer models have applied some form of sparsity to the otherwise dense attention layer. Here only lists a high-level overview, several derived from Tay et al. 2020.\nSparse Attention Patterns Fixed Patterns limit the field of view for the attention matrix, using predefined, fixed patterns.\n Chunk input sequences into fixed blocks, such as Blockwise Attention; Image Transformer uses local attention; Sparse Transformer uses strided attention patterns. Combined Patterns learn to sort/cluster the input tokens - enabling a more optimal global view of the sequence while maintaining the efficiency benefits of fixed patterns.\n Sparse Transformer combines strided and local attention; Given a high dimensional input tensor, instead of applying attention to the flattened version of the input, Axial Transformer applies multiple attentions, each along a single axis of the input tensor. ETC, Longformer and Big Bird combines local and global context, as well as strided or random attention. Learnable Patterns identify the optimal attention pattern via learning.\n Reformer clusters tokens into clusters based on hash-based similarity (LSH); Routing Transformer runs $k$-means clustering on tokens; Sinkhorn Sorting Network learns to sort blocks of input sequence. Recurrence Recurrence mechanism connects multiple blocks/segments via recurrence.\n Transformer-XL makes use of longer context by reusing hidden states between segments. Universal Transformer combines self-attention with the recurrent mechanism in RNN. Compressive Transformer is an extension of Transformer-XL with additional memory, containing a set of memory slots for past activiations and compressive memory slots for compressed activations. Whenever the model accepts a new input segment, the oldest activations in the primary memory are moved to the compressed memory where a compression function is applied. Memory Saving Designs Memory saving designs refer to changes of the architecture to use less memory.\n Linformer projects the length dimension of keys and values to a lower-dimensional representation ($N \\to k$) and thus the memory complexity is reduced from $N \\times N$ to $N \\times k$. Shazeer (2019) proposed multi-query attention which has the keys and values shared across different attention \u0026ldquo;heads\u0026rdquo;, greatly reducing the size of these tensors and the memory cost. Random feature attention and Performer use kernel methods to achieve a cheaper mathematical format of the self-attention mechanism. Adaptive Attention Adaptive attention enables the model to learn the optimal attention span or decide on when to do early exiting for different input tokens.\n Adaptive Attention Span trains the model to learn the optimal attention span per token per head via a soft mask between the token and other keys. Universal Transformer incorporates recurrent mechanism and uses ACT (Adaptive computation time) to dynamically decide the number of recurrent steps. Depth-Adaptive Transformer and CALM learns when to early exit the computation layers per token using some confidence measures to achieve good performance-efficiency tradeoffs. Citation Cited as:\n Weng, Lilian. (Jan 2023). Large Transformer Model Inference Optimization. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2023-01-10-inference-optimization/.\n Or\n@article{weng2023inference, title = \u0026quot;Large Transformer Model Inference Optimization\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;Lil'Log\u0026quot;, year = \u0026quot;2023\u0026quot;, month = \u0026quot;Jan\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2023-01-10-inference-optimization/\u0026quot; } References [1] Bondarenko et al. \u0026ldquo;Understanding and overcoming the challenges of efficient transformer quantization\u0026rdquo; ACL 2021.\n[2] Dettmers et al. \u0026ldquo;LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale\u0026rdquo; NeuriPS 2022\n[3] Zadeh et al. \u0026ldquo;Gobo: Quantizing attention-based NLP models for low latency and energy efficient inference.\u0026quot; MICRO 2020\n[4] Shen, Dong \u0026amp; Ye, et al. \u0026ldquo;Q-BERT: Hessian based ultra low precision quantization of BERT\u0026rdquo; AAAI 2020.\n[5] Yao et al. \u0026ldquo;ZeroQuant: Efficient and affordable post-training quantization for large-scale transformers\u0026rdquo; arXiv preprint arXiv:2206.01861 (2022).\n[6] Frantar et al. \u0026ldquo;GPTQ: Accurate Quantization for Generative Pre-trained Transformers\u0026rdquo; arXiv preprint arXiv:2210.17323 (2022).\n[7] Xiao \u0026amp; Lin \u0026ldquo;SmoothQuant: Accelerated sparse neural training: A provable and efficient method to find N:M transposable masks.\u0026quot; arXiv preprint arXiv:2211.10438 (2022). | code\n[8] Pool \u0026amp; Yu. \u0026ldquo;Channel Permutations for N:M Sparsity.\u0026quot; NeuriPS 2021. | code\n[9] Zhou \u0026amp; Ma, et al. \u0026ldquo;Learning N:M fine-grained structured sparse neural networks from scratch.\u0026quot; arXiv preprint arXiv:2102.04010 (2021).\n[10] Jayakumar et al. \u0026ldquo;Top-KAST: Top-K Always Sparse Training.\u0026quot; NeuriPS 2020.\n[11] Nvidia. \u0026ldquo;Nvidia A100 tensor core GPU architecture.\u0026quot; 2020.\n[12] Gale, Elsen \u0026amp; Hooker \u0026ldquo;The State of Sparsity in Deep Neural Networks.\u0026quot; arXiv preprint arXiv:1902.09574 (2019).\n[13] Zhu \u0026amp; Gupta. \u0026ldquo;To Prune, or Not to Prune: Exploring the Efficacy of Pruning for Model Compression.\u0026quot; arXiv preprint arXiv:1710.01878 (2017).\n[14] Renda et al. \u0026ldquo;Comparing rewinding and fine-tuning in neural network pruning.\u0026quot; arXiv preprint arXiv:2003.02389 (2020).\n[15] Zhou \u0026amp; Ma, et al. \u0026ldquo;Learning N:M fine-grained structured sparse neural networks from scratch.\u0026quot; arXiv preprint arXiv:2102.04010 (2021).\n[16] Pool \u0026amp; Yu. \u0026ldquo;Channel Permutations for N:M Sparsity.\u0026quot; NeuriPS 2021. | code\n[17] Jaszczur et al. \u0026ldquo;Sparse is Enough in Scaling Transformers.\u0026quot; NeuriPS 2021.\n[18] Mishra et al. \u0026ldquo;An Survey of Neural Network Compression.\u0026quot; arXiv preprint arXiv:1710.09282 (2017).\n[19] Fedus et al. \u0026ldquo;A Review of Sparse Expert Models in Deep Learning.\u0026quot; arXiv preprint arXiv:2209.01667 (2022)..\n[20] Riquelme et al. \u0026ldquo;Scaling vision with sparse mixture of experts.\u0026quot; NeuriPS 2021.\n[21] Kudugunta et al. \u0026ldquo;Beyond Distillation: Task-level Mixture-of-Experts for Efficient Inference.\u0026quot; arXiv preprint arXiv:2110.03742 (2021).\n[22] Rajbhandari et al. \u0026ldquo;DeepSpeed-MoE: Advancing mixture-of-experts inference and training to power next-generation ai scale.\u0026quot; arXiv preprint arXiv:2201.05596 (2022).\n[23] Kossmann et al. \u0026ldquo;Optimizing mixture of experts using dynamic recompilations.\u0026quot; arXiv preprint arXiv:2205.01848 (2022).\n[24] Hwang et al. \u0026ldquo;Tutel: Adaptive mixture-of-experts at scale.\u0026quot; arXiv preprint arXiv:2206.03382 (2022). | code\n[25] Noam Shazeer. \u0026ldquo;Fast Transformer Decoding: One Write-Head is All You Need.\u0026quot; arXiv preprint arXiv:1911.02150 (2019).\n[26] Tay et al. \u0026ldquo;Efficient Transformers: A Survey.\u0026quot; ACM Computing Surveys 55.6 (2022): 1-28.\n[27] Pope et al. \u0026ldquo;Efficiently Scaling Transformer Inference.\u0026quot; arXiv preprint arXiv:2211.05102 (2022).\n[28] Frankle \u0026amp; Carbin. \u0026ldquo;The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks\u0026rdquo; ICLR 2019.\n[29] Elabyad et al. \u0026ldquo;Depth-Adaptive Transformer\u0026rdquo; ICLR 2020.\n[30] Schuster et al. \u0026ldquo;Confident Adaptive Language Modeling\u0026rdquo; arXiv preprint arXiv:2207.07061 (2022).\n[31] Gou et al. \u0026ldquo;https://arxiv.org/abs/2006.05525\u0026rdquo; arXiv preprint arXiv:2006.05525 (2020).\n[32] Hinton et al. \u0026ldquo;Distilling the Knowledge in a Neural Network\u0026rdquo; NIPS 2014.\n[33] Sanh et al. \u0026ldquo;DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter\u0026rdquo; Workshop on Energy Efficient Machine Learning and Cognitive Computing @ NeuriPS 2019.\n","permalink":"https://lilianweng.github.io/posts/2023-01-10-inference-optimization/","summary":"[Updated on 2023-01-24: add a small section on Distillation.]\nLarge transformer models are mainstream nowadays, creating SoTA results for a variety of tasks. They are powerful but very expensive to train and use. The extremely high inference cost, in both time and memory, is a big bottleneck for adopting a powerful transformer for solving real-world tasks at scale.\nWhy is it hard to run inference for large transformer models? Besides the increasing size of SoTA models, there are two main factors contributing to the inference challenge (Pope et al.","title":"Large Transformer Model Inference Optimization"},{"content":"Neural networks are well known to be over-parameterized and can often easily fit data with near-zero training loss with decent generalization performance on test dataset. Although all these parameters are initialized at random, the optimization process can consistently lead to similarly good outcomes. And this is true even when the number of model parameters exceeds the number of training data points.\nNeural tangent kernel (NTK) (Jacot et al. 2018) is a kernel to explain the evolution of neural networks during training via gradient descent. It leads to great insights into why neural networks with enough width can consistently converge to a global minimum when trained to minimize an empirical loss. In the post, we will do a deep dive into the motivation and definition of NTK, as well as the proof of a deterministic convergence at different initializations of neural networks with infinite width by characterizing NTK in such a setting.\n 🤓 Different from my previous posts, this one mainly focuses on a small number of core papers, less on the breadth of the literature review in the field. There are many interesting works after NTK, with modification or expansion of the theory for understanding the learning dynamics of NNs, but they won\u0026rsquo;t be covered here. The goal is to show all the math behind NTK in a clear and easy-to-follow format, so the post is quite math-intensive. If you notice any mistakes, please let me know and I will be happy to correct them quickly. Thanks in advance!\n Basics This section contains reviews of several very basic concepts which are core to understanding of neural tangent kernel. Feel free to skip.\nVector-to-vector Derivative Given an input vector $\\mathbf{x} \\in \\mathbb{R}^n$ (as a column vector) and a function $f: \\mathbb{R}^n \\to \\mathbb{R}^m$, the derivative of $f$ with respective to $\\mathbf{x}$ is a $m\\times n$ matrix, also known as Jacobian matrix:\n $$ J = \\frac{\\partial f}{\\partial \\mathbf{x}} = \\begin{bmatrix} \\frac{\\partial f_1}{\\partial x_1} \u0026 \\dots \u0026\\frac{\\partial f_1}{\\partial x_n} \\\\ \\vdots \u0026 \u0026 \\\\ \\frac{\\partial f_m}{\\partial x_1} \u0026 \\dots \u0026\\frac{\\partial f_m}{\\partial x_n} \\\\ \\end{bmatrix} \\in \\mathbb{R}^{m \\times n} $$ Throughout the post, I use integer subscript(s) to refer to a single entry out of a vector or matrix value; i.e. $x_i$ indicates the $i$-th value in the vector $\\mathbf{x}$ and $f_i(.)$ is the $i$-th entry in the output of the function.\nThe gradient of a vector with respect to a vector is defined as $\\nabla_\\mathbf{x} f = J^\\top \\in \\mathbb{R}^{n \\times m}$ and this formation is also valid when $m=1$ (i.e., scalar output).\nDifferential Equations Differential equations describe the relationship between one or multiple functions and their derivatives. There are two main types of differential equations.\n (1) ODE (Ordinary differential equation) contains only an unknown function of one random variable. ODEs are the main form of differential equations used in this post. A general form of ODE looks like $(x, y, \\frac{dy}{dx}, \\dots, \\frac{d^ny}{dx^n}) = 0$. (2) PDE (Partial differential equation) contains unknown multivariable functions and their partial derivatives. Let\u0026rsquo;s review the simplest case of differential equations and its solution. Separation of variables (Fourier method) can be used when all the terms containing one variable can be moved to one side, while the other terms are all moved to the other side. For example,\n $$ \\begin{aligned} \\text{Given }a\\text{ is a constant scalar:}\\quad\\frac{dy}{dx} \u0026= ay \\\\ \\text{Move same variables to the same side:}\\quad\\frac{dy}{y} \u0026= adx \\\\ \\text{Put integral on both sides:}\\quad\\int \\frac{dy}{y} \u0026= \\int adx \\\\ \\ln (y) \u0026= ax + C' \\\\ \\text{Finally}\\quad y \u0026= e^{ax + C'} = C e^{ax} \\end{aligned} $$ Central Limit Theorem Given a collection of i.i.d. random variables, $x_1, \\dots, x_N$ with mean $\\mu$ and variance $\\sigma^2$, the Central Limit Theorem (CTL) states that the expectation would be Gaussian distributed when $N$ becomes really large.\n $$ \\bar{x} = \\frac{1}{N}\\sum_{i=1}^N x_i \\sim \\mathcal{N}(\\mu, \\frac{\\sigma^2}{n})\\quad\\text{when }N \\to \\infty $$ CTL can also apply to multidimensional vectors, and then instead of a single scale $\\sigma^2$ we need to compute the covariance matrix of random variable $\\Sigma$.\nTaylor Expansion The Taylor expansion is to express a function as an infinite sum of components, each represented in terms of this function\u0026rsquo;s derivatives. The Tayler expansion of a function $f(x)$ at $x=a$ can be written as: $$ f(x) = f(a) + \\sum_{k=1}^\\infty \\frac{1}{k!} (x - a)^k\\nabla^k_xf(x)\\vert_{x=a} $$ where $\\nabla^k$ denotes the $k$-th derivative.\nThe first-order Taylor expansion is often used as a linear approximation of the function value:\n $$ f(x) \\approx f(a) + (x - a)\\nabla_x f(x)\\vert_{x=a} $$ Kernel \u0026amp; Kernel Methods A kernel is essentially a similarity function between two data points, $K: \\mathcal{X} \\times \\mathcal{X} \\to \\mathbb{R}$. It describes how sensitive the prediction for one data sample is to the prediction for the other; or in other words, how similar two data points are. The kernel should be symmetric, $K(x, x') = K(x', x)$.\nDepending on the problem structure, some kernels can be decomposed into two feature maps, one corresponding to one data point, and the kernel value is an inner product of these two features: $K(x, x') = \\langle \\varphi(x), \\varphi(x') \\rangle$.\nKernel methods are a type of non-parametric, instance-based machine learning algorithms. Assuming we have known all the labels of training samples $\\{x^{(i)}, y^{(i)}\\}$, the label for a new input $x$ is predicted by a weighted sum $\\sum_{i} K(x^{(i)}, x)y^{(i)}$.\nGaussian Processes Gaussian process (GP) is a non-parametric method by modeling a multivariate Gaussian probability distribution over a collection of random variables. GP assumes a prior over functions and then updates the posterior over functions based on what data points are observed.\nGiven a collection of data points $\\{x^{(1)}, \\dots, x^{(N)}\\}$, GP assumes that they follow a jointly multivariate Gaussian distribution, defined by a mean $\\mu(x)$ and a covariance matrix $\\Sigma(x)$. Each entry at location $(i,j)$ in the covariance matrix $\\Sigma(x)$ is defined by a kernel $\\Sigma_{i,j} = K(x^{(i)}, x^{(j)})$, also known as a covariance function. The core idea is \u0026ndash; if two data points are deemed similar by the kernel, the function outputs should be close, too. Making predictions with GP for unknown data points is equivalent to drawing samples from this distribution, via a conditional distribution of unknown data points given observed ones.\nCheck this post for a high-quality and highly visualization tutorial on what Gaussian Processes are.\nNotation Let us consider a fully-connected neural networks with parameter $\\theta$, $f(.;\\theta): \\mathbb{R}^{n_0} \\to \\mathbb{R}^{n_L}$. Layers are indexed from 0 (input) to $L$ (output), each containing $n_0, \\dots, n_L$ neurons, including the input of size $n_0$ and the output of size $n_L$. There are $P = \\sum_{l=0}^{L-1} (n_l + 1) n_{l+1}$ parameters in total and thus we have $\\theta \\in \\mathbb{R}^P$.\nThe training dataset contains $N$ data points, $\\mathcal{D}=\\{\\mathbf{x}^{(i)}, y^{(i)}\\}_{i=1}^N$. All the inputs are denoted as $\\mathcal{X}=\\{\\mathbf{x}^{(i)}\\}_{i=1}^N$ and all the labels are denoted as $\\mathcal{Y}=\\{y^{(i)}\\}_{i=1}^N$.\nNow let\u0026rsquo;s look into the forward pass computation in every layer in detail. For $l=0, \\dots, L-1$, each layer $l$ defines an affine transformation $A^{(l)}$ with a weight matrix $\\mathbf{w}^{(l)} \\in \\mathbb{R}^{n_{l} \\times n_{l+1}}$ and a bias term $\\mathbf{b}^{(l)} \\in \\mathbb{R}^{n_{l+1}}$, as well as a pointwise nonlinearity function $\\sigma(.)$ which is Lipschitz continuous.\n $$ \\begin{aligned} A^{(0)} \u0026= \\mathbf{x} \\\\ \\tilde{A}^{(l+1)}(\\mathbf{x}) \u0026= \\frac{1}{\\sqrt{n_l}} {\\mathbf{w}^{(l)}}^\\top A^{(l)} + \\beta\\mathbf{b}^{(l)}\\quad\\in\\mathbb{R}^{n_{l+1}} \u0026 \\text{; pre-activations}\\\\ A^{(l+1)}(\\mathbf{x}) \u0026= \\sigma(\\tilde{A}^{(l+1)}(\\mathbf{x}))\\quad\\in\\mathbb{R}^{n_{l+1}} \u0026 \\text{; post-activations} \\end{aligned} $$ Note that the NTK parameterization applies a rescale weight $1/\\sqrt{n_l}$ on the transformation to avoid divergence with infinite-width networks. The constant scalar $\\beta \\geq 0$ controls how much effort the bias terms have.\nAll the network parameters are initialized as an i.i.d Gaussian $\\mathcal{N}(0, 1)$ in the following analysis.\nNeural Tangent Kernel Neural tangent kernel (NTK) (Jacot et al. 2018) is an important concept for understanding neural network training via gradient descent. At its core, it explains how updating the model parameters on one data sample affects the predictions for other samples.\nLet\u0026rsquo;s start with the intuition behind NTK, step by step.\nThe empirical loss function $\\mathcal{L}: \\mathbb{R}^P \\to \\mathbb{R}_+$ to minimize during training is defined as follows, using a per-sample cost function $\\ell: \\mathbb{R}^{n_0} \\times \\mathbb{R}^{n_L} \\to \\mathbb{R}_+$:\n $$ \\mathcal{L}(\\theta) =\\frac{1}{N} \\sum_{i=1}^N \\ell(f(\\mathbf{x}^{(i)}; \\theta), y^{(i)}) $$ and according to the chain rule. the gradient of the loss is:\n $$ \\nabla_\\theta \\mathcal{L}(\\theta)= \\frac{1}{N} \\sum_{i=1}^N \\underbrace{\\nabla_\\theta f(\\mathbf{x}^{(i)}; \\theta)}_{\\text{size }P \\times n_L} \\underbrace{\\nabla_f \\ell(f, y^{(i)})}_{\\text{size } n_L \\times 1} $$ When tracking how the network parameter $\\theta$ evolves in time, each gradient descent update introduces a small incremental change of an infinitesimal step size. Because of the update step is small enough, it can be approximately viewed as a derivative on the time dimension:\n $$ \\frac{d\\theta}{d t} = - \\nabla_\\theta\\mathcal{L}(\\theta) = -\\frac{1}{N} \\sum_{i=1}^N \\nabla_\\theta f(\\mathbf{x}^{(i)}; \\theta) \\nabla_f \\ell(f, y^{(i)}) $$ Again, by the chain rule, the network output evolves according to the derivative:\n $$ \\frac{df(\\mathbf{x};\\theta)}{dt} = \\frac{df(\\mathbf{x};\\theta)}{d\\theta}\\frac{d\\theta}{dt} = -\\frac{1}{N} \\sum_{i=1}^N \\color{blue}{\\underbrace{\\nabla_\\theta f(\\mathbf{x};\\theta)^\\top \\nabla_\\theta f(\\mathbf{x}^{(i)}; \\theta)}_\\text{Neural tangent kernel}} \\color{black}{\\nabla_f \\ell(f, y^{(i)})} $$ Here we find the Neural Tangent Kernel (NTK), as defined in the blue part in the above formula, $K: \\mathbb{R}^{n_0}\\times\\mathbb{R}^{n_0} \\to \\mathbb{R}^{n_L \\times n_L}$ :\n $$ K(\\mathbf{x}, \\mathbf{x}'; \\theta) = \\nabla_\\theta f(\\mathbf{x};\\theta)^\\top \\nabla_\\theta f(\\mathbf{x}'; \\theta) $$ where each entry in the output matrix at location $(m, n), 1 \\leq m, n \\leq n_L$ is:\n $$ K_{m,n}(\\mathbf{x}, \\mathbf{x}'; \\theta) = \\sum_{p=1}^P \\frac{\\partial f_m(\\mathbf{x};\\theta)}{\\partial \\theta_p} \\frac{\\partial f_n(\\mathbf{x}';\\theta)}{\\partial \\theta_p} $$ The \u0026ldquo;feature map\u0026rdquo; form of one input $\\mathbf{x}$ is $\\varphi(\\mathbf{x}) = \\nabla_\\theta f(\\mathbf{x};\\theta)$.\nInfinite Width Networks To understand why the effect of one gradient descent is so similar for different initializations of network parameters, several pioneering theoretical work starts with infinite width networks. We will look into detailed proof using NTK of how it guarantees that infinite width networks can converge to a global minimum when trained to minimize an empirical loss.\nConnection with Gaussian Processes Deep neural networks have deep connection with gaussian processes (Neal 1994). The output functions of a $L$-layer network, $f_i(\\mathbf{x}; \\theta)$ for $i=1, \\dots, n_L$ , are i.i.d. centered Gaussian process of covariance $\\Sigma^{(L)}$, defined recursively as:\n $$ \\begin{aligned} \\Sigma^{(1)}(\\mathbf{x}, \\mathbf{x}') \u0026= \\frac{1}{n_0}\\mathbf{x}^\\top{\\mathbf{x}'} + \\beta^2 \\\\ \\lambda^{(l+1)}(\\mathbf{x}, \\mathbf{x}') \u0026= \\begin{bmatrix} \\Sigma^{(l)}(\\mathbf{x}, \\mathbf{x}) \u0026 \\Sigma^{(l)}(\\mathbf{x}, \\mathbf{x}') \\\\ \\Sigma^{(l)}(\\mathbf{x}', \\mathbf{x}) \u0026 \\Sigma^{(l)}(\\mathbf{x}', \\mathbf{x}') \\end{bmatrix} \\\\ \\Sigma^{(l+1)}(\\mathbf{x}, \\mathbf{x}') \u0026= \\mathbb{E}_{f \\sim \\mathcal{N}(0, \\lambda^{(l)})}[\\sigma(f(\\mathbf{x})) \\sigma(f(\\mathbf{x}'))] + \\beta^2 \\end{aligned} $$ Lee \u0026amp; Bahri et al. (2018) showed a proof by mathematical induction:\n(1) Let\u0026rsquo;s start with $L=1$, when there is no nonlinearity function and the input is only processed by a simple affine transformation:\n $$ \\begin{aligned} f(\\mathbf{x};\\theta) = \\tilde{A}^{(1)}(\\mathbf{x}) \u0026= \\frac{1}{\\sqrt{n_0}}{\\mathbf{w}^{(0)}}^\\top\\mathbf{x} + \\beta\\mathbf{b}^{(0)} \\\\ \\text{where }\\tilde{A}_m^{(1)}(\\mathbf{x}) \u0026= \\frac{1}{\\sqrt{n_0}}\\sum_{i=1}^{n_0} w^{(0)}_{im}x_i + \\beta b^{(0)}_m\\quad \\text{for }1 \\leq m \\leq n_1 \\end{aligned} $$ Since the weights and biases are initialized i.i.d., all the output dimensions of this network ${\\tilde{A}^{(1)}_1(\\mathbf{x}), \\dots, \\tilde{A}^{(1)}_{n_1}(\\mathbf{x})}$ are also i.i.d. Given different inputs, the $m$-th network outputs $\\tilde{A}^{(1)}_m(.)$ have a joint multivariate Gaussian distribution, equivalent to a Gaussian process with covariance function (We know that mean $\\mu_w=\\mu_b=0$ and variance $\\sigma^2_w = \\sigma^2_b=1$)\n $$ \\begin{aligned} \\Sigma^{(1)}(\\mathbf{x}, \\mathbf{x}') \u0026= \\mathbb{E}[\\tilde{A}_m^{(1)}(\\mathbf{x})\\tilde{A}_m^{(1)}(\\mathbf{x}')] \\\\ \u0026= \\mathbb{E}\\Big[\\Big( \\frac{1}{\\sqrt{n_0}}\\sum_{i=1}^{n_0} w^{(0)}_{i,m}x_i + \\beta b^{(0)}_m \\Big) \\Big( \\frac{1}{\\sqrt{n_0}}\\sum_{i=1}^{n_0} w^{(0)}_{i,m}x'_i + \\beta b^{(0)}_m \\Big)\\Big] \\\\ \u0026= \\frac{1}{n_0} \\sigma^2_w \\sum_{i=1}^{n_0} \\sum_{j=1}^{n_0} x_i{x'}_j + \\frac{\\beta \\mu_b}{\\sqrt{n_0}} \\sum_{i=1}^{n_0} w_{im}(x_i + x'_i) + \\sigma^2_b \\beta^2 \\\\ \u0026= \\frac{1}{n_0}\\mathbf{x}^\\top{\\mathbf{x}'} + \\beta^2 \\end{aligned} $$ (2) Using induction, we first assume the proposition is true for $L=l$, a $l$-layer network, and thus $\\tilde{A}^{(l)}_m(.)$ is a Gaussian process with covariance $\\Sigma^{(l)}$ and $\\{\\tilde{A}^{(l)}_i\\}_{i=1}^{n_l}$ are i.i.d.\nThen we need to prove the proposition is also true for $L=l+1$. We compute the outputs by:\n $$ \\begin{aligned} f(\\mathbf{x};\\theta) = \\tilde{A}^{(l+1)}(\\mathbf{x}) \u0026= \\frac{1}{\\sqrt{n_l}}{\\mathbf{w}^{(l)}}^\\top \\sigma(\\tilde{A}^{(l)}(\\mathbf{x})) + \\beta\\mathbf{b}^{(l)} \\\\ \\text{where }\\tilde{A}^{(l+1)}_m(\\mathbf{x}) \u0026= \\frac{1}{\\sqrt{n_l}}\\sum_{i=1}^{n_l} w^{(l)}_{im}\\sigma(\\tilde{A}^{(l)}_i(\\mathbf{x})) + \\beta b^{(l)}_m \\quad \\text{for }1 \\leq m \\leq n_{l+1} \\end{aligned} $$ We can infer that the expectation of the sum of contributions of the previous hidden layers is zero:\n $$ \\begin{aligned} \\mathbb{E}[w^{(l)}_{im}\\sigma(\\tilde{A}^{(l)}_i(\\mathbf{x}))] \u0026= \\mathbb{E}[w^{(l)}_{im}]\\mathbb{E}[\\sigma(\\tilde{A}^{(l)}_i(\\mathbf{x}))] = \\mu_w \\mathbb{E}[\\sigma(\\tilde{A}^{(l)}_i(\\mathbf{x}))] = 0 \\\\ \\mathbb{E}[\\big(w^{(l)}_{im}\\sigma(\\tilde{A}^{(l)}_i(\\mathbf{x}))\\big)^2] \u0026= \\mathbb{E}[{w^{(l)}_{im}}^2]\\mathbb{E}[\\sigma(\\tilde{A}^{(l)}_i(\\mathbf{x}))^2] = \\sigma_w^2 \\Sigma^{(l)}(\\mathbf{x}, \\mathbf{x}) = \\Sigma^{(l)}(\\mathbf{x}, \\mathbf{x}) \\end{aligned} $$ Since $\\{\\tilde{A}^{(l)}_i(\\mathbf{x})\\}_{i=1}^{n_l}$ are i.i.d., according to central limit theorem, when the hidden layer gets infinitely wide $n_l \\to \\infty$, $\\tilde{A}^{(l+1)}_m(\\mathbf{x})$ is Gaussian distributed with variance $\\beta^2 + \\text{Var}(\\tilde{A}_i^{(l)}(\\mathbf{x}))$. Note that ${\\tilde{A}^{(l+1)}_1(\\mathbf{x}), \\dots, \\tilde{A}^{(l+1)}_{n_{l+1}}(\\mathbf{x})}$ are still i.i.d.\n$\\tilde{A}^{(l+1)}_m(.)$ is equivalent to a Gaussian process with covariance function:\n $$ \\begin{aligned} \\Sigma^{(l+1)}(\\mathbf{x}, \\mathbf{x}') \u0026= \\mathbb{E}[\\tilde{A}^{(l+1)}_m(\\mathbf{x})\\tilde{A}^{(l+1)}_m(\\mathbf{x}')] \\\\ \u0026= \\frac{1}{n_l} \\sigma\\big(\\tilde{A}^{(l)}_i(\\mathbf{x})\\big)^\\top \\sigma\\big(\\tilde{A}^{(l)}_i(\\mathbf{x}')\\big) + \\beta^2 \\quad\\text{;similar to how we get }\\Sigma^{(1)} \\end{aligned} $$ When $n_l \\to \\infty$, according to central limit theorem,\n $$ \\Sigma^{(l+1)}(\\mathbf{x}, \\mathbf{x}') \\to \\mathbb{E}_{f \\sim \\mathcal{N}(0, \\Lambda^{(l)})}[\\sigma(f(\\mathbf{x}))^\\top \\sigma(f(\\mathbf{x}'))] + \\beta^2 $$ The form of Gaussian processes in the above process is referred to as the Neural Network Gaussian Process (NNGP) (Lee \u0026amp; Bahri et al. (2018)).\nDeterministic Neural Tangent Kernel Finally we are now prepared enough to look into the most critical proposition from the NTK paper:\nWhen $n_1, \\dots, n_L \\to \\infty$ (network with infinite width), the NTK converges to be:\n (1) deterministic at initialization, meaning that the kernel is irrelevant to the initialization values and only determined by the model architecture; and (2) stays constant during training. The proof depends on mathematical induction as well:\n(1) First of all, we always have $K^{(0)} = 0$. When $L=1$, we can get the representation of NTK directly. It is deterministic and does not depend on the network initialization. There is no hidden layer, so there is nothing to take on infinite width.\n $$ \\begin{aligned} f(\\mathbf{x};\\theta) \u0026= \\tilde{A}^{(1)}(\\mathbf{x}) = \\frac{1}{\\sqrt{n_0}} {\\mathbf{w}^{(0)}}^\\top\\mathbf{x} + \\beta\\mathbf{b}^{(0)} \\\\ K^{(1)}(\\mathbf{x}, \\mathbf{x}';\\theta) \u0026= \\Big(\\frac{\\partial f(\\mathbf{x}';\\theta)}{\\partial \\mathbf{w}^{(0)}}\\Big)^\\top \\frac{\\partial f(\\mathbf{x};\\theta)}{\\partial \\mathbf{w}^{(0)}} + \\Big(\\frac{\\partial f(\\mathbf{x}';\\theta)}{\\partial \\mathbf{b}^{(0)}}\\Big)^\\top \\frac{\\partial f(\\mathbf{x};\\theta)}{\\partial \\mathbf{b}^{(0)}} \\\\ \u0026= \\frac{1}{n_0} \\mathbf{x}^\\top{\\mathbf{x}'} + \\beta^2 = \\Sigma^{(1)}(\\mathbf{x}, \\mathbf{x}') \\end{aligned} $$ (2) Now when $L=l$, we assume that a $l$-layer network with $\\tilde{P}$ parameters in total, $\\tilde{\\theta} = (\\mathbf{w}^{(0)}, \\dots, \\mathbf{w}^{(l-1)}, \\mathbf{b}^{(0)}, \\dots, \\mathbf{b}^{(l-1)}) \\in \\mathbb{R}^\\tilde{P}$, has a NTK converging to a deterministic limit when $n_1, \\dots, n_{l-1} \\to \\infty$.\n $$ K^{(l)}(\\mathbf{x}, \\mathbf{x}';\\tilde{\\theta}) = \\nabla_{\\tilde{\\theta}} \\tilde{A}^{(l)}(\\mathbf{x})^\\top \\nabla_{\\tilde{\\theta}} \\tilde{A}^{(l)}(\\mathbf{x}') \\to K^{(l)}_{\\infty}(\\mathbf{x}, \\mathbf{x}') $$ Note that $K_\\infty^{(l)}$ has no dependency on $\\theta$.\nNext let\u0026rsquo;s check the case $L=l+1$. Compared to a $l$-layer network, a $(l+1)$-layer network has additional weight matrix $\\mathbf{w}^{(l)}$ and bias $\\mathbf{b}^{(l)}$ and thus the total parameters contain $\\theta = (\\tilde{\\theta}, \\mathbf{w}^{(l)}, \\mathbf{b}^{(l)})$.\nThe output function of this $(l+1)$-layer network is:\n $$ f(\\mathbf{x};\\theta) = \\tilde{A}^{(l+1)}(\\mathbf{x};\\theta) = \\frac{1}{\\sqrt{n_l}} {\\mathbf{w}^{(l)}}^\\top \\sigma\\big(\\tilde{A}^{(l)}(\\mathbf{x})\\big) + \\beta \\mathbf{b}^{(l)} $$ And we know its derivative with respect to different sets of parameters; let denote $\\tilde{A}^{(l)} = \\tilde{A}^{(l)}(\\mathbf{x})$ for brevity in the following equation:\n $$ \\begin{aligned} \\nabla_{\\color{blue}{\\mathbf{w}^{(l)}}} f(\\mathbf{x};\\theta) \u0026= \\color{blue}{ \\frac{1}{\\sqrt{n_l}} \\sigma\\big(\\tilde{A}^{(l)}\\big)^\\top } \\color{black}{\\quad \\in \\mathbb{R}^{1 \\times n_l}} \\\\ \\nabla_{\\color{green}{\\mathbf{b}^{(l)}}} f(\\mathbf{x};\\theta) \u0026= \\color{green}{ \\beta } \\\\ \\nabla_{\\color{red}{\\tilde{\\theta}}} f(\\mathbf{x};\\theta) \u0026= \\frac{1}{\\sqrt{n_l}} \\nabla_\\tilde{\\theta}\\sigma(\\tilde{A}^{(l)}) \\mathbf{w}^{(l)} \\\\ \u0026= \\color{red}{ \\frac{1}{\\sqrt{n_l}} \\begin{bmatrix} \\dot{\\sigma}(\\tilde{A}_1^{(l)})\\frac{\\partial \\tilde{A}_1^{(l)}}{\\partial \\tilde{\\theta}_1} \u0026 \\dots \u0026 \\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})\\frac{\\partial \\tilde{A}_{n_l}^{(l)}}{\\partial \\tilde{\\theta}_1} \\\\ \\vdots \\\\ \\dot{\\sigma}(\\tilde{A}_1^{(l)})\\frac{\\partial \\tilde{A}_1^{(l)}}{\\partial \\tilde{\\theta}_\\tilde{P}} \u0026 \\dots \u0026 \\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})\\frac{\\partial \\tilde{A}_{n_l}^{(l)}}{\\partial \\tilde{\\theta}_\\tilde{P}}\\\\ \\end{bmatrix} \\mathbf{w}^{(l)} \\color{black}{\\quad \\in \\mathbb{R}^{\\tilde{P} \\times n_{l+1}}} } \\end{aligned} $$ where $\\dot{\\sigma}$ is the derivative of $\\sigma$ and each entry at location $(p, m), 1 \\leq p \\leq \\tilde{P}, 1 \\leq m \\leq n_{l+1}$ in the matrix $\\nabla_{\\tilde{\\theta}} f(\\mathbf{x};\\theta)$ can be written as\n $$ \\frac{\\partial f_m(\\mathbf{x};\\theta)}{\\partial \\tilde{\\theta}_p} = \\sum_{i=1}^{n_l} w^{(l)}_{im} \\dot{\\sigma}\\big(\\tilde{A}_i^{(l)} \\big) \\nabla_{\\tilde{\\theta}_p} \\tilde{A}_i^{(l)} $$ The NTK for this $(l+1)$-layer network can be defined accordingly:\n $$ \\begin{aligned} \u0026 K^{(l+1)}(\\mathbf{x}, \\mathbf{x}'; \\theta) \\\\ =\u0026 \\nabla_{\\theta} f(\\mathbf{x};\\theta)^\\top \\nabla_{\\theta} f(\\mathbf{x};\\theta) \\\\ =\u0026 \\color{blue}{\\nabla_{\\mathbf{w}^{(l)}} f(\\mathbf{x};\\theta)^\\top \\nabla_{\\mathbf{w}^{(l)}} f(\\mathbf{x};\\theta)} + \\color{green}{\\nabla_{\\mathbf{b}^{(l)}} f(\\mathbf{x};\\theta)^\\top \\nabla_{\\mathbf{b}^{(l)}} f(\\mathbf{x};\\theta)} + \\color{red}{\\nabla_{\\tilde{\\theta}} f(\\mathbf{x};\\theta)^\\top \\nabla_{\\tilde{\\theta}} f(\\mathbf{x};\\theta)} \\\\ =\u0026 \\frac{1}{n_l} \\Big[ \\color{blue}{\\sigma(\\tilde{A}^{(l)})\\sigma(\\tilde{A}^{(l)})^\\top} + \\color{green}{\\beta^2} \\\\ \u0026+ \\color{red}{ {\\mathbf{w}^{(l)}}^\\top \\begin{bmatrix} \\dot{\\sigma}(\\tilde{A}_1^{(l)})\\dot{\\sigma}(\\tilde{A}_1^{(l)})\\sum_{p=1}^\\tilde{P} \\frac{\\partial \\tilde{A}_1^{(l)}}{\\partial \\tilde{\\theta}_p}\\frac{\\partial \\tilde{A}_1^{(l)}}{\\partial \\tilde{\\theta}_p} \u0026 \\dots \u0026 \\dot{\\sigma}(\\tilde{A}_1^{(l)})\\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})\\sum_{p=1}^\\tilde{P} \\frac{\\partial \\tilde{A}_1^{(l)}}{\\partial \\tilde{\\theta}_p}\\frac{\\partial \\tilde{A}_{n_l}^{(l)}}{\\partial \\tilde{\\theta}_p} \\\\ \\vdots \\\\ \\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})\\dot{\\sigma}(\\tilde{A}_1^{(l)})\\sum_{p=1}^\\tilde{P} \\frac{\\partial \\tilde{A}_{n_l}^{(l)}}{\\partial \\tilde{\\theta}_p}\\frac{\\partial \\tilde{A}_1^{(l)}}{\\partial \\tilde{\\theta}_p} \u0026 \\dots \u0026 \\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})\\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})\\sum_{p=1}^\\tilde{P} \\frac{\\partial \\tilde{A}_{n_l}^{(l)}}{\\partial \\tilde{\\theta}_p}\\frac{\\partial \\tilde{A}_{n_l}^{(l)}}{\\partial \\tilde{\\theta}_p} \\\\ \\end{bmatrix} \\mathbf{w}^{(l)} } \\color{black}{\\Big]} \\\\ =\u0026 \\frac{1}{n_l} \\Big[ \\color{blue}{\\sigma(\\tilde{A}^{(l)})\\sigma(\\tilde{A}^{(l)})^\\top} + \\color{green}{\\beta^2} \\\\ \u0026+ \\color{red}{ {\\mathbf{w}^{(l)}}^\\top \\begin{bmatrix} \\dot{\\sigma}(\\tilde{A}_1^{(l)})\\dot{\\sigma}(\\tilde{A}_1^{(l)})K^{(l)}_{11} \u0026 \\dots \u0026 \\dot{\\sigma}(\\tilde{A}_1^{(l)})\\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})K^{(l)}_{1n_l} \\\\ \\vdots \\\\ \\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})\\dot{\\sigma}(\\tilde{A}_1^{(l)})K^{(l)}_{n_l1} \u0026 \\dots \u0026 \\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})\\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})K^{(l)}_{n_ln_l} \\\\ \\end{bmatrix} \\mathbf{w}^{(l)} } \\color{black}{\\Big]} \\end{aligned} $$ where each individual entry at location $(m, n), 1 \\leq m, n \\leq n_{l+1}$ of the matrix $K^{(l+1)}$ can be written as:\n $$ \\begin{aligned} K^{(l+1)}_{mn} =\u0026 \\frac{1}{n_l}\\Big[ \\color{blue}{\\sigma(\\tilde{A}_m^{(l)})\\sigma(\\tilde{A}_n^{(l)})} + \\color{green}{\\beta^2} + \\color{red}{ \\sum_{i=1}^{n_l} \\sum_{j=1}^{n_l} w^{(l)}_{im} w^{(l)}_{in} \\dot{\\sigma}(\\tilde{A}_i^{(l)}) \\dot{\\sigma}(\\tilde{A}_{j}^{(l)}) K_{ij}^{(l)} } \\Big] \\end{aligned} $$ When $n_l \\to \\infty$, the section in blue and green has the limit (See the proof in the previous section):\n $$ \\frac{1}{n_l}\\sigma(\\tilde{A}^{(l)})\\sigma(\\tilde{A}^{(l)}) + \\beta^2\\to \\Sigma^{(l+1)} $$ and the red section has the limit:\n $$ \\sum_{i=1}^{n_l} \\sum_{j=1}^{n_l} w^{(l)}_{im} w^{(l)}_{in} \\dot{\\sigma}(\\tilde{A}_i^{(l)}) \\dot{\\sigma}(\\tilde{A}_{j}^{(l)}) K_{ij}^{(l)} \\to \\sum_{i=1}^{n_l} \\sum_{j=1}^{n_l} w^{(l)}_{im} w^{(l)}_{in} \\dot{\\sigma}(\\tilde{A}_i^{(l)}) \\dot{\\sigma}(\\tilde{A}_{j}^{(l)}) K_{\\infty,ij}^{(l)} $$ Later, Arora et al. (2019) provided a proof with a weaker limit, that does not require all the hidden layers to be infinitely wide, but only requires the minimum width to be sufficiently large.\nLinearized Models From the previous section, according to the derivative chain rule, we have known that the gradient update on the output of an infinite width network is as follows; For brevity, we omit the inputs in the following analysis:\n $$ \\begin{aligned} \\frac{df(\\theta)}{dt} \u0026= -\\eta\\nabla_\\theta f(\\theta)^\\top \\nabla_\\theta f(\\theta) \\nabla_f \\mathcal{L} \u0026 \\\\ \u0026= -\\eta\\nabla_\\theta f(\\theta)^\\top \\nabla_\\theta f(\\theta) \\nabla_f \\mathcal{L} \u0026 \\\\ \u0026= -\\eta K(\\theta) \\nabla_f \\mathcal{L} \\\\ \u0026= \\color{cyan}{-\\eta K_\\infty \\nabla_f \\mathcal{L}} \u0026 \\text{; for infinite width network}\\\\ \\end{aligned} $$ To track the evolution of $\\theta$ in time, let\u0026rsquo;s consider it as a function of time step $t$. With Taylor expansion, the network learning dynamics can be simplified as:\n $$ f(\\theta(t)) \\approx f^\\text{lin}(\\theta(t)) = f(\\theta(0)) + \\underbrace{\\nabla_\\theta f(\\theta(0))}_{\\text{formally }\\nabla_\\theta f(\\mathbf{x}; \\theta) \\vert_{\\theta=\\theta(0)}} (\\theta(t) - \\theta(0)) $$ Such formation is commonly referred to as the linearized model, given $\\theta(0)$, $f(\\theta(0))$, and $\\nabla_\\theta f(\\theta(0))$ are all constants. Assuming that the incremental time step $t$ is extremely small and the parameter is updated by gradient descent:\n $$ \\begin{aligned} \\theta(t) - \\theta(0) \u0026= - \\eta \\nabla_\\theta \\mathcal{L}(\\theta) = - \\eta \\nabla_\\theta f(\\theta)^\\top \\nabla_f \\mathcal{L} \\\\ f^\\text{lin}(\\theta(t)) - f(\\theta(0)) \u0026= - \\eta\\nabla_\\theta f(\\theta(0))^\\top \\nabla_\\theta f(\\mathcal{X};\\theta(0)) \\nabla_f \\mathcal{L} \\\\ \\frac{df(\\theta(t))}{dt} \u0026= - \\eta K(\\theta(0)) \\nabla_f \\mathcal{L} \\\\ \\frac{df(\\theta(t))}{dt} \u0026= \\color{cyan}{- \\eta K_\\infty \\nabla_f \\mathcal{L}} \u0026 \\text{; for infinite width network}\\\\ \\end{aligned} $$ Eventually we get the same learning dynamics, which implies that a neural network with infinite width can be considerably simplified as governed by the above linearized model (Lee \u0026amp; Xiao, et al. 2019).\nIn a simple case when the empirical loss is an MSE loss, $\\nabla_\\theta \\mathcal{L}(\\theta) = f(\\mathcal{X}; \\theta) - \\mathcal{Y}$, the dynamics of the network becomes a simple linear ODE and it can be solved in a closed form:\n $$ \\begin{aligned} \\frac{df(\\theta)}{dt} =\u0026 -\\eta K_\\infty (f(\\theta) - \\mathcal{Y}) \u0026 \\\\ \\frac{dg(\\theta)}{dt} =\u0026 -\\eta K_\\infty g(\\theta) \u0026 \\text{; let }g(\\theta)=f(\\theta) - \\mathcal{Y} \\\\ \\int \\frac{dg(\\theta)}{g(\\theta)} =\u0026 -\\eta \\int K_\\infty dt \u0026 \\\\ g(\\theta) \u0026= C e^{-\\eta K_\\infty t} \u0026 \\end{aligned} $$ When $t=0$, we have $C=f(\\theta(0)) - \\mathcal{Y}$ and therefore,\n $$ f(\\theta) = (f(\\theta(0)) - \\mathcal{Y})e^{-\\eta K_\\infty t} + \\mathcal{Y} \\\\ = f(\\theta(0))e^{-K_\\infty t} + (I - e^{-\\eta K_\\infty t})\\mathcal{Y} $$ Lazy Training People observe that when a neural network is heavily over-parameterized, the model is able to learn with the training loss quickly converging to zero but the network parameters hardly change. Lazy training refers to the phenomenon. In other words, when the loss $\\mathcal{L}$ has a decent amount of reduction, the change in the differential of the network $f$ (aka the Jacobian matrix) is still very small.\nLet $\\theta(0)$ be the initial network parameters and $\\theta(T)$ be the final network parameters when the loss has been minimized to zero. The delta change in parameter space can be approximated with first-order Taylor expansion:\n $$ \\begin{aligned} \\hat{y} = f(\\theta(T)) \u0026\\approx f(\\theta(0)) + \\nabla_\\theta f(\\theta(0)) (\\theta(T) - \\theta(0)) \\\\ \\text{Thus }\\Delta \\theta \u0026= \\theta(T) - \\theta(0) \\approx \\frac{\\|\\hat{y} - f(\\theta(0))\\|}{\\| \\nabla_\\theta f(\\theta(0)) \\|} \\end{aligned} $$ Still following the first-order Taylor expansion, we can track the change in the differential of $f$:\n $$ \\begin{aligned} \\nabla_\\theta f(\\theta(T)) \u0026\\approx \\nabla_\\theta f(\\theta(0)) + \\nabla^2_\\theta f(\\theta(0)) \\Delta\\theta \\\\ \u0026= \\nabla_\\theta f(\\theta(0)) + \\nabla^2_\\theta f(\\theta(0)) \\frac{\\|\\hat{y} - f(\\mathbf{x};\\theta(0))\\|}{\\| \\nabla_\\theta f(\\theta(0)) \\|} \\\\ \\text{Thus }\\Delta\\big(\\nabla_\\theta f\\big) \u0026= \\nabla_\\theta f(\\theta(T)) - \\nabla_\\theta f(\\theta(0)) = \\|\\hat{y} - f(\\mathbf{x};\\theta(0))\\| \\frac{\\nabla^2_\\theta f(\\theta(0))}{\\| \\nabla_\\theta f(\\theta(0)) \\|} \\end{aligned} $$ Let $\\kappa(\\theta)$ be the relative change of the differential of $f$ to the change in the parameter space:\n $$ \\kappa(\\theta = \\frac{\\Delta\\big(\\nabla_\\theta f\\big)}{\\| \\nabla_\\theta f(\\theta(0)) \\|} = \\|\\hat{y} - f(\\theta(0))\\| \\frac{\\nabla^2_\\theta f(\\theta(0))}{\\| \\nabla_\\theta f(\\theta(0)) \\|^2} $$ Chizat et al. (2019) showed the proof for a two-layer neural network that $\\mathbb{E}[\\kappa(\\theta_0)] \\to 0$ (getting into the lazy regime) when the number of hidden neurons $\\to \\infty$. Also, recommend this post for more discussion on linearized models and lazy training.\nCitation Cited as:\n Weng, Lilian. (Sep 2022). Some math behind neural tangent kernel. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2022-09-08-ntk/.\n Or\n@article{weng2022ntk, title = \u0026quot;Some Math behind Neural Tangent Kernel\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;Lil'Log\u0026quot;, year = \u0026quot;2022\u0026quot;, month = \u0026quot;Sep\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2022-09-08-ntk/\u0026quot; } References [1] Jacot et al. \u0026ldquo;Neural Tangent Kernel: Convergence and Generalization in Neural Networks.\u0026quot; NeuriPS 2018.\n[2]Radford M. Neal. \u0026ldquo;Priors for Infinite Networks.\u0026quot; Bayesian Learning for Neural Networks. Springer, New York, NY, 1996. 29-53.\n[3] Lee \u0026amp; Bahri et al. \u0026ldquo;Deep Neural Networks as Gaussian Processes.\u0026quot; ICLR 2018.\n[4] Chizat et al. \u0026ldquo;On Lazy Training in Differentiable Programming\u0026rdquo; NeuriPS 2019.\n[5] Lee \u0026amp; Xiao, et al. \u0026ldquo;Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent.\u0026quot; NeuriPS 2019.\n[6] Arora, et al. \u0026ldquo;On Exact Computation with an Infinitely Wide Neural Net.\u0026quot; NeurIPS 2019.\n[7] (YouTube video) \u0026ldquo;Neural Tangent Kernel: Convergence and Generalization in Neural Networks\u0026rdquo; by Arthur Jacot, Nov 2018.\n[8] (YouTube video) \u0026ldquo;Lecture 7 - Deep Learning Foundations: Neural Tangent Kernels\u0026rdquo; by Soheil Feizi, Sep 2020.\n[9] \u0026ldquo;Understanding the Neural Tangent Kernel.\u0026quot; Rajat\u0026rsquo;s Blog.\n[10] \u0026ldquo;Neural Tangent Kernel.\u0026quot;Applied Probability Notes, Mar 2021.\n[11] \u0026ldquo;Some Intuition on the Neural Tangent Kernel.\u0026quot; inFERENCe, Nov 2020.\n","permalink":"https://lilianweng.github.io/posts/2022-09-08-ntk/","summary":"Neural networks are well known to be over-parameterized and can often easily fit data with near-zero training loss with decent generalization performance on test dataset. Although all these parameters are initialized at random, the optimization process can consistently lead to similarly good outcomes. And this is true even when the number of model parameters exceeds the number of training data points.\nNeural tangent kernel (NTK) (Jacot et al. 2018) is a kernel to explain the evolution of neural networks during training via gradient descent.","title":"Some Math behind Neural Tangent Kernel"},{"content":"Processing images to generate text, such as image captioning and visual question-answering, has been studied for years. Traditionally such systems rely on an object detection network as a vision encoder to capture visual features and then produce text via a text decoder. Given a large amount of existing literature, in this post, I would like to only focus on one approach for solving vision language tasks, which is to extend pre-trained generalized language models to be capable of consuming visual signals.\nI roughly group such vision language models (VLMs) into four buckets:\n Translating images into embedding features that can be jointly trained with token embeddings. Learning good image embeddings that can work as a prefix for a frozen, pre-trained language model. Using a specially designed cross-attention mechanism to fuse visual information into layers of the language model. Combine vision and language models without any training. Jointly Training with Image and Text One straightforward approach to fuse visual information into language models is to treat images as normal text tokens and train the model on a sequence of joint representations of both text and images. Precisely, images are divided into multiple smaller patches and each patch is treated as one \u0026ldquo;token\u0026rdquo; in the input sequence.\nVisualBERT (Li et al. 2019) feeds both text inputs and image regions into BERT such that it is able to discover the internal alignment between images and text with self-attention mechanism.\nFig. 1. VisualBERT is trained on the combination of both text and image embeddings. (Image source: Li et al. 2019) Similar to text embedding in BERT, each visual embedding in VisualBERT also sums up three types of embeddings, tokenized features $f_o$, segmentation embedding $f_s$ and position embedding $f_p$, precisely:\n $f_o$ is a visual feature vector computed for a bounding region of the image by a convolutional neural network; $f_s$ is a segment embedding to indicate whether the embedding is for vision not for text; $f_p$ is a position embedding used for aligning the order of bounding regions. The model is trained on MS COCO image caption dataset with both text and image as inputs to predict text captions, using two visually-grounded language model objectives:\n MLM with the image. The model needs to predict masked text tokens, while image embeddings always stay not masked. Sentence-image prediction. When provided with an image and two associated captions, one of two captions might be a random unrelated caption with 50% probability. The model is asked to distinguish these two situations. According to ablation experiments, the most important configuration is to fuse visual information early on into the transformer layers and to pretrain the model on the COCO caption dataset. Initialization from a pre-trained BERT and the adoption of the sentence-image prediction training objective have relatively small impacts.\nFig. 2. Ablation study results of VisualBERT on NLVR. (Image source: Li et al. 2019) VisualBERT outperforms SoTA at the time on NLVR and Flickr30K, but still has some performance gap with SoTA on VQA.\nSimVLM (Simple Visual Language Model; Wang et al. 2022) is a simple prefix language model, where the prefix sequence is processed with bi-directional attention like BERT, but the main input sequence only has causal attention like GPT. Images are encoded as prefix tokens such that the model can fully consume the visual information and then generates associated text in an autoregressive manner.\nInspired by ViT and CoAtNet, SimVLM splits the image into smaller patches in a flatten 1D sequence of patches. They use the convolutional stage consisting of the first 3 blocks of ResNet to extract contextualized patches and this setup is found to work better than a naive linear projection.\nFig. 3. Training architecture for SimVLM, where the image patches are processed by the cross-attention encoder and the text decoder has causal attention. (Image source: Wang et al. 2022) Training data for SimVLM consists of a large number of image-text pairs from ALIGN (Jia et al. 2021) and text-only data from C4 dataset (Raffel et al. 2019). They mix the two pretraining datasets within each batch, containing 4,096 image-text pairs (ALIGN) and 512 text-only documents (C4).\nAccording to ablation studies, it is important to have both image-text and text-only data for training. The PrefixLM objective outperforms both span corruption and naive LM.\nFig. 4. Ablation study results of SimVLM on VQA. (Image source: Wang et al. 2022) CM3 (Causally-Masked Multimodal Modeling; Aghajanyan, et al. 2022) is a hyper-text language model, learning to generate the content (hypertext markup, hyperlinks and images) of large scale HTML web pages of CC-NEWS and Wikipedia articles. The resulting CM3 models can generate rich structured, multi-modal outputs while conditioning on arbitrary masked document contexts.\nArchitecture-wise, CM3 is an autoregressive model. However, in order to combine causal and masked language modeling, CM3 also masks out a small number of long token spans and tries to generate them at the end of the sequences.\nFig. 5. Illustration of how a causally masked language model works. (Image source: Aghajanyan, et al. 2022) The training dataset for CM3 contains close to 1T Web data. During preprocessing, images are first downloaded from src and resized to 256 x 256 with random cropping. Then they are tokenized by VQVAE-GAN, resulting in 256 tokens per image. These tokens, joined with spaces, are inserted back into the src attribute.\nCM3 can be used to complete several types of tasks by prompt engineering:\n Image in-filling: Infilling Prompt: \u0026lt;img src=\u0026quot;{prefix}\u0026lt;mask:0\u0026gt;{postfix}\u0026quot;\u0026gt;\u0026lt;mask:0\u0026gt; Conditional image in-filling: Conditional Infilling Prompt: \u0026lt;img alt=\u0026quot;Photo: {text}\u0026quot; src=\u0026quot;{prefix}\u0026lt;mask:0\u0026gt;{postfix}\u0026quot;\u0026gt;\u0026lt;mask:0\u0026gt; Conditional image generation: Conditional Generation Prompt: \u0026lt;img alt=\u0026quot;{prompt} Image captions: Captioning Masked Prompt #1: \u0026lt;img alt=\u0026quot;Photo: A photo taken of\u0026lt;mask:0\u0026gt;\u0026quot; src=\u0026quot;{image}\u0026quot;\u0026gt; Captioning Causal Prompt #1: \u0026lt;img src=\u0026quot;{image}\u0026quot; title=\u0026quot;Photo: A photo taken of Entity disambiguation Original: Manetho writes that these kings ruled from \u0026lt;a title=\u0026quot;Memphis, Egypt\u0026quot;\u0026gt;Memphis\u0026lt;/a\u0026gt; Prompt: Manetho writes that these kings ruled from \u0026lt;a title=\u0026quot;\u0026lt;mask:0\u0026gt;\u0026quot;\u0026gt;Memphis\u0026lt;/a\u0026gt;...\u0026lt;mask:0\u0026gt; Target: Manetho writes that these kings ruled from \u0026lt;a title=\u0026quot;\u0026lt;mask:0\u0026gt;\u0026quot;\u0026gt;Memphis\u0026lt;/a\u0026gt;...\u0026lt;mask:0\u0026gt; Memphis, Egypt Learned Image Embedding as (Frozen) LM Prefix What if we don’t want to change the language model parameters when adapting it to handle visual signals? Instead we learn such an embedding space for images that it is compatible with the language model’s.\nInspired by prefix or prompt tuning, both Frozen (Tsimpoukelli et al. 2021) and ClipCap (Mokady, Hertz \u0026amp; Hertz, 2021) only update the parameters of the vision module during training to produce image embeddings that can work with a pretrained, frozen language model. Both are trained with aligned image caption datasets to produce the next text token in caption conditioned on the image and previous text tokens. The powerful language capability is retained by freezing LM parameters. In addition, even though such setup is trained with limited image caption data, they can also rely on the encyclopedic knowledge of the language model at test time.\nThe vision encoder of Frozen is based on NF-ResNet-50 and uses the final output vector of the NF-Resnet after the global pooling layer. The Frozen VLM can be used as a multi-model few-shot learner to adapt to new tasks at test time for zero-shot or few-shot transfer with a sequence of interleaved images and text.\nFig. 6. Illustration of Frozen model (left) training architecture and (right) testing pipeline. (Image source: Tsimpoukelli et al. 2021) Experiments showed that fine-tuning the pre-trained LM interestingly leads to worse performance on VQA tasks. It is important to initialize the language model from a pre-trained version, as training from scratch (${Frozen}_\\text{scratch}$) does not show any meaningful progress. The baseline ${Frozen}_\\text{train-blind}$ blacks out the image but still can achieve decent performance because of the innate power of the pre-trained LM.\nFig. 7. Performance of different versions of Frozen on (left) VQAv2 and (right) OKVQA, trained on Conceptual Captions. \"Frozen scratch\" does not load a pre-trained LM and is trained from scratch. \"Frozen finetuned\" has the language model finetuned, while \"Frozen\" keeps LM frozen. \"Frozen train-blind\" blacks out the image. (Image source: Tsimpoukelli et al. 2021) ClipCap relies on CLIP (Radford et al. 2021) for vision encoding, but it needs to be processed by a light mapping network $F$ such that image embedding vectors are translated into the same semantic space as the pre-trained LM. The network $F$ maps CLIP embeddings into a sequence of $k$ embedding vectors, each with the same dimension as a word embedding in GPT2. Increasing the prefix size $k$ helps improve the performance. Both CLIP vision encoder and the LM are frozen during training and only the mapping network $F$ is learned. They found that when LM is frozen, $F$ should be a transformer, with 8 multi-head self-attention layers with 8 heads each, but when LM can be fine-tuned, a MLP is enough.\nEven though ClipCap only trains such a minimum set of parameters, it still achieves decent performance on image captioning tasks, comparable with SoTA at the time (e.g. Oscar, VLP, BUTD). Hence they postulate that \u0026ldquo;the CLIP space already encapsulates the required information, and adapting it towards specific styles does not contribute to flexibility.\u0026rdquo;\nFig. 8. Overview of ClipCap training pipeline where only the mapping network needs to be train to transform CLIP image embedding to work with the pre-trained LM. (Image source: Mokady, Hertz \u0026 Hertz, 2021) The fun fact is - because ClipCap translates CLIP image embeddings into LM space, the processed prefixes can be even interpreted as words.\nFig. 9. The learned image embedding can be interpreted as text, containing words related to the image context. (Image source: Mokady, Hertz \u0026 Hertz, 2021) Text-Image Cross-Attention Fuse Mechanisms To more efficiently fuse visual information into different layers of the language model, we can consider a specially designed cross-attention fuse mechanism to balance the mixture of text generation capacity and visual information.\nVisualGPT (Chen et al. 2021) employs a self-resurrecting encoder-decoder attention mechanism to quickly adapt the pre-trained LM with a small amount of in-domain image-text data.\nFig. 10. Illustration of VisualGPT architecture. (Image source: Chen et al. 2021) Let $I$ be the output of a visual encoder and $H$ be the hidden state of the LM decoder. VisualGPT introduced a self-resurrecting activation unit (SRAU) to control the tradeoff between a mixture of pre-trained linguistic information $H$ and visual component, $\\text{EncDecAttn}(H, I)$ via two complementary gates $B^\\text{vis}$ and $B^\\text{lan}$:\n$$ \\begin{aligned} \u0026amp; B^\\text{vis} \\otimes \\text{EncDecAttn}(H, I) + B^\\text{lan} \\otimes H \\\\ \\text{where } \u0026amp; B^\\text{vis}[i,j] = \\sigma(H[i,j]) \\mathbb{1}[\\sigma(H[i,j]) \u0026gt; \\tau] \\\\ \u0026amp; B^\\text{lan}[i,j] = (1 - \\sigma(H[i,j])) \\mathbb{1}[1 - \\sigma(H[i,j]) \u0026gt; \\tau] \\\\ \\end{aligned} $$ where $\\otimes$ is element-wise multiplication and $[i,j]$ denotes one element in the matrix. $\\tau$ is a predefined threshold hyperparameter.\nFig. 11. Comparison of different models trained on 0.1% and 1% of the MS COCO and Conceptual Caption datasets. (Image source: Chen et al. 2021) VC-GPT (Visual Conditioned GPT; Luo et al. 2022) combines a pretrained visual transformer (CLIP-ViT) as visual encoder and a pretrained LM as language decoder.\nFig. 12. Illustration of VC-GPT training framework. (Image source: Luo et al. 2022) The CLIP-ViT takes a sequence of image patches as inputs and outputs representation for each patch. To avoid catastrophic forgetting, instead of injecting the visual information directly into GPT2, VC-GPT introduces extra cross-attention layers on top of the output of visual encoder and language decoder. Then a self-ensemble module linearly combines the single model language decoder logits $h^G$ and cross-model vision-language fused module logits $h^\\text{fuse}$. The self-ensemble module (see \u0026ldquo;VC-GPT w/o SE\u0026rdquo; in Fig. 13) is important for the performance.\n$$ \\text{logits} = W^G h^G + W^\\text{fuse}h^\\text{fuse} $$\nwhere $W^G$ is a linear projection of the language decoder, initialized by the word embedding matrix of GPT2 and $W^\\text{fuse}$ is a linear projection of the fusion module and initialized randomly.\nFig. 13. Performance of VC-GPT on the MS COCO test set, in comparison with other end-to-end image captioning baseline models. Metric abbreviation: C = CIDEr; B = BLEU; M = METEOR; S = SPICE. (Image source: Luo et al. 2022) MERLOT (Zellers, et al. 2021) is trained with 6 millions of YouTube videos with transcribed speech (YT-Temporal-180M) to learn both spatial (frame-level) and temporal (video-level) objectives and demonstrated strong performance on VQA and visual reasoning tasks when fine-tuned.\nEach video $\\mathcal{V}$ is split into multiple segments $\\{ \\boldsymbol{s}_t \\}$, each segment $\\boldsymbol{s}_t$ containing an image frame $\\mathbf{I}_t$ extracted from the middle timestep and $L=32$ tokens of words associated. Images are encoded by a learned image encoder and words are encoded using a learned embedding. Then both are encoded together within a joint vision-language transformer.\nThere are 3 learning objectives in MERLOT:\n Masked language modeling (MLM) is useful especially because in videos, people tend to ramble, resulting in many repeated keywords or filler words. Contrastive frame-caption matching uses the language-only part from the joint vision-language transformer. Matched representations for each frame $\\mathbf{I}_t$ and caption $\\boldsymbol{w}_t$ are treated as positive examples, while the negative examples come from all other frame-caption pairs in the minibatch. Temporal reordering learns temporal reasoning: scramble random $i$ frames and replace the segment-level position embeddings with a random and unique position embedding. The random position embeddings are learned, allowing the model to unshuffle these \u0026ldquo;\u0026lsquo;shuffled\u0026rsquo;\u0026rdquo; frames conditioned on correctly-ordered ones. The loss is to predict whether $t_i \u0026lt; t_j$ or $t_j \u0026lt; t_i$ for each frame-frame pair. Fig. 14. Illustration of MERLOT training framework: (Left) contrastive frame-caption matching training; (Right) joint vision-language transformer is trained with MLM loss, as well as on the temporal reordering task to unshuffle scrambled video frames. (Image source: Zellers, et al. 2021) Ablation studies showed that it is important to (1) train on videos instead of images, (2) scale up the size and diversity of the training dataset and (3) use diverse objectives to encourage full-stack multimodal reasoning.\nFlamingo (Alayrac et al. 2022) is a visual language model that accepts text interleaved with images/videos and outputs free-form text. Flamingo connects a pretrained LM and a pretrained vision encoder (i.e. CLIP image encoder) via a transformer-based mapper. To more efficiently incorporate vision signals, Flamingo adopts a Perceiver-based architecture to produce a few hundreds of tokens out of a large number of visual input features and then use cross-attention layers interleaved with the LM layers to fuse visual information into the language decoding process. The training objective is an autoregressive, NLL loss.\n The Perceiver resampler receives spatio-temporal features from the vision encoder of image/video inputs to produce fixed-size visual tokens. The frozen LM is equipped with newly initialized cross-attention layers interleaved between the pretrained LM layers. Thus the LM can generate text conditioned on the above visual tokens. Similar to ClipCap, both pretrained models are frozen during training and thus Flamingo is only trained to harmoniously connect existing, powerful language and vision models together. Tha main difference between ClipCap and Flamingo is that the former treats the image embedding as simple prefix for LM, while the latter uses the gated cross-attention-dense layer to fuse image information. In addition, Flamingo incorporates a lot more training data than ClipCap.\nFig. 15. Overview of the Flamingo model. (Image source: Alayrac et al. 2022) Fig. 16. The architecture illustration and pseudo code of the gated cross-attention-dense layer in Flamingo. (Image source: Alayrac et al. 2022) To easily handle text with interleaved images, masking in Flamingo is designed such that text token only cross-attends to visual tokens corresponding to the last preceding image, largely reducing the number of visual tokens that a certain text token can see. They found this works better than allowing text tokens to attend to all preceding images directly. Text still can attend to all previous images because there is a causal self-attention dependency in the text encoder. This design can deal with an arbitrary number of images in the context.\nThey scraped 43 million webpages from the Internet, named MultiModal MassiveWeb (M3W) dataset, containing text with interleaved images. In addition, Flamingo is also trained on paired image/text and video/text datasets, including ALIGN, LTIP and VTP.\nData processing of the Internet dataset includes:\n The input Web page text is processed by inserting \u0026lt;image\u0026gt; tags at the location of visual inputs, as well as special tokens, \u0026lt;BOS\u0026gt; (beginning of sentence) and \u0026lt;EOC\u0026gt; (end of chunks; always at the end of the document, before any image tag). From each document, they sample a random subsequence of $L = 256$ tokens and take up to $N = 5$ images included in the sampled sequence (using only the first $N$ within that sampled subsequence if there are more, or padding to $N$ if fewer) A function $\\phi: [1,L] \\to [0,N]$ is computed to track the text and image interleaving order, which assigns to each text position the index of the last image/video appearing before this position; 0 if no preceding visual data. Since Flamingo is trained on a mixture of three different datasets, it optimizes for a weighted sum of dataset-specific NLL losses. Tuning the dataset weights is very important for the final performance. In practice, instead of round-robin between datasets, they actually sample one batch from each dataset and apply a weighted sum of these gradients in each update. Gradient accumulation across different heterogeneous datasets can be viewed as a mean to stabilize training, as it reduces the gradient variance between each update.\nAt test time, Flamingo naturally supports few-shot learning since it can work with any sequence of interleaved text and images. And more examples in the context contribute to better performance.\nFig. 17. Larger model sizes and more few-shot examples lead to better performance. (Image source: Alayrac et al. 2022) Flamingo outperforms SoTA fine-tuned models on 6 out of the 16 tasks despite even when not using any fine-tuning but only few-shot prompting. Fine-tuning Flamingo is expensive and it is difficult to do hyperparemeter tuning, but it does lead to better results.\nFig. 18. Performance of Flamingo model using different numbers of shots and of different sizes, in comparison with SoTA fine-tuned baseline. (Image source: Alayrac et al. 2022) CoCa (Contrastive Captioner; Yu \u0026amp; Wang et al., 2022) captures both the merits of contrastive learning and image-to-caption generation. It is a model jointly trained with contrastive loss on CLIP-style representation and generative loss on image captioning, achieving SoTA zero-shot transfer on a variety of multi-modal evaluation tasks.\nFig. 19. Overview of CoCa training framework. (Image source: Yu \u0026 Wang et al., 2022) CoCa is pretrained from scratch, using web-scale alt-text data ALIGN and annotated images by treating all labels as texts in JTB-3B.\nThere are two major training components in CoCa. The final loss is a weighted sum of the following two losses, with weight scalars $\\lambda_\\text{cap}=2.0, \\lambda_\\text{con} = 1.0$.:\n $\\mathcal{L}_\\text{con}$ - Dual-encoder contrastive learning optimizes the symmetric contrastive learning loss, similar to CLIP. $\\mathcal{L}_\\text{cap}$ - Encoder-decoder captioning has the decoder predict the caption based on the latent encoded features from the image encoder, by optimizing an autoregressive loss. The text decoder is decoupled into two components, unimodal and multimodal; a good balance is to split the decoder by half for these two components: The bottom unimodal component encodes the input text with causally-masked self-attention. The top multimodal component applies both causally-masked self-attention and cross-attention to the output of the vision encoder. CoCa performs better than the contrastive-only model and on par with the captioning-only model on VQA. Captioning loss is found to be beneficial to the zero-shot classification capacity too.\nFig. 20. Illustration of how CoCa can be used to solve various downstream tasks at test time. (Image source: Yu \u0026 Wang et al., 2022) They use task-specific attention pooling, or attention pooler, as a natural task adapter, as they found that a single pooled image embedding helps visual recognition tasks (e.g. ImageNet classification), while a more fine-grained embedding helps multimodal understanding tasks (e.g. VQA). A pooler is a single multi-head attention layer with $n_\\text{query}$ learnable queries (note that $\\mathbf{X} \\in \\mathbb{R}^{L \\times d}$, $\\mathbf{W}^q \\in \\mathbb{R}^{d \\times d_q}$, and $d_k = d_q$), with the encoder output as both keys and values. CoCa uses attentional poolers in pretraining for generative loss $n_\\text{query} = 256$ and contrastive loss $n_\\text{query} = 1$. This enables the model to obtain strong performance as a frozen encoder where we only learn a new pooler to aggregate features.\nFig. 21. Pseudo code for CoCa architecture and training. (Image source: Yu \u0026 Wang et al., 2022) No Training Finally it is possible to solve vision language tasks by stitching pretrained language and vision models together without training any additional parameters.\nDecoding Guided with Vision-based Scores MAGiC (iMAge-Guided text generatIon with CLIP; Su et al. 2022) does guided decoding according to a CLIP-based score named magic score to sample the next token, without fine-tuning. The generated text is encouraged to be relevant to the given image, while still stay coherent to the previously generated text.\nThe next token $x_t$ at a time step $t$ is selected according to the following equation. Model confidence and degeneration penalty (Su et al. 2022) are added to avoid corrupted generation from LM.\n$$ \\begin{aligned} \u0026amp; x_t = \\arg\\max_{v \\in \\mathcal{V}^{(k)}} \\big\\{ (1-\\alpha) \\underbrace{p(v \\vert \\boldsymbol{x}_{\u0026lt;t})}_\\text{model confidence} - \\alpha \\underbrace{\\max_{1 \\leq j \\leq t-1} { \\text{cosine}(h_v, h_{x_j})}}_\\text{degeneration penalty} + \\beta \\underbrace{f_\\text{magic}(v \\vert \\mathcal{I}, \\boldsymbol{x}_{\u0026lt;t}, \\mathcal{V}^{(k)})}_\\text{magic score} \\big\\} \\\\ \\text{where } \u0026amp; f_\\text{magic} ( v \\vert \\mathcal{I}, \\mathbf{x}_{\u0026lt;t}, \\mathcal{V}^{(k)} ) = \\frac{ \\exp(\\text{CLIP}(\\mathcal{I}, [\\boldsymbol{x}_{\u0026lt;t}:v])) }{ \\sum_{z \\in \\mathcal{V}^{(k)}} \\exp(\\text{CLIP}(\\mathcal{I}, [\\boldsymbol{x}_{\u0026lt;t}:z])) } = \\frac{ \\exp\\big({h^\\text{image}(\\mathcal{I})}^\\top h^\\text{text}([\\boldsymbol{x}_{\u0026lt;t}:v])\\big) }{ \\sum_{z \\in \\mathcal{V}^{(k)}} \\exp\\big({h^\\text{image}(\\mathcal{I})}^\\top h^\\text{text}([\\boldsymbol{x}_{\u0026lt;t}:z])\\big) } \\end{aligned} $$\nwhere $\\mathcal{I}$ is the input image; $\\mathcal{V}^{(k)}$ contains top-$k$ possible tokens predicted by the language model $p$; $\\boldsymbol{x}_{\u0026lt;t}$ refers to the past generated tokens before time step $t$; $h_v$ is the representation of the token $v$ computed by LM conditioned on the concatenation of $\\boldsymbol{x}_{\u0026lt;t}$ and $v$; $h^\\text{image}(.)$ and $h^\\text{text}(.)$ are embeddings generated by CLIP image and text encoders, respectively.\nMAGiC has decent performance compared to other unsupervised approaches, but still has big gaps with supervised methods.\nFig. 22. Image captioning performance on COCO and Flickr30k. (Image source: Su et al. 2022) Language as Communication Interface For knowledge-based VQA tasks, PICa (Prompts GPT-3 via the use of Image Captions; Yang et al. 2021) first converts the images into captions or tags and then uses few-shot examples to prompt GPT3 to provide answers. Image captions or tags are extracted by some existing models (e.g. VinVL) or Azure Tagging API. And GPT3 is considered as an unstructured, implicit knowledge base.\nFig. 23. How PICa works for $n$-shot VQA at inference time. (Image source: Yang et al. 2021) PICa explored two ways to improve few-shot examples to achieve better results:\n In-context examples are selected based on how similar they are to the question using CLIP embedding. Multi-query ensembling is to prompt the model multiple times to get multiple answers and the one with highest logprob is selected. This simple approach with only 16 examples improved SoTA on OK-VQA by +8.6 points and got decent performance on VQAv2.\nFig. 24. Performance of PICa on OK-VQA. \"PICa-Base\" has random in-context examples, while \"PICa-Full\" incorporates both similar in-context example selection and multi-query ensembling. (Image source: Yang et al. 2021) Socratic Models (SM) (Zeng et al. 2022) is a framework to compose multiple pretrained models for different modality via language (prompting) into one model without further training. Here language is considered as the intermediate representation by which different models can exchange information. The key idea is to use multi-model multimodal prompting, in which output of a non-language model is inserted into a language prompt and then it is used for LM for reasoning.\nLet’s examine a concrete example. Given an ego-centric video (images + audio), SM can produce a summary of the person’s activity using text-to-text LM, image-to-text VLM and speech-to-text ALM. They are chained as follows:\n(Image source: Zeng et al. 2022) the VLM detects visual entities; the LM suggests sounds that may be heard; the ALM chooses the most likely sound; the LM suggests possible activities; the VLM ranks the most likely activity; the LM generates a summary of the Socratic interaction. Fig. 25. Illustration of the Socratic Model solution for image captioning. (Image source: Zeng et al. 2022) SM can generate image captions by first using VLM to zero-shot predict different place categories, object categories, image type and the number of people; and then the VLM-filled language prompt is fed into a causal LM to generate caption candidates. The Socratic approach still has performance gap with ClipCap on image captioning but pretty decent given it does not involve any training.\nFig. 26. Comparison of image captioning performance of different models on random 100 COCO text examples. (Image source: Zeng et al. 2022) SM framework is very flexible and can be used on a lot more complicated tasks other than image captions. For example, the egocentric perception (User inputs + VLM + LM + ALM) task is to take as inputs egocentric videos to (1) summarize content; (2) answer free-form reasoning questions; (3) and do forecasting.\nFig. 27. The Socratic Model approach for generating captions and question answering based on the egocentric videos. (Image source: Zeng et al. 2022) Datasets Image Caption Datasets MS COCO (Chen et al. 2015): contains 328K images and each paired with 5 independent captions. NoCaps (Agrawal et al., 2019) is designed to measure generalization to unseen classes and concepts, where in-domain contains images portraying only COCO classes, near-domain contains both COCO and novel classes, and out-of-domain consists of only novel classes. Conceptual Captions (Sharma et al. 2018) contains 3 million pairs of images and captions, mined from the web and post-processed. To focus on the concepts, specific entities in this dataset are replaced with general notions (e.g. a politician’s name is replaced with \u0026ldquo;politician\u0026rdquo;) Crisscrossed Captions (CxC) (Parekh et al. 2021) contains 247,315 human-labeled annotations including positive and negative associations between image pairs, caption pairs and image-caption pairs. Concadia (Kreiss et al. 2021) is a Wikipedia-based dataset containing 96,918 images with corresponding English-language descriptions, captions, and surrounding context. Pair Image-Text Datasets (*) Not a public dataset.\n ALIGN (Jia et al., 2021) contains 1.8 billion images with alt-text. The dataset is large but noisy with only minimal frequency-based filtration. (*) LTIP (Long text \u0026amp; image pairs; Alayrac et al. 2022): 312 million images, paired with descriptive captions. (*) VTP (Video \u0026amp; text pairs; Alayrac et al. 2022): 27 million short videos (~22 seconds on average), paired with descriptive captions. (*) JFT-300M / JFT-3B are internal Google datasets, containing 300M / 3B images annotated with a class-hierarchy of around 30k labels via a semi-automatic pipeline. Thus the data and associated labels are noisy. Evaluation Tasks Visual Question-Answering Given an image and a question, the task is to correctly answer the question.\n VQAv2 (Goyal et al., 2017) contains 1+ million questions about 200K images from COCO. OK-VQA (Marino et al. 2019) contains 14K open-ended questions that require outside knowledge (e.g. from Wikipedia). A-OKVQA: the augmented successor of OK-VQA, with no overlapped questions with OK-VAQ. TextVQA (Singh, et al. 2019) contains 45,336 questions on 28,408 images that require reasoning about text to answer. VizWiz (Gurari, et al. 2018) contains over 31,000 visual questions originating from blind people who each took a picture using a mobile phone and recorded a spoken question about it, together with 10 crowdsourced answers per visual question. Visual Language Reasoning VCR (Visual Commonsense Reasoning; Zellers et al. 2018) contains 290k multiple choice QA questions derived from 110k movie scenes, with focus on visual commonsense. NLVR2 (Natural Language for Visual Reasoning; Suhr et al. 2019) contains 100k+ examples of sentences paired with web images and the task is to determine whether a natural language caption is true about a pair of images, with a focus on semantic diversity. Flickr30K (Jia et al. 2015) contains 30k images collected from Flickr and 250k annotations and the task is to select the bounding regions given spans of a sentence. SNLI-VE (Visual Entailment; Xie et al. 2019) is built on top of SNLI and Flickr30K and the task is to reason about the relationship between an image premise and a text hypothesis. Video QA and Understanding MSR-VTT (MSR Video to Text; Xu et al. 2016) contains 10K web video clips with 41.2 hours and 200K clip-sentence pairs in total; the task is to translate videos to text. ActivityNet-QA (Yu et al. 2019) contains 58,000 human-annotated QA pairs on 5,800 videos derived from the popular ActivityNet dataset. TGIF (Tumblr GIF; Li et al. .2016) contains 100K animated GIFs and 120K sentences describing visual content of the animated GIFs, randomly selected posts published between May and June of 2015 on Tumblr. TGIF-QA contains 165K QA pairs for the animated GIFs from the TGIF dataset. LSMDC (Large Scale Movie Description Challenge; Rohrbach et al. 2015) contains 118,081 short video clips extracted from 202 movies. Each video has a caption, either extracted from the movie script or from transcribed DVS (descriptive video services) for the visually impaired. TVQA (Lei et al. 2018) / TVQA+ (Lei et al. 2019) is a large-scale video QA dataset based on 6 popular TV shows (Friends, The Big Bang Theory, How I Met Your Mother, House M.D., Grey\u0026rsquo;s Anatomy, Castle). It consists of 152.5K QA pairs from 21.8K video clips, spanning over 460 hours of video. DramaQA (Choi et al. 2020) is a large-scale video QA dataset based on a Korean popular TV show, \u0026ldquo;Another Miss Oh\u0026rdquo;. This dataset contains four levels of QA on difficulty and multi-level character-centered story descriptions. VLEP (Video-and-Language Event Prediction; Lei et al. 2020) contains 28,726 future event prediction examples (along with their rationales) from 10,234 diverse TV Show and YouTube Lifestyle Vlog video clips. Citation Cited as:\n Weng, Lilian. (Jun 2022). Generalized visual language models. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2022-06-09-vlm/.\n Or\n@article{weng2022vlm, title = \u0026quot;Generalized Visual Language Models\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;Lil'Log\u0026quot;, year = \u0026quot;2022\u0026quot;, month = \u0026quot;Jun\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2022-06-09-vlm/\u0026quot; } References [1] Li et al. \u0026ldquo;VisualBERT: A Simple and Performant Baseline for Vision and Language.\u0026quot; arXiv preprint:1908.03557 (2019).\n[2] Wang et al. \u0026ldquo;SimVLM: Simple Visual Language Model Pretraining with Weak Supervision.\u0026quot; ICLR 2022.\n[3] Aghajanyan, et al. \u0026ldquo;CM3: A Causal Masked Multimodal Model of the Internet.\u0026quot; arXiv preprint arXiv: 2201.07520 (2022).\n[4] Tsimpoukelli et al. \u0026ldquo;Multimodal Few-Shot Learning with Frozen Language Models.\u0026quot; NeuriPS 2021.\n[5] Mokady, Hertz \u0026amp; Hertz. \u0026ldquo;ClipCap: CLIP Prefix for Image Captioning.\u0026quot; 2021.\n[6] Chen et al. \u0026ldquo;VisualGPT: Data-efficient Adaptation of Pretrained Language Models for Image Captioning.\u0026quot; arXiv preprint arXiv:2111.09734 (2021).\n[7] Luo et al. \u0026ldquo;A Frustratingly Simple Approach for End-to-End Image Captioning.\u0026quot; arXiv preprint arXiv:2201.12723 (2022).\n[8] Zellers et al. \u0026ldquo;MERLOT: Multimodal neural script knowledge models.\u0026quot; NeuriPS 2021.\n[9] Alayrac et al. \u0026ldquo;Flamingo: a Visual Language Model for Few-Shot Learning.\u0026quot; arXiv preprint arXiv:2204.14198 (2022).\n[10] Yu \u0026amp; Wang et al. \u0026ldquo;CoCa: Contrastive Captioners are Image-Text Foundation Models.\u0026quot; arXiv preprint arXiv:2205.01917 (2022).\n[11] Yang et al. \u0026ldquo;An Empirical Study of GPT-3 for Few-Shot Knowledge-Based VQA.\u0026quot; arXiv preprint arXiv:2109.05014 (2021).\n[12] Su et al. \u0026ldquo;Language models can see: Plugging visual controls in text generation.\u0026quot; arXiv preprint arXiv:2205.02655 (2022).\n[13] Zeng et al. \u0026ldquo;Socratic Models: Composing Zero-Shot Multimodal Reasoning with Language.\u0026quot; arXiv preprint arXiv:2204.00598 (2022).\n","permalink":"https://lilianweng.github.io/posts/2022-06-09-vlm/","summary":"Processing images to generate text, such as image captioning and visual question-answering, has been studied for years. Traditionally such systems rely on an object detection network as a vision encoder to capture visual features and then produce text via a text decoder. Given a large amount of existing literature, in this post, I would like to only focus on one approach for solving vision language tasks, which is to extend pre-trained generalized language models to be capable of consuming visual signals.","title":"Generalized Visual Language Models"},{"content":"Here comes the Part 3 on learning with not enough data (Previous: Part 1 and Part 2). Let’s consider two approaches for generating synthetic data for training.\n Augmented data. Given a set of existing training samples, we can apply a variety of augmentation, distortion and transformation to derive new data points without losing the key attributes. We have covered a bunch of augmentation methods on text and images in a previous post on contrastive learning. For the sake of post completeness, I duplicate the section on data augmentation here with some edits. New data. Given few or even no data points, we can rely on powerful pretrained models to generate a number of new data points. This is especially true in recent years given the fast progress in large pretrained language models (LM). Few shot prompting is shown to be effective for LM to learn within context without extra training. Data Augmentation The goal of data augmentation is to modify the input format (e.g. text wording, visual appearance) while the semantic meaning stays unchanged.\nImage Augmentation Basic Image Processing Operations There are several ways to modify an image while retaining its semantic information. We can use any one of the following augmentation or a composition of multiple operations.\n Random cropping and then resize back to the original size. Random color distortions Random Gaussian blur Random color jittering Random horizontal flip Random grayscale conversion And many more. Check PIL.ImageOps for inspiration. Task-Specific Augmentation Strategies If the downstream task is known, it is possible to learn the optimal augmentation strategies (i.e. what processing operations to use and how to combine them in sequence) to maximize the downstream task performance.\n AutoAugment (Cubuk, et al. 2018) is inspired by neural architecture search, AutoAugment frames the problem of learning best data augmentation operations (i.e. shearing, rotation, invert, etc.) for image classification as an RL problem and looks for the combination that leads to the highest accuracy on the evaluation set. AutoAugment can be executed in adversarial fashion (Zhang, et al 2019). RandAugment (Cubuk et al., 2019) greatly reduces the search space of AutoAugment by controlling the magnitudes of different transformation operations with a single magnitude parameter. Population based augmentation (PBA; Ho et al., 2019) combines PBT (\u0026ldquo;population based training\u0026rdquo;; Jaderberg et al, 2017) with AutoAugment, using the evolutionary algorithm to train a population of children models in parallel to evolve the best augmentation strategies. Unsupervised Data Augmentation (UDA; Xie et al., 2019), among a set of possible augmentation strategies, selects a subset to minimize the KL divergence between the predicted distribution over an unlabelled example and its unlabelled augmented version. Image Mixture Image mixture methods can construct new training examples from existing data points.\n Mixup (Zhang et al., 2018) runs global-level mixture by creating a weighted pixel-wise combination of two existing images $I_1$ and $I_2$: $I_\\text{mixup} \\gets \\alpha I_1 + (1-\\alpha) I_2$ and $\\alpha \\in [0, 1]$. Cutmix (Yun et al., 2019) does region-level mixture by generating a new example by combining a local region of one image with the rest of the other image. $I_\\text{cutmix} \\gets \\mathbf{M}_b \\odot I_1 + (1-\\mathbf{M}_b) \\odot I_2$, where $\\mathbf{M}_b \\in \\{0, 1\\}^I$ is a binary mask and $\\odot$ is element-wise multiplication. It is equivalent to filling the cutout (DeVries \u0026amp; Taylor 2017) region with the same region from another image. Given a query $\\mathbf{q}$, MoCHi (\u0026ldquo;mixing of contrastive hard negatives\u0026rdquo;; Kalantidis et al. 2020) maintains a queue of $K$ negative features $Q={\\mathbf{n}_1, \\dots, \\mathbf{n}_K }$ and sorts these negative features by similarity to the query, $\\mathbf{q}^\\top \\mathbf{n}$, in descending order. The first $N$ items in the queue are considered as the hardest negatives, $Q^N$. Then synthetic hard examples can be generated by $\\mathbf{h} = \\tilde{\\mathbf{h}} / |\\tilde{\\mathbf{h}}|_2$ where $\\tilde{\\mathbf{h}} = \\alpha\\mathbf{n}_i + (1-\\alpha) \\mathbf{n}_j$ and $\\alpha \\in (0, 1)$. Even harder examples can be created by mixing with the query feature, $\\mathbf{h}' = \\tilde{\\mathbf{h}'} / |\\tilde{\\mathbf{h}'}|_2$ where $\\tilde{\\mathbf{h}'} = \\beta\\mathbf{q} + (1-\\beta) \\mathbf{n}_j$ and $\\beta \\in (0, 0.5)$. Text Augmentation Lexical Edits Easy Data Augmentation (EDA; Wei \u0026amp; Zou 2019) defines a set of simple but powerful operations for text augmentation. Given a sentence, EDA randomly chooses and applies one of four simple operations:\n Synonym replacement (SR): Replace $n$ random non-stop words with their synonyms. Random insertion (RI): Place a random synonym of a randomly selected non-stop word in the sentence at a random position. Random swap (RS): Randomly swap two words and repeat $n$ times. Random deletion (RD): Randomly delete each word in the sentence with probability $p$. where $p=\\alpha$ and $n=\\alpha \\times \\text{sentence_length}$, with the intuition that longer sentences can absorb more noise while maintaining the original label. The hyperparameter $\\alpha$ roughly indicates the percent of words in one sentence that may be changed by one augmentation.\nEDA is shown to improve the classification accuracy on several classification benchmark datasets compared to baseline without EDA. The performance lift is more significant on a smaller training set. All the four operations in EDA help improve the classification accuracy, but get to optimal at different $\\alpha$\u0026rsquo;s.\nFig. 1. EDA leads to performance improvement on several classification benchmarks. (Image source: Wei \u0026 Zou 2019) Contextual Augmentation (Kobayashi, 2018) replaces word $w_i$ at position $i$ by sampling from a probability distribution learned by a bidirectional LM such as BERT, $p(.\\mid S\\setminus{w_i})$. In this way, the words are substituted by synonyms, or similar words suitable for the context. To guarantee such operations do not alter the labels, the LM is fit to be label-conditioned bidirectional LM. Conditional BERT (CBERT; Xing Wu et al. 2018) extends BERT to predict masked tokens conditioned on the class label and can be used for contextual augmentation prediction.\nBack-translation Back-translation produces augmented data by translating text samples to another language and then translating them back. The translation happens in two ways and both directions should have decent enough performance to avoid significant loss of semantic meaning.\nMix-up It is also possible to apply Mixup to text (Guo et al. 2019) but on the embedding space to obtain some performance gain. The proposed method relies on a specially designed model architecture to operate the prediction on the word or sentence embedding. Adding adversarial noise in the embedding space as a way of data augmentation is shown to improve the generalization of model training (Zhu et al. 2019).\nAudio Augmentation Here is a list of several commonly used audio data augmentation methods, operated on raw audio or spectrograms, summarized by Wang \u0026amp; van den Oord (2021).\nAudio mixup. Given two audio clips $\\mathbf{x}_1$ and $\\mathbf{x}_2$, the mixed-up version $\\hat{\\mathbf{x}} = \\alpha \\mathbf{x}_1 + (1-\\alpha)\\mathbf{x}_2$ should be associated with the label of the more dominant input. The audio mixup augments the data with more realistic noise.\nTime masking. A small consecutive chunk of the audio can be masked without losing semantic information.\nFrequency masking. A small amount of frequency components on the spectrogram can be dropped off and it should not change the associated label.\nFrequency shift. The spectrogram can be shifted by an integer between $[-F, F]$, where $F$ is the maximum shift size. It is a cheap augmentation to change the pitch of the audio.\nArchitectural Augmentation Models with dropout layers can create augmented samples by applying different dropout masks on the same input sample. For example, in the contrastive learning model SimCSE (Guo et al. 2021), a sample is simply fed into the encoder twice with different dropout masks and these two versions are the positive pair where the other in-batch samples are considered as negative pairs.\nDropout augments data by adding noise onto the internal representation of the model. It can be applied in a more structured way, such as in cutoff (Shen et al. (2020)), where random chunks of the token embedding matrix are removed.\nData Synthesis Given that generating high-quality, photorealistic images is a lot more difficult than generating human-like natural language text and recent success with large pretrained language models, this section only focuses on text generation. To read more on how to synthesize realistic images, check posts on GAN, VAE, flow and diffusion models.\nLanguage Model as Noisy Annotator Wang et al. (2021) explored ways to leverage GPT-3 as a weak annotator via few-shot prompting, achieving 10x cheaper than human labeling. The paper argues that by using data labeled by GPT-3, it essentially performs self-training: The predictions on unlabeled samples apply entropy regularization on the model to avoid high class overlaps so as to help improve the model performance.\nFig. 2. Illustration of how to use GPT-3 to generate more training data with the human-in-the-loop active learning pipeline to improve the data quality. (Image source: Wang et al. 2021) GPT-3-labeled samples selected by active learning with highest uncertainty are sent to human labelers to be re-annotated. The few-shot prompt contains a small number of human labeled examples and thus the labeling cost is restricted. Synthetic samples are ranked by predicted logits of label $y$ and those with the lowest scores go through relabeling.\nGPT-3 labeling achieves better results in the low-cost regime, but has a gap with human labeling when enough money is spent on data collection. This implies the following inequation, although to what extent \u0026ldquo;a lot\u0026rdquo; or \u0026ldquo;noisy\u0026rdquo; means depends on the task details.\n A lot of high-quality data \u0026gt; A lot of noisy data \u0026gt; A little high quality data.\n Fig. 3. GPT-3 labeling technique improves the classification performance in the low-cost regime. (Image source: Wang et al. 2021) Language Model as Data Generator If enough training dataset for text classification tasks are available, we can fine-tune language models to synthesize more training samples conditioned on labels (Anaby-Tavor et al. 2019, Kumar et al. 2021).\nLanguage-model-based data augmentation (LAMBADA; Anaby-Tavor et al. 2019) takes such an idea, where the process involves fine-tuning both a classifier and a sample generation model.\n Train a baseline classifier using the existing training dataset: $h = \\mathcal{A}(\\mathcal{D}_\\text{train})$. Independently of step 1, a LM $\\mathcal{M}$ is fine-tuned on $\\mathcal{D}_{\\text{train}}$ to obtain $\\mathcal{M}_{\\text{tuned}}$. Synthesize a labeled dataset $\\mathcal{D}^*$ by generating the continuation of the sequence y[SEP] until EOS using $\\mathcal{M}_\\text{tuned}$. Filter synthesized dataset by, (1) Verifying that the predicted label is correct $h(x)=y$; (2) Selecting the top ranked samples when they are ranked by the classifier probability. $\\mathcal{D}_\\text{syn} \\subset \\mathcal{D}^*$. They generate 10x more samples needed for augmentation and only the top 10% synthesized samples with highest confidence scores remain. The final classifier is trained on $\\mathcal{D}_\\text{syn} \\cup \\mathcal{D}_\\text{train}$ . The process can be repeated multiple times, but it is unclear whether the benefit would quickly diminish or the repetitive process would bring in self-bias.\nFig. 4. Accuracy of LAMBADA vs. other generative approaches over all datasets and classifiers. (Image source: Anaby-Tavor et al. 2019) To simplify LAMBADA, we can actually remove the dependency of a fine-tuned generation model and an existing training dataset of a decent size (Step 2 above). Unsupervised data generation (UDG; Wang et al. 2021) relies on few-shot prompting on a large pretrained language model to generate high-quality synthetic data for training. Opposite to the above approach where LM is asked to predict $y$ given $\\mathbf{x}$, UDG instead synthetizes the inputs $\\mathbf{x}$ given labels $y$. Then a task-specific model is trained on this synthetic dataset.\nSchick \u0026amp; Schutze (2021) proposed a similar idea but on the NLI task instead of classification, asking PLM to write sentence pairs that are similar or different while the model is prompted with task-specific instructions.\nFig. 5. Illustration of the unsupervised data generation (UDG) framework. (Image source: Wang et al., 2021) The few-shot prompts of UDG contain a small number of unlabeled examples, as well as a task-specific natural language description of the desired label. Because some generated examples are noisy, they implemented noisy label annealing (NLA) techniques to filter potentially misaligned samples out during the training processes. NLA gradually removes noisy training signals in time during training when the model starts to disagree with its pseudo label with high confidence. At each training step $t$, a given example $(\\mathbf{x}_i, \\hat{y}_i)$ is considered noisy and should be removed if:\n The model predicted probability is higher than a threshold $p(\\bar{y}_i \\vert \\mathbf{x}_i) \u0026gt; \\mu_t$ where $\\bar{y}_i = \\arg\\max_y p(y \\vert \\mathbf{x}_i)$; And the predicted label is different from the synthetic label, $\\bar{y}_i \\neq \\hat{y}_i$. Note that the threshold $\\mu_t$ is time-dependent, initialized as 0.9 and then gradually annealed to $1/\\text{num_of_classes}$ in time.\nAs shown in their experiments, the improvement of UDG over few-shot inference is quit significant, where NLA brings in some extra boost. The results are even comparable with supervised fine-tuning on several cases.\nFig. 6. Comparison of accuracy of UDG and other methods on different classification datasets. (Image source: Wang et al., 2021) Han et al (2021) achieved SOTA results on translation tasks using few-shot data generation, distillation and back-translation. The proposed method contains the following steps, assuming no access to paired translation data:\n Zero-shot Generation. First use the zero-shot translation ability of a pre-trained LM to generate translations for a small set of unlabeled sentences. Few-shot Generation. Then amplify these zero-shot translations by using them as few-shot demonstrations to gather an even larger synthetic dataset. Distillation. Fine-tune the model on this dataset. The translation task is formulated as a language modeling task [L1] \u0026lt;seq1\u0026gt; [[TRANSLATE]] [L2] \u0026lt;seq2\u0026gt;. given a pair of two sequences \u0026lt;seq1, seq2\u0026gt; in two different languages. At test-time, the LM is prompted with [L1] \u0026lt;seq\u0026gt; [[TRANSLATE]] [L2] and a candidate translation \u0026lt;sampledSeq\u0026gt; is parsed from the sampled completion. Back-translation. Continue fine-tuning on the back-translation dataset where the order of samples is reversed, \u0026lt;sampledSeq, seq\u0026gt;. Step 1-4 can be repeated. Fig. 7. Algorithm of using distillation and back-translation to train a language model on translation tasks. (Image source: Han et al. 2021) The success of the above method depends on a good pretrained LM to kick off the initial translation dataset. Iterative few-shot generation and distillation with back-translation is an effective way to extract and refine the translation capability out of a pretrained LM and further to distill that into a new model.\nFig. 8. Comparison of BLEU scores of the translation models of different training runs using: only distillation, back-translation, both and with more monolingual training data. (Image source: Han et al. 2021) How to Quantify Generated Data Quality? Given all the generated data, either by data augmentation or data synthesis, how can we quantify data quality in terms of how they improve model generalization? Gontijo-Lopes et al. (2020) introduced two dimensions to track, affinity and diversity.\n Affinity is a model-sensitive metric for distribution shift, quantifying how much an augmentation shifts the training data distribution from what a model learned. Definition: The performance difference between the model tested on clean data vs augmented data, while the model is trained on clean data. As a comparison, KL can also measure distribution shift but does not consider the model performance. Diversity is a measure of augmentation complexity, measuring the complexity of the augmented data with respect to the model and learning procedure. Definition: The final training loss of a model trained with a given augmentation. Another potential diversity measure is the entropy of the transformed data. A third potential diversity measure is the training time needed for a model to reach a given training accuracy threshold. All three metrics above are correlated. The final model performance is dependent on both metrics to be high enough.\nFig. 9. (a) Left: A scatter plot of affinity vs diversity metric, where each point represents a different augmentation method and its color indicates the final test accuracy. (b) Right: The conceptual illustration of the relationship between clean and augmented data in different regions of affinity and diversity metrics. (Image source: Gontijo-Lopes et al. 2020) There are many quantitative metrics on relevancy and diversity, in different formations depending on whether a reference is available, such as perplexity, BLEU for text and inception score for images. I\u0026rsquo;m skipping the list of concrete quantitative metrics on quality here, given it could be very long.\nTraining with Noisy Data It is convenient to collect a large amount of noisy data via model generation or data augmentation, but it is hard to guarantee that augmented and generated data can be 100% accurate. Knowing that deep neural networks can easily overfit noisy labels and \u0026ldquo;memotize\u0026rdquo; corrupted labels, we can apply the techniques for training on noisy labels (noise-robust training) when using generated data to stabilize and optimize the performance. Please check this survey paper (Song et al. 2021) on learning from noisy labels for a more thorough coverage of related work.\nRegularization and Robust Architecture Generally speaking, mechanisms designed for avoiding overfitting should help improve training robustness when working with moderately noisy data, such as weight decay, dropout, batch normalization. In fact, good data augmentation (i.e. only non-essential attributes are modified) can be considered as a way of regularization as well.\nA different approach is to enhance the network with a dedicated noisy adaptation layer to approximate the unknown projection of label corruption (Sukhbaatar et al. 2015, Goldberger \u0026amp; Ben-Reuven, 2017).\nSukhbaatar et al. (2015) introduced an extra linear layer $Q$ into the network architecture to adapt the predictions to match the noisy label distribution. The noise matrix $Q$ is initially fixed to the identity function while only the base model parameters is updated. After some time, $Q$ starts to be updated and expected to capture the noise in the data. The noise matrix is trained with regularization to encourage it to match the noise distribution while keeping the base model prediction accurate for true labels.\nFig. 10. (a) Left: A noise matrix $Q$ is added between softmax and the final output for the loss. (b) Right: The noise matrix $Q$ is fixed at the identity function initially and only gets updated with regularization after some training. (Image source: Sukhbaatar et al. 2015) However, it is hard to guarantee such a noise matrix layer would only capture the noise transition distribution and it is actually non-trivial to learn. Goldberger \u0026amp; Ben-Reuven (2017)) proposed to add an additional softmax layer end-to-end with the base model and apply the EM algorithm by treating the correct labels as latent random variable and the noise processes as a communication channel with unknown parameters.\nRobust Learning Objective Besides the most commonly used cross entropy loss, some other choices of learning objectives are shown to be more robust to noisy labels.\nFor example, MAE (mean absolute error) is more robust to noisy labels than CCE (categorical cross entropy), as it treats every sample equally (Ghosh et al. 2017). Lack of different weighting among training samples of MAE lead to significantly longer training time. Motivated by the tradeoff between MAE and CCE, Zhang \u0026amp; Sabuncu (2018) proposed generalized cross entropy (GCE), a generalization of CCE loss to be robust to noisy data.\nTo exploit the benefits of both the noise-robustness provided by MAE and the implicit weighting scheme of CCE, GCE adopts the the negative Box-Cox transformation as a loss function:\n$$ \\mathcal{L}_q(f(\\mathbf{x}_i, y_i = j)) = \\frac{1 - f^{(j)}(\\mathbf{x}_i)^q}{q} $$\nwhere $f^{(j)}$ denotes the $j$-th element of $f(.)$ and $q \\in (0, 1]$. $\\mathcal{L}_q$ is equivalent to CCE when $q \\to 0$ and becomes MAE when $q=1$. Empirical experiments show that there exists a threshold of $q$ with which overfitting never emerges and the noisier the data the higher such a threshold should be.\nGiven true and predicted labels, $y_i, \\hat{y}_i \\in \\{0, 1\\}$ and let $u_i=y_i \\cdot \\hat{y}_i$, the zero-one loss, $\\mathcal{L}_{01}(\\mathbf{u}) = \\sum_{i=1}^n \\mathbb{1}[u_i \u0026lt; 0]$, is another learning subjective shown to be robust to noisy data. Minimizing the empirical risk with the zero-one loss is shown to be equivalent to minimizing the empirical adversarial (worse-case) risk (Hu et al 2018). Because the worst-case risk is the upper bound of the classification risk of the clean data distribution, minimizing the worst-case risk can lead to decreased true risk, which makes the zero-one loss especially robust. However, the zero-one loss is non-differentiable and cannot be optimized directly. One solution is to approximate an upper bound of the zero-one loss and to minimize the upper bound loss instead.\nThe hinge loss, $\\mathcal{L}_\\text{hinge}(\\mathbf{u}) = \\sum_{i=1}^n \\max(0, 1 - u_i)$, defines a rough upper bound of the zero-one loss. Lyu \u0026amp; Tsang (2020) proposed a curriculum loss (CL), which is a tighter upper bound compared to a conventional surrogate loss like the hinge loss, $\\mathcal{L}_\\text{01}(\\mathbf{u}) \\leq \\mathcal{L}_\\text{CL}(\\mathbf{u}) \\leq \\mathcal{L}_\\text{hinge}(\\mathbf{u})$.\n$$ \\mathcal{L}_\\text{CL}(\\mathbf{u}) = \\min_{\\mathbf{w}\\in\\{0,1\\}^n}\\max(\\sum_{i=1}^n w_i \\ell(u_i), n - \\sum_{i=1}^n w_i + \\sum_{i=1}^n\\mathbb{1}[u_i \u0026lt; 0]) $$\nwhere $\\ell(u_i)$ is a base surrogate loss for the zero-one loss (e.g. hinge loss) and the optimal weighting variable $\\mathbf{w}$ is to be learned.\nGiven a label corruption rate $\\rho$, the noise pruned curriculum loss (NPCL) is constructed based on the intuition that an ideal model should correctly classify $n(1-\\rho)$ samples with clean labels but misclassify $n\\rho$ corrupted labels. If $\\rho$ is a known prior, we would know how many samples (with largest losses) to be pruned. Assuming $\\ell(u_1) \\leq \\dots \\leq \\ell(u_n)$, then $u_{n(1-\\rho)+1} = \\dots = u_n =0$ and the following NPCL is the basic CL for only $n(1-\\rho)$ samples:\n$$ \\text{NPCL}(\\mathbf{u}) = \\min_{\\mathbf{w}\\in\\{0,1\\}^{n(1-\\rho)}} \\max(\\sum_{i=1}^{n(1-\\rho)} w_i \\ell(u_i), n(1-\\rho) - \\sum_{i=1}^{n(1-\\rho)} w_i) $$\nWhen experimenting on CIFAR-10, NPCL is comparable with GCE and performs better when the noise rate increases.\nLabel Correction Since it is known some labels are incorrect, noise-robust training can explicitly take the label correction into consideration.\nOne approach is to rely on the estimation of a noise transition matrix and use that to correct the forward or backward loss, named F-correction (Patrini et al. 2017). Let’s first assume that there are $k$ classes and the noise transition matrix $C \\in [0, 1]^{k\\times k}$ is observable and the label flipping probability does not depend on the sample input but only the label (i.e. known as random classification noise, RCN). Let $\\tilde{y}$ denote a corrupted label. Each entry of $C$ represents the probability of one label flipping to another1,\n$$ C_{ij} = p(\\tilde{y}= j \\vert y =i, \\mathbf{x}) \\approx p(\\tilde{y}= j \\vert y =i) $$\nThen we can proceed a forward label correction procedure to incorporate the prior knowledge of noisy transition matrix into the prediction.\n$$ \\begin{aligned} \\mathcal{L}(\\hat{p}(\\tilde{y}\\vert\\mathbf{x}), y) \u0026amp;= - \\log \\hat{p}(\\tilde{y}=i\\vert\\mathbf{x}) \\\\ \u0026amp;= - \\log \\sum_{j=1}^k p(\\tilde{y}=i\\vert y=j) \\hat{p}(y=j\\vert\\mathbf{x}) \\\\ \u0026amp;= - \\log \\sum_{j=1}^k C_{ji} \\hat{p}(y=j\\vert\\mathbf{x}) \\end{aligned} $$\nIn matrix form, we have $\\mathcal{L}(\\hat{p}(y \\vert \\mathbf{x})) = - \\log C^\\top \\hat{p}(y \\vert \\mathbf{x})$. However, such a noise transition matrix is usually unknown. If we have access to a clean dataset, the noise matrix $C$ can be estimated (Hendrycks et al. 2018) by calculating confusion matrix on the clean data. Let’s denote a clean trusted dataset as $\\mathcal{D}_c$ and a noisy dataset as $\\mathcal{D}_n$ going forward.\n$$ \\hat{C}_{ij} = \\frac{1}{\\vert \\mathcal{A}_i\\vert} \\sum_{\\mathbf{x} \\in \\mathcal{A}_i} \\hat{p}(\\tilde{y}=j \\vert y=i, \\mathbf{x}) \\approx p(\\tilde{y}=j \\vert y=i) $$\nwhere $\\mathcal{A}_i$ is a subset of data points from $\\mathcal{D}_c$ with label $i$.\nLet $f(x) = \\hat{p}(\\tilde{y} \\vert \\mathbf{x}; \\theta)$ and this model should be trained with $\\mathcal{L}(f(\\mathbf{x}), y)$ on clean data $\\mathcal{D}_c$ and with $\\mathcal{L}(\\hat{C}^\\top f(\\mathbf{x}), \\hat{y})$ on noisy data $\\mathcal{D}_n$.\nFig. 11. Algorithm of gold loss correction (GLC), estimating the noise transition matrix with a trusted dataset. (Image source: Hendrycks et al. 2018) If the trusted training dataset $\\mathcal{D}_c$ gets large, we can train a neural network only on clean data and distill its knowledge into the primary model (i.e. the final model to make predictions at test time) using corrected pseudo labels (Li et al. 2017). The primary model is trained on the entire dataset, $\\mathcal{D} = \\mathcal{D}_c \\cup \\mathcal{D}_n$. Optionally the \u0026ldquo;side\u0026rdquo; information of label relations in the knowledge graph, if available, can be incorporated into distillation to help the robustness of the predictions of the network that is trained on limited data.\nThe label correction distillation works as following:\n First train an auxiliary model $f_c$ from the small clean dataset $\\mathcal{D}_c$ to provide a soft label for each sample $x_i$, $s_i = \\delta(f_c(\\mathbf{x}_i)/T)$ is the sigmoid activation with temperature $T$. Because the clean dataset is not large, $f_c$ is likely to overfit, Li et al. (2017) turn to a knowledge graph $\\mathcal{G}$ that defines the relations in the label space and propagate the prediction among labels accordingly. The new soft label is donated as $\\hat{s}_i = \\mathcal{G}(s_i)$. The primary model $f$ is trained with predictions from $f_c$ to imitate, $$ \\mathcal{L}(y_i, f(\\mathbf{x}_i)) = \\text{CE}(\\underbrace{\\lambda y_i + (1 - \\lambda) \\hat{s}_i}_\\text{pseudo label}, f(\\mathbf{x}_i)) $$\nSample Reweighting and Selection Some samples may be more likely to have inaccurate labels than others. Such estimation gives us intuition on which samples should be weighted less or more in the loss function. However, considering two types of biases in training data, class imbalance and noisy labels, there is actually a contradictory preference \u0026mdash; We would prefer samples with larger loss to balance the label distribution but those with smaller loss for mitigating the potential noise. Some work (Ren et al. 2018) thus argue that in order to learn general forms of training data biases, it is necessary to have a small unbiased validation to guide training. The sample reweighting methods presented in this section all assume access to a small trusted set of clean data.\nConsidering a binary classification task with random classification noise, $y, \\hat{y} \\in \\{-1, +1\\}$, the label flipping probabilities, $\\rho_{-1}, \\rho_{+1} \\in [0, 0.5)$, are defined as:\n$$ \\rho_{-1} = P(\\tilde{y} = +1 \\vert y=-1)\\quad\\rho_{+1} = P(\\tilde{y}=-1 \\vert y =+1) $$\nLiu \u0026amp; Tao (2015) applies importance reweighting to adjust the weighted distribution of observed $\\hat{y}$ to match the distribution of unobservable $y$. Let $\\mathcal{D}$ be the true data distribution and $\\mathcal{D}_\\rho$ be the corrupted version.\n$$ \\begin{aligned} \\mathcal{L}_{\\ell,\\mathcal{D}}(f) \u0026amp;= \\mathbb{E}_{(\\mathbf{x},y)\\sim \\mathcal{D}}[\\ell(f(\\mathbf{x}), y)] \\\\ \u0026amp;= \\mathbb{E}_{(\\mathbf{x},\\tilde{y})\\sim \\mathcal{D}_\\rho} \\Big[ \\frac{P_\\mathcal{D}(\\mathbf{x}, y=\\tilde{y})}{P_{\\mathcal{D}_\\rho}(\\mathbf{x}, \\tilde{y})} \\ell(f(\\mathbf{x}), \\tilde{y}) \\Big] \\\\ \u0026amp;= \\mathbb{E}_{(\\mathbf{x},\\tilde{y})\\sim \\mathcal{D}_\\rho} \\Big[ \\frac{P_\\mathcal{D}(y=\\tilde{y} \\vert \\mathbf{x})}{P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert \\mathbf{x})} \\ell(f(\\mathbf{x}), \\tilde{y}) \\Big] \u0026amp; \\text{; because }P_\\mathcal{D}(\\mathbf{x})=P_{\\mathcal{D}_\\rho}(\\mathbf{x}) \\\\ \u0026amp;= \\mathbb{E}_{(\\mathbf{x},\\tilde{y})\\sim \\mathcal{D}_\\rho} [ w(\\mathbf{x}, \\hat{y})\\ell(f(\\mathbf{x}), \\tilde{y}) ] = \\mathcal{L}_{w\\ell,\\mathcal{D}}(f) \\end{aligned} $$\nBecause,\n$$ \\begin{aligned} P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert \\mathbf{x}) \u0026amp;= P_\\mathcal{D}(y = \\tilde{y} \\vert \\mathbf{x}) P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert y=\\tilde{y}) + P_\\mathcal{D}(y = - \\tilde{y} \\vert \\mathbf{x}) P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert y = - \\tilde{y}) \\\\ \u0026amp;= P_\\mathcal{D}(y = \\tilde{y} \\vert \\mathbf{x}) (1 - P_{\\mathcal{D}_\\rho}(- \\tilde{y} \\vert y=\\tilde{y})) + (1 - P_\\mathcal{D}(y = \\tilde{y} \\vert \\mathbf{x})) P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert y = - \\tilde{y}) \\\\ \u0026amp;= P_\\mathcal{D}(y = \\tilde{y} \\vert \\mathbf{x}) (1 - \\rho_{\\tilde{y}}) + (1 - P_\\mathcal{D}(y = \\tilde{y} \\vert \\mathbf{x})) \\rho_{-\\tilde{y}} \\\\ \u0026amp;= P_\\mathcal{D}(y = \\tilde{y} \\vert \\mathbf{x})(1 - \\rho_{\\tilde{y}} - \\rho_{-\\tilde{y}}) + \\rho_{-\\tilde{y}} \\end{aligned} $$\nThus the weight assigned to a noisy sample is,\n$$ w(x, \\tilde{y}) = \\frac{P_\\mathcal{D}(y=\\tilde{y} \\vert \\mathbf{x})}{P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert \\mathbf{x})} = \\frac{P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert \\mathbf{x}) - \\rho_{-\\tilde{y}}}{(1-\\rho_0-\\rho_1) P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert \\mathbf{x})} $$\nwhere $P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert \\mathbf{x})$ can be estimated using a simple logistic regression, but estimating the note rates is more challenging. Naive cross-validation can work out but is costly as the quality depends on the amount of trusted labels available. The paper approximates the upper bounds for noise rates first, $\\rho_\\tilde{y} \\leq P_{\\mathcal{D}_\\rho}(- \\tilde{y} \\vert \\mathbf{x})$ and then use a mild assumption to efficiently estimate them, $\\hat{\\rho}_{\\tilde{y}} = \\min_{\\mathbf{x} \\in {\\mathbf{x}_1, \\dots, \\mathbf{x}_n}} \\hat{P}_{\\mathcal{D}_\\rho}(- \\tilde{y} \\vert \\mathbf{x})$. In their experiments, the advantage of importance reweighting only varies across datasets and is more beneficial when the noise rates are high in general.\nSample reweighting schemes can be learned by a separate network. Learning to reweight (L2R; Ren et al. 2018) is a meta-learning approach to directly optimize the weights in pursuit of best validation performance on a known set of clean data. Each example gets assigned with the weight based on its gradient direction. The weighted loss to minimize $\\theta^*(\\mathbf{w})$ involves a set of training weights $\\{w_i\\}_{i=1}^n$ as unknown hyperparameters. These sample training weights $w_i$ are learned to minimize the loss on this unbiased validate set, $\\mathcal{D}_c = \\{x^\\text{valid}_j\\}_{j=1}^m$.\n$$ \\begin{aligned} \\theta^{*}(\\mathbf{w}) \u0026amp;= \\arg\\min_\\theta \\sum_{i=1}^n w_i f(x_i; \\theta) \\\\ \\text{where optimal }\\mathbf{w}^{*} \u0026amp;= \\arg\\min_{\\mathbf{w}, \\mathbf{w} \\geq \\mathbf{0}} \\frac{1}{m} \\sum_{j=1}^m f(\\mathbf{x}^\\text{valid}_j; \\theta^{*}(\\mathbf{w})) \\end{aligned} $$\nThe learning process involves two nested loops of optimization, so pretty expensive, 3x training time.\nFig. 12. Illustration of updates implemented by second order automatic differentiation. (Image source: Ren et al. 2018) They ran experiments on (1) two-class MNIST to test the robustness of L2R when the class distribution is imbalanced and (2) CIFAR-10 with noisy labels. L2R is shown to be better than other baseline methods at the time on both tasks.\nFig. 13. Left: Imbalanced classes on MNIST (class 4 and 9); Right: Effect of the number of clean samples. Task is on CIFAR-10 with 40% of data flipped to label 3. (Image source: Ren et al. 2018) MentorNet (Jiang et al. 2018) uses teach-student curriculum learning to weight data. It incorporates two different networks, a mentor and a student. The mentor network provides a data-driven curriculum (i.e. sample training weighting scheme) for the student to focus on learning likely correct labels.\nLet $g_\\psi$ be the MentorNet parameterized by $\\psi$ , $f_\\theta$ be the StudentNet parametrized by $\\theta$ and $G$ be a predefined curriculum parameterized by $\\lambda$. Given the training data $\\mathcal{D} = \\{(\\mathbf{x}_i, y_i)\\}_{i=1}^n$ for a $k$-class classification task, the MentorNet needs to predict a time-varying latent weight variable $\\mathbf{w} \\in [0, 1]^{n \\times k}$ to guide the learning of StudentNet, taking an intermediate feature processed by StudentNet $f$ , $\\mathbf{z}_i = \\phi_{f_\\theta}(\\mathbf{x}_i, y_i)$:\n$$ g_{\\psi^{*}}(\\mathbf{z}_i) = \\arg\\min_{w_i \\in [0,1]} \\mathcal{L}(\\theta, \\mathbf{w}), \\forall i \\in [1, n] $$\nStudentNet learns to minimize the following learning objective,\n$$ \\begin{aligned} \\mathcal{L}(\\theta, \\mathbf{w}) \u0026amp;= \\frac{1}{n}\\sum_{i=1}^n \\mathbf{w}_i^\\top \\ell(y_i, f_\\theta(\\mathbf{x}_i)) + G_\\lambda(\\mathbf{w}) + \\alpha |\\theta|^2_2 \\\\ \u0026amp;= \\frac{1}{n}\\sum_{i=1}^n g_\\psi(\\mathbf{z}_i)^\\top \\ell_i + G_\\lambda(\\mathbf{w}) + \\alpha |\\theta|^2_2 \u0026amp; \\text{; Let }\\ell_i = \\ell(y_i, f_\\theta(\\mathbf{x}_i)) \\\\ \\end{aligned} $$\nThe mentor network $g_\\psi$ is trained with cross entropy on the input $(\\phi_{f_\\theta}(\\mathbf{x}_i, y_i), w^{*}_i)$ , where $v^*_i=1$ if $y_i$ is known to be a correct label, otherwise 0. The architecture of MentorNet does not have to be very complicated. In the paper, they adopted a LSTM layer to capture the prediction variance in time.\nFig. 14. Model architecture of MentorNet and StudentNet which are trained simultaneously, where MentorNet predicts the sample weights for StudentNet to train on. (Image source: Jiang et al. 2018) Different from MentorNet where one network explicitly learns weighting scheme and curriculum for the other network, Co-teaching (Han et al. 2018) trains two neural networks, $f_1$ and $f_2$, simultaneously and lets them teach each other by feeding data to each other selectively. Co-teaching consists of three steps:\n First, each network feeds forward the current mini-batch and selects samples with potentially clean labels; Then two networks exchange information on which samples in the batch should be used for training. Small-loss instances are selected as they are more likely to be associated with correct labels. The percentage of the batch to select is determined by a time-dependent function $R(T)$. The value of $R(T)$ decreases in time because the network is more likely to overfit and memorize noisy labels as training progresses and thus we use a smaller sampling percentage to keep the selected data quality high. Finally, each network runs back-propagation updates with the data selected by its peer. According to their experiments, co-teaching performs better than F-correction where the noise rates are high or the corruption transition matrix is not symmetric.\nFig. 15. Algorithm of co-teaching in which two networks are trained separately in parallel and each selects samples for the other to train on. (Image source: Han et al. 2018) Citation Cited as:\n Weng, Lilian. (Apr 2022). Learning with not enough data part 3: data generation. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2022-04-15-data-gen/.\n Or\n@article{weng2022datagen, title = \u0026quot;Learning with not Enough Data Part 3: Data Generation\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;Lil'Log\u0026quot;, year = \u0026quot;2022\u0026quot;, month = \u0026quot;Apr\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2022-04-15-data-gen/\u0026quot; } Reference [1] Zhang et al. \u0026ldquo;Adversarial AutoAgument\u0026rdquo; ICLR 2020.\n[2] Kumar et al. \u0026ldquo;Data Augmentation using Pre-trained Transformer Models.\u0026quot; AACL 2020 Workshop.\n[3] Anaby-Tavor et al. \u0026ldquo;Not enough data? Deep learning to rescue!\u0026quot; AAAI 2020.\n[4] Wang et al. \u0026ldquo;Want To Reduce Labeling Cost? GPT-3 Can Help.\u0026quot; EMNLP 2021.\n[5] Wang et al. \u0026ldquo;Towards Zero-Label Language Learning.\u0026quot; arXiv preprint arXiv:2109.09193 (2021).\n[6] Schick \u0026amp; Schutze. Generating Datasets with Pretrained Language Models.\u0026quot; EMNLP 2021.\n[7] Han et al. \u0026ldquo;Unsupervised Neural Machine Translation with Generative Language Models Only.\u0026quot; arXiv preprint arXiv:2110.05448 (2021).\n[8] Guo et al. \u0026ldquo;Augmenting data with mixup for sentence classification: An empirical study.\u0026quot; arXiv preprint arXiv:1905.08941 (2019).\n[9] Ekin D. Cubuk et al. \u0026ldquo;AutoAugment: Learning augmentation policies from data.\u0026quot; arXiv preprint arXiv:1805.09501 (2018).\n[10] Daniel Ho et al. \u0026ldquo;Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules.\u0026quot; ICML 2019.\n[11] Cubuk \u0026amp; Zoph et al. \u0026ldquo;RandAugment: Practical automated data augmentation with a reduced search space.\u0026quot; arXiv preprint arXiv:1909.13719 (2019).\n[12] Zhang et al. \u0026ldquo;mixup: Beyond Empirical Risk Minimization.\u0026quot; ICLR 2017.\n[13] Yun et al. \u0026ldquo;CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features.\u0026quot; ICCV 2019.\n[14] Kalantidis et al. \u0026ldquo;Mixing of Contrastive Hard Negatives\u0026rdquo; NeuriPS 2020.\n[15] Wei \u0026amp; Zou. \u0026ldquo;EDA: Easy data augmentation techniques for boosting performance on text classification tasks.\u0026quot; EMNLP-IJCNLP 2019.\n[16] Kobayashi. \u0026ldquo;Contextual Augmentation: Data Augmentation by Words with Paradigmatic Relations.\u0026quot; NAACL 2018\n[17] Fang et al. \u0026ldquo;CERT: Contrastive self-supervised learning for language understanding.\u0026quot; arXiv preprint arXiv:2005.12766 (2020).\n[18] Gao et al. \u0026ldquo;SimCSE: Simple Contrastive Learning of Sentence Embeddings.\u0026quot; arXiv preprint arXiv:2104.08821 (2020). [code]\n[19] Shen et al. \u0026ldquo;A Simple but Tough-to-Beat Data Augmentation Approach for Natural Language Understanding and Generation.\u0026quot; arXiv preprint arXiv:2009.13818 (2020) [code]\n[20] Wang \u0026amp; van den Oord. \u0026ldquo;Multi-Format Contrastive Learning of Audio Representations.\u0026quot; NeuriPS Workshop 2020.\n[21] Wu et al. \u0026ldquo;Conditional BERT Contextual Augmentation\u0026rdquo; arXiv preprint arXiv:1812.06705 (2018).\n[22 Zhu et al. \u0026ldquo;FreeLB: Enhanced Adversarial Training for Natural Language Understanding.\u0026quot; ICLR 2020.\n[23] Affinity and Diversity: Quantifying Mechanisms of Data Augmentation Gontijo-Lopes et al. 2020 (https://arxiv.org/abs/2002.08973)\n[24] Song et al. \u0026ldquo;Learning from Noisy Labels with Deep Neural Networks: A Survey.\u0026quot; TNNLS 2020.\n[25] Zhang \u0026amp; Sabuncu. \u0026ldquo;Generalized cross entropy loss for training deep neural networks with noisy labels.\u0026quot; NeuriPS 2018.\n[26] Goldberger \u0026amp; Ben-Reuven. \u0026ldquo;Training deep neural-networks using a noise adaptation layer.\u0026quot; ICLR 2017.\n[27] Sukhbaatar et al. \u0026ldquo;Training convolutional networks with noisy labels.\u0026quot; ICLR Workshop 2015.\n[28] Patrini et al. \u0026ldquo;Making Deep Neural Networks Robust to Label Noise: a Loss Correction Approach\u0026rdquo; CVPR 2017.\n[29] Hendrycks et al. \u0026ldquo;Using trusted data to train deep networks on labels corrupted by severe noise.\u0026quot; NeuriPS 2018.\n[30] Zhang \u0026amp; Sabuncu. \u0026ldquo;Generalized cross entropy loss for training deep neural networks with noisy labels.\u0026quot; NeuriPS 2018.\n[31] Lyu \u0026amp; Tsang. \u0026ldquo;Curriculum loss: Robust learning and generalization against label corruption.\u0026quot; ICLR 2020.\n[32] Han et al. \u0026ldquo;Co-teaching: Robust training of deep neural networks with extremely noisy labels.\u0026quot; NeuriPS 2018. (code)\n[33] Ren et al. \u0026ldquo;Learning to reweight examples for robust deep learning.\u0026quot; ICML 2018.\n[34] Jiang et al. \u0026ldquo;MentorNet: Learning data-driven curriculum for very deep neural networks on corrupted labels.\u0026quot; ICML 2018.\n[35] Li et al. \u0026ldquo;Learning from noisy labels with distillation.\u0026quot; ICCV 2017.\n[36] Liu \u0026amp; Tao. \u0026ldquo;Classification with noisy labels by importance reweighting.\u0026quot; TPAMI 2015.\n[37] Ghosh, et al. \u0026ldquo;Robust loss functions under label noise for deep neural networks.\u0026quot; AAAI 2017.\n[38] Hu et al. \u0026ldquo;Does Distributionally Robust Supervised Learning Give Robust Classifiers? \u0026ldquo; ICML 2018.\n $y=i$ is not a technically correct way to annotate a label being a certain value, since we usually use one-hot encoding (i.e. $\\mathbf{y} = \\mathbf{e}_i$). We use this form for simplicity.\u0026#160;\u0026#x21a9;\u0026#xfe0e;\n ","permalink":"https://lilianweng.github.io/posts/2022-04-15-data-gen/","summary":"Here comes the Part 3 on learning with not enough data (Previous: Part 1 and Part 2). Let’s consider two approaches for generating synthetic data for training.\n Augmented data. Given a set of existing training samples, we can apply a variety of augmentation, distortion and transformation to derive new data points without losing the key attributes. We have covered a bunch of augmentation methods on text and images in a previous post on contrastive learning.","title":"Learning with not Enough Data Part 3: Data Generation"},{"content":"This is part 2 of what to do when facing a limited amount of labeled data for supervised learning tasks. This time we will get some amount of human labeling work involved, but within a budget limit, and therefore we need to be smart when selecting which samples to label.\nNotations Symbol Meaning $K$ Number of unique class labels. $(\\mathbf{x}^l, y) \\sim \\mathcal{X}, y \\in \\{0, 1\\}^K$ Labeled dataset. $y$ is a one-hot representation of the true label. $\\mathbf{u} \\sim \\mathcal{U}$ Unlabeled dataset. $\\mathcal{D} = \\mathcal{X} \\cup \\mathcal{U}$ The entire dataset, including both labeled and unlabeled examples. $\\mathbf{x}$ Any sample which can be either labeled or unlabeled. $\\mathbf{x}_i$ The $i$-th sample. $U(\\mathbf{x})$ Scoring function for active learning selection. $P_\\theta(y \\vert \\mathbf{x})$ A softmax classifier parameterized by $\\theta$. $\\hat{y} = \\arg\\max_{y \\in \\mathcal{Y}} P_\\theta(y \\vert \\mathbf{x})$ The most confident prediction by the classifier. $B$ Labeling budget (the maximum number of samples to label). $b$ Batch size. What is Active Learning? Given an unlabeled dataset $\\mathcal{U}$ and a fixed amount of labeling cost $B$, active learning aims to select a subset of $B$ examples from $\\mathcal{U}$ to be labeled such that they can result in maximized improvement in model performance. This is an effective way of learning especially when data labeling is difficult and costly, e.g. medical images. This classical survey paper in 2010 lists many key concepts. While some conventional approaches may not apply to deep learning, discussion in this post mainly focuses on deep neural models and training in batch mode.\nFig. 1. Illustration of a cyclic workflow of active learning, producing better models more efficiently by smartly choosing which samples to label. To simplify the discussion, we assume that the task is a $K$-class classification problem in all the following sections. The model with parameters $\\theta$ outputs a probability distribution over the label candidates, which may or may not be calibrated, $P_\\theta(y \\vert \\mathbf{x})$ and the most likely prediction is $\\hat{y} = \\arg\\max_{y \\in \\mathcal{Y}} P_\\theta(y \\vert \\mathbf{x})$.\nAcquisition Function The process of identifying the most valuable examples to label next is referred to as \u0026ldquo;sampling strategy\u0026rdquo; or \u0026ldquo;query strategy\u0026rdquo;. The scoring function in the sampling process is named \u0026ldquo;acquisition function\u0026rdquo;, denoted as $U(\\mathbf{x})$. Data points with higher scores are expected to produce higher value for model training if they get labeled.\nHere is a list of basic sampling strategies.\nUncertainty Sampling Uncertainty sampling selects examples for which the model produces most uncertain predictions. Given a single model, uncertainty can be estimated by the predicted probabilities, although one common complaint is that deep learning model predictions are often not calibrated and not correlated with true uncertainty well. In fact, deep learning models are often overconfident.\n Least confident score, also known as variation ratio: $U(\\mathbf{x}) = 1 - P_\\theta(\\hat{y} \\vert \\mathbf{x})$. Margin score: $U(\\mathbf{x}) = P_\\theta(\\hat{y}_1 \\vert \\mathbf{x}) - P_\\theta(\\hat{y}_2 \\vert \\mathbf{x})$, where $\\hat{y}_1$ and $\\hat{y}_2$ are the most likely and the second likely predicted labels. Entropy: $U(\\mathbf{x}) = \\mathcal{H}(P_\\theta(y \\vert \\mathbf{x})) = - \\sum_{y \\in \\mathcal{Y}} P_\\theta(y \\vert \\mathbf{x}) \\log P_\\theta(y \\vert \\mathbf{x})$. Another way to quantify uncertainty is to rely on a committee of expert models, known as Query-By-Committee (QBC). QBC measures uncertainty based on a pool of opinions and thus it is critical to keep a level of disagreement among committee members. Given $C$ models in the committee pool, each parameterized by $\\theta_1, \\dots, \\theta_C$.\n Voter entropy: $U(\\mathbf{x}) = \\mathcal{H}(\\frac{V(y)}{C})$, where $V(y)$ counts the number of votes from the committee on the label $y$. Consensus entropy: $U(\\mathbf{x}) = \\mathcal{H}(P_\\mathcal{C})$, where $P_\\mathcal{C}$ is the prediction averaging across the committee. KL divergence: $U(\\mathbf{x}) = \\frac{1}{C} \\sum_{c=1}^C D_\\text{KL} (P_{\\theta_c} | P_\\mathcal{C})$ Diversity Sampling Diversity sampling intend to find a collection of samples that can well represent the entire data distribution. Diversity is important because the model is expected to work well on any data in the wild, just not on a narrow subset. Selected samples should be representative of the underlying distribution. Common approaches often rely on quantifying the similarity between samples.\nExpected Model Change Expected model change refers to the impact that a sample brings onto the model training. The impact can be the influence on the model weights or the improvement over the training loss. A later section reviews several works on how to measure model impact triggered by selected data samples.\nHybrid Strategy Many methods above are not mutually exclusive. A hybrid sampling strategy values different attributes of data points, combining different sampling preferences into one. Often we want to select uncertain but also highly representative samples.\nDeep Acquisition Function Measuring Uncertainty The model uncertainty is commonly categorized into two buckets (Der Kiureghian \u0026amp; Ditlevsen 2009, Kendall \u0026amp; Gal 2017):\n Aleatoric uncertainty is introduced by noise in the data (e.g. sensor data, noise in the measurement process) and it can be input-dependent or input-independent. It is generally considered as irreducible since there is missing information about the ground truth. Epistemic uncertainty refers to the uncertainty within the model parameters and therefore we do not know whether the model can best explain the data. This type of uncertainty is theoretically reducible given more data Ensemble and Approximated Ensemble There is a long tradition in machine learning of using ensembles to improve model performance. When there is a significant diversity among models, ensembles are expected to yield better results. This ensemble theory is proved to be correct by many ML algorithms; for example, AdaBoost aggregates many weak learners to perform similar or even better than a single strong learner. Bootstrapping ensembles multiple trials of resampling to achieve more accurate estimation of metrics. Random forests or GBM is also a good example for the effectiveness of ensembling.\nTo get better uncertainty estimation, it is intuitive to aggregate a collection of independently trained models. However, it is expensive to train a single deep neural network model, let alone many of them. In reinforcement learning, Bootstrapped DQN (Osband, et al. 2016) is equipped with multiple value heads and relies on the uncertainty among an ensemble of Q value approximation to guide exploration in RL.\nIn active learning, a commoner approach is to use dropout to \u0026ldquo;simulate\u0026rdquo; a probabilistic Gaussian process (Gal \u0026amp; Ghahramani 2016). We thus ensemble multiple samples collected from the same model but with different dropout masks applied during the forward pass to estimate the model uncertainty (epistemic uncertainty). The process is named MC dropout (Monte Carlo dropout), where dropout is applied before every weight layer, is approved to be mathematically equivalent to an approximation to the probabilistic deep Gaussian process (Gal \u0026amp; Ghahramani 2016). This simple idea has been shown to be effective for classification with small datasets and widely adopted in scenarios when efficient model uncertainty estimation is needed.\nDBAL (Deep Bayesian active learning; Gal et al. 2017) approximates Bayesian neural networks with MC dropout such that it learns a distribution over model weights. In their experiment, MC dropout performed better than random baseline and mean standard deviation (Mean STD), similarly to variation ratios and entropy measurement.\nFig. 2. Active learning results of DBAL on MNIST. (Image source: Gal et al. 2017). Beluch et al. (2018) compared ensemble-based models with MC dropout and found that the combination of naive ensemble (i.e. train multiple models separately and independently) and variation ratio yields better calibrated predictions than others. However, naive ensembles are very expensive, so they explored a few alternative cheaper options:\n Snapshot ensemble: Use a cyclic learning rate schedule to train an implicit ensemble such that it converges to different local minima. Diversity encouraging ensemble (DEE): Use a base network trained for a small number of epochs as initialization for $n$ different networks, each trained with dropout to encourage diversity. Split head approach: One base model has multiple heads, each corresponding to one classifier. Unfortunately all the cheap implicit ensemble options above perform worse than naive ensembles. Considering the limit on computational resources, MC dropout is still a pretty good and economical choice. Naturally, people also try to combine ensemble and MC dropout (Pop \u0026amp; Fulop 2018) to get a bit of additional performance gain by stochastic ensemble.\nUncertainty in Parameter Space Bayes-by-backprop (Blundell et al. 2015) measures weight uncertainty in neural networks directly. The method maintains a probability distribution over the weights $\\mathbf{w}$, which is modeled as a variational distribution $q(\\mathbf{w} \\vert \\theta)$ since the true posterior $p(\\mathbf{w} \\vert \\mathcal{D})$ is not tractable directly. The loss is to minimize the KL divergence between $q(\\mathbf{w} \\vert \\theta)$ and $p(\\mathbf{w} \\vert \\mathcal{D})$,\n $$ \\begin{aligned} \\mathcal{L}(\\theta) \u0026= \\text{KL}[q(\\mathbf{w}\\vert\\theta) \\| p(\\mathbf{w} \\vert \\mathcal{D})] \\\\ \u0026= \\int q(\\mathbf{w}\\vert\\theta) \\log \\frac{q(\\mathbf{w}\\vert\\theta)}{p(\\mathbf{w}) p(\\mathcal{D}\\vert \\mathbf{w})} d\\mathbf{w} \\\\ \u0026= \\text{KL}[q(\\mathbf{w}\\vert\\theta) \\| p(w)] - \\mathbb{E}_{q(\\mathbf{w}\\vert\\theta)} [\\log p(\\mathcal{D} \\vert \\mathbf{w})] \\\\ \u0026\\approx \\log q(\\mathbf{w} \\vert \\theta) - \\log p(\\mathbf{w}) p(\\mathcal{D}\\vert \\mathbf{w}) \u0026 \\text{; monte carlo sampling; }q(\\mathbf{w} \\vert \\theta)\\text{ \u0026 }p(\\mathbf{w})\\text{ are close.} \\end{aligned} $$ The variational distribution $q$ is typically a Gaussian with diagonal covariance and each weight is sampled from $\\mathcal{N}(\\mu_i, \\sigma_i^2)$. To ensure non-negativity of $\\sigma_i$, it is further parameterized via softplus, $\\sigma_i = \\log(1 + \\exp(\\rho_i))$ where the variational parameters are $\\theta = \\{\\mu_i , \\rho_i\\}^d_{i=1}$.\nThe process of Bayes-by-backprop can be summarized as:\n Sample $\\epsilon \\sim \\mathcal{N}(0, I)$ Let $\\mathbf{w} = \\mu + \\log(1+ \\exp(\\rho)) \\circ \\epsilon$ Let $\\theta = (\\mu, \\rho)$ Let $f(\\mathbf{w}, \\theta) = \\log q(\\mathbf{w} \\vert \\theta) - \\log p(\\mathbf{w})p(\\mathcal{D}\\vert \\mathbf{w})$ Calculate the gradient of $f(\\mathbf{w}, \\theta)$ w.r.t. to $\\mu$ and $\\rho$ and then update $\\theta$. Uncertainty is measured by sampling different model weights during inference. Loss Prediction The loss objective guides model training. A low loss value indicates that a model can make good and accurate predictions. Yoo \u0026amp; Kweon (2019) designed a loss prediction module to predict the loss value for unlabeled inputs, as an estimation of how good a model prediction is on the given data. Data samples are selected if the loss prediction module makes uncertain predictions (high loss value) for them. The loss prediction module is a simple MLP with dropout, that takes several intermediate layer features as inputs and concatenates them after a global average pooling.\nFig. 3. Use the model with a loss prediction module to do active learning selection. (Image source: Yoo \u0026 Kweon 2019) Let $\\hat{l}$ be the output of the loss prediction module and $l$ be the true loss. When training the loss prediction module, a simple MSE loss $=(l - \\hat{l})^2$ is not a good choice, because the loss decreases in time as the model learns to behave better. A good learning objective should be independent of the scale changes of the target loss. They instead rely on the comparison of sample pairs. Within each batch of size $b$, there are $b/2$ pairs of samples $(\\mathbf{x}_i, \\mathbf{x}_j)$ and the loss prediction model is expected to correctly predict which sample has a larger loss.\n $$ \\begin{aligned} \\mathcal{L}_\\text{loss}(\\mathbf{x}_i, \\mathbf{x}_j) \u0026= \\max\\big( 0, -\\mathbb{1}(l(\\mathbf{x}_i), l(\\mathbf{x}_j)) \\cdot (\\hat{l}(\\mathbf{x}_i) - \\hat{l}(\\mathbf{x}_j)) + \\epsilon \\big) \\\\ \\text{where } \\mathbb{1}(l_i, l_j) \u0026= \\begin{cases} +1 \u0026 \\text{if }l_i l_j \\\\ -1 \u0026 \\text{otherwise} \\end{cases} \\end{aligned} $$ where $\\epsilon$ is a predefined positive margin constant.\nIn experiments on three vision tasks, active learning selection based on the loss prediction performs better than random baseline, entropy based acquisition and core-set.\nFig. 4. Active learning results of loss prediction module based selection, in comparison with other approaches. (Image source: Yoo \u0026 Kweon 2019) Adversarial Setup Sinha et al. (2019) proposed a GAN-like setup, named VAAL (Variational Adversarial Active Learning), where a discriminator is trained to distinguish unlabeled data from labeled data. Interestingly, active learning acquisition criteria does not depend on the task performance in VAAL.\nFig. 5. Illustration of VAAL (Variational adversarial active learning). (Image source: Sinha et al. 2019) The $\\beta$-VAE learns a latent feature space $\\mathbf{z}^l \\cup \\mathbf{z}^u$, for labeled and unlabeled data respectively, aiming to trick the discriminator $D(.)$ that all the data points are from the labeled pool; The discriminator $D(.)$ predicts whether a sample is labeled (1) or not (0) based on a latent representation $\\mathbf{z}$. VAAL selects unlabeled samples with low discriminator scores, which indicates that those samples are sufficiently different from previously labeled ones. The loss for VAE representation learning in VAAL contains both a reconstruction part (minimizing the ELBO of given samples) and an adversarial part (labeled and unlabeled data is drawn from the same probability distribution $q_\\phi$):\n $$ \\begin{aligned} \\mathcal{L}_\\text{VAE} \u0026= \\lambda_1 \\mathcal{L}^\\text{rec}_\\text{VAE} + \\lambda_2 \\mathcal{L}^\\text{adv}_\\text{VAE} \\\\ \\mathcal{L}^\\text{rec}_\\text{VAE} \u0026= \\mathbb{E}[\\log p_\\theta(\\mathbf{x}^l \\vert \\mathbf{z}^l)] - \\beta \\text{KL}(q_\\phi(\\mathbf{z}^l \\vert \\mathbf{x}^l) \\| p(\\mathbf{\\tilde{z}})) + \\mathbb{E}[\\log p_\\theta(\\mathbf{u} \\vert \\mathbf{z}^u)] - \\beta \\text{KL}(q_\\phi(\\mathbf{z}^u \\vert \\mathbf{u}) \\| p(\\mathbf{\\tilde{z}})) \\\\ \\mathcal{L}^\\text{adv}_\\text{VAE} \u0026= - \\mathbb{E}[\\log D(q_\\phi (\\mathbf{z}^l \\vert \\mathbf{x}^l))] - \\mathbb{E}[\\log D(q_\\phi(\\mathbf{z}^u \\vert \\mathbf{u}))] \\end{aligned} $$ where $p(\\mathbf{\\tilde{z}})$ is a unit Gaussian as a predefined prior and $\\beta$ is the Lagrangian parameter.\nThe discriminator loss is:\n $$ \\mathcal{L}_D = -\\mathbb{E}[\\log D(q_\\phi (\\mathbf{z}^l \\vert \\mathbf{x}^l))] - \\mathbb{E}[\\log (1 - D(q_\\phi (\\mathbf{z}^u \\vert \\mathbf{u})))] $$ Fig. 6. Experiment results of VAAL (variational adversarial active learning) on several image classification tasks. (Image source: Sinha et al. 2019 Ablation studies showed that jointly training VAE and discriminator is critical. Their results are robust to the biased initial labeled pool, different labeling budgets and noisy oracle.\nMAL (Minimax Active Learning; Ebrahimiet al. 2021) is an extension of VAAL. The MAL framework consists of an entropy minimizing feature encoding network $F$ followed by an entropy maximizing classifier $C$. This minimax setup reduces the distribution gap between labeled and unlabeled data.\nFig. 7. Illustration of the MAL (minimax active learning) framework. (Image source: Ebrahimiet al. 2021) A feature encoder $F$ encodes a sample into a $\\ell_2$-normalized $d$-dimensional latent vector. Assuming there are $K$ classes, a classifier $C$ is parameterized by $\\mathbf{W} \\in \\mathbb{R}^{d \\times K}$.\n(1) First $F$ and $C$ are trained on labeled samples by a simple cross entropy loss to achieve good classification results,\n $$ \\mathcal{L}_\\text{CE} = -\\mathbb{E}_{(\\mathbf{x}^l, y) \\sim \\mathcal{X}} \\sum_{k=1}^K \\mathbb{1}[k=y] \\log\\Big( \\sigma(\\frac{1}{T} \\frac{\\mathbf{W}^\\top F\\big(\\mathbf{x}^l)}{\\|F(\\mathbf{x}^l)\\|}\\big) \\Big) $$ (2) When training on the unlabeled examples, MAL relies on a minimax game setup\n $$ \\begin{aligned} \\mathcal{L}_\\text{Ent} \u0026= -\\sum^K_{k=1} p(y=k \\vert \\mathbf{u}) \\log p(y=k\\vert \\mathbf{u}) \\\\ \\theta^*_F, \\theta^*_C \u0026= \\min_F\\max_C \\mathcal{L}_\\text{Ent} \\\\ \\theta_F \u0026\\gets \\theta_F - \\alpha_1 \\nabla \\mathcal{L}_\\text{Ent} \\\\ \\theta_C \u0026\\gets \\theta_C + \\alpha_2 \\nabla \\mathcal{L}_\\text{Ent} \\end{aligned} $$ where,\n First, minimizing the entropy in $F$ encourages unlabeled samples associated with similar predicted labels to have similar features. Maximizing the entropy in $C$ adversarially makes the prediction to follow a more uniform class distribution. (My understanding here is that because the true label of an unlabeled sample is unknown, we should not optimize the classifier to maximize the predicted labels just yet.) The discriminator is trained in the same way as in VAAL.\nSampling strategy in MAL considers both diversity and uncertainty:\n Diversity: the score of $D$ indicates how similar a sample is to previously seen examples. A score closer to 0 is better to select unfamiliar data points. Uncertainty: use the entropy obtained by $C$. A higher entropy score indicates that the model cannot make a confident prediction yet. The experiments compared MAL to random, entropy, core-set, BALD and VAAL baselines, on image classification and segmentation tasks. The results look pretty strong.\nFig. 8. Performance of MAL on ImageNet. (Table source: Ebrahimiet al. 2021) CAL (Contrastive Active Learning; Margatina et al. 2021) intends to select contrastive examples. If two data points with different labels share similar network representations $\\Phi(.)$, they are considered as contrastive examples in CAL. Given a pair of contrastive examples $(\\mathbf{x}_i, \\mathbf{x}_j)$, they should\n $$ d(\\Phi(\\mathbf{x}_i), \\Phi(\\mathbf{x}_j)) Given an unlabeled sample $\\mathbf{x}$, CAL runs the following process:\n Select the top $k$ nearest neighbors in the model feature space among the labeled samples, $\\{(\\mathbf{x}^l_i, y_i\\}_{i=1}^M \\subset \\mathcal{X}$. Compute the KL divergence between the model output probabilities of $\\mathbf{x}$ and each in $\\{\\mathbf{x}^l\\}$. The contrastive score of $\\mathbf{x}$ is the average of these KL divergence values: $s(\\mathbf{x}) = \\frac{1}{M} \\sum_{i=1}^M \\text{KL}(p(y \\vert \\mathbf{x}^l_i | p(y \\vert \\mathbf{x}))$. Samples with high contrastive scores are selected for active learning. On a variety of classification tasks, the experiment results of CAL look similar to the entropy baseline.\nMeasuring Representativeness Core-sets Approach A core-set is a concept in computational geometry, referring to a small set of points that approximates the shape of a larger point set. Approximation can be captured by some geometric measure. In the active learning, we expect a model that is trained over the core-set to behave comparably with the model on the entire data points.\nSener \u0026amp; Savarese (2018) treats active learning as a core-set selection problem. Let’s say, there are $N$ samples in total accessible during training. During active learning, a small set of data points get labeled at every time step $t$, denoted as $\\mathcal{S}^{(t)}$. The upper bound of the learning objective can be written as follows, where the core-set loss is defined as the difference between average empirical loss over the labeled samples and the loss over the entire dataset including unlabelled ones.\n $$ \\begin{aligned} \\mathbb{E}_{(\\mathbf{x}, y) \\sim p} [\\mathcal{L}(\\mathbf{x}, y)] \\leq\u0026 \\bigg\\vert \\mathbb{E}_{(\\mathbf{x}, y) \\sim p} [\\mathcal{L}(\\mathbf{x}, y)] - \\frac{1}{N} \\sum_{i=1}^N \\mathcal{L}(\\mathbf{x}_i, y_i) \\bigg\\vert \u0026 \\text{; Generalization error}\\\\ +\u0026 \\frac{1}{\\vert \\mathcal{S}^{(t)} \\vert} \\sum_{j=1}^{\\vert \\mathcal{S}^{(t)} \\vert} \\mathcal{L}(\\mathbf{x}^l_j, y_j) \u0026 \\text{; Training error}\\\\ +\u0026 \\bigg\\vert \\frac{1}{N} \\sum_{i=1}^N \\mathcal{L}(\\mathbf{x}_i, y_i) - \\frac{1}{\\vert \\mathcal{S}^{(t)} \\vert} \\sum_{j=1}^{\\vert \\mathcal{S}^{(t)} \\vert} \\mathcal{L}(\\mathbf{x}^l_j, y_j) \\bigg\\vert \u0026 \\text{; Core-set error} \\end{aligned} $$ Then the active learning problem can be redefined as:\n $$ \\min_{\\mathcal{S}^{(t+1)} : \\vert \\mathcal{S}^{(t+1)} \\vert \\leq b} \\bigg\\vert \\frac{1}{N}\\sum_{i=1}^N \\mathcal{L}(\\mathbf{x}_i, y_i) - \\frac{1}{\\vert \\mathcal{S}^{(t)} \\cup \\mathcal{S}^{(t+1)} \\vert} \\sum_{j=1}^{\\vert \\mathcal{S}^{(t)} \\cup \\mathcal{S}^{(t+1)} \\vert} \\mathcal{L}(\\mathbf{x}^l_j, y_j) \\bigg\\vert $$ It is equivalent to the $k$-Center problem: choose $b$ center points such that the largest distance between a data point and its nearest center is minimized. This problem is NP-hard. An approximate solution depends on the greedy algorithm.\nFig. 9. Active learning results of core-sets algorithm in comparison with several common baselines on CIFAR-10, CIFAR-100, SVHN. (Image source: Sener \u0026 Savarese 2018) It works well on image classification tasks when there is a small number of classes. When the number of classes grows to be large or the data dimensionality increases (\u0026ldquo;curse of dimensionality\u0026rdquo;), the core-set method becomes less effective (Sinha et al. 2019).\nBecause the core-set selection is expensive, Coleman et al. (2020) experimented with a weaker model (e.g. smaller, weaker architecture, not fully trained) and found that empirically using a weaker model as a proxy can significantly shorten each repeated data selection cycle of training models and selecting samples, without hurting the final error much. Their method is referred to as SVP (Selection via Proxy).\nDiverse Gradient Embedding BADGE (Batch Active learning by Diverse Gradient Embeddings; Ash et al. 2020) tracks both model uncertainty and data diversity in the gradient space. Uncertainty is measured by the gradient magnitude w.r.t. the final layer of the network and diversity is captured by a diverse set of samples that span in the gradient space.\n Uncertainty. Given an unlabeled sample $\\mathbf{x}$, BADGE first computes the prediction $\\hat{y}$ and the gradient $g_\\mathbf{x}$ of the loss on $(\\mathbf{x}, \\hat{y})$ w.r.t. the last layer’s parameters. They observed that the norm of $g_\\mathbf{x}$ conservatively estimates the example\u0026rsquo;s influence on the model learning and high-confidence samples tend to have gradient embeddings of small magnitude. Diversity. Given many gradient embeddings of many samples, $g_\\mathbf{x}$, BADGE runs $k$-means++ to sample data points accordingly. Fig. 10. Algorithm of BADGE (batch active learning by diverse gradient embeddings). (Image source: Ash et al. 2020) Measuring Training Effects Quantify Model Changes Settles et al. (2008) introduced an active learning query strategy, named EGL (Expected Gradient Length). The motivation is to find samples that can trigger the greatest update on the model if their labels are known.\nLet $\\nabla \\mathcal{L}(\\theta)$ be the gradient of the loss function with respect to the model parameters. Specifically, given an unlabeled sample $\\mathbf{x}_i$, we need to calculate the gradient assuming the label is $y \\in \\mathcal{Y}$, $\\nabla \\mathcal{L}^{(y)}(\\theta)$. Because the true label $y_i$ is unknown, EGL relies on the current model belief to compute the expected gradient change:\n $$ \\text{EGL}(\\mathbf{x}_i) = \\sum_{y_i \\in \\mathcal{Y}} p(y=y_i \\vert \\mathbf{x}) \\|\\nabla \\mathcal{L}^{(y_i)}(\\theta)\\| $$ BALD (Bayesian Active Learning by Disagreement; Houlsby et al. 2011) aims to identify samples to maximize the information gain about the model weights, that is equivalent to maximize the decrease in expected posterior entropy.\n $$ \\begin{aligned} I[\\boldsymbol{\\theta}, y \\vert x,\\mathcal{D}] \u0026= H(\\boldsymbol{\\theta} \\vert \\mathcal{D}) - \\mathbb{E}_{y \\sim p(y \\vert \\boldsymbol{x}, \\mathcal{D})} \\big[ H(\\boldsymbol{\\theta} \\vert y, \\boldsymbol{x}, \\mathcal{D}) \\big] \u0026 \\text{; Decrease in expected posterior entropy}\\\\ \u0026= H(y \\vert \\boldsymbol{x}, \\mathcal{D}) - \\mathbb{E}_{\\boldsymbol{\\theta} \\sim p(\\boldsymbol{\\theta} \\vert \\mathcal{D})} \\big[ H(y \\vert \\boldsymbol{x}, \\mathcal{\\theta}) \\big] \\end{aligned} $$ The underlying interpretation is to \u0026ldquo;seek $\\mathbf{x}$ for which the model is marginally most uncertain about $y$ (high $H(y \\vert \\mathbf{x}, \\mathcal{D})$), but for which individual settings of the parameters are confident (low $H(y \\vert \\mathbf{x}, \\boldsymbol{\\theta})$).\u0026rdquo; In other words, each individual posterior draw is confident but a collection of draws carry diverse opinions.\nBALD was originally proposed for an individual sample and Kirsch et al. (2019) extended it to work in batch mode.\nForgetting Events To investigate whether neural networks have a tendency to forget previously learned information, Mariya Toneva et al. (2019) designed an experiment: They track the model prediction for each sample during the training process and count the transitions for each sample from being classified correctly to incorrectly or vice-versa. Then samples can be categorized accordingly,\n Forgettable (redundant) samples: If the class label changes across training epochs. Unforgettable samples: If the class label assignment is consistent across training epochs. Those samples are never forgotten once learned. They found that there are a large number of unforgettable examples that are never forgotten once learnt. Examples with noisy labels or images with \u0026ldquo;uncommon\u0026rdquo; features (visually complicated to classify) are among the most forgotten examples. The experiments empirically validated that unforgettable examples can be safely removed without compromising model performance.\nIn the implementation, the forgetting event is only counted when a sample is included in the current training batch; that is, they compute forgetting across presentations of the same example in subsequent mini-batches. The number of forgetting events per sample is quite stable across different seeds and forgettable examples have a small tendency to be first-time learned later in the training. The forgetting events are also found to be transferable throughout the training period and between architectures.\nForgetting events can be used as a signal for active learning acquisition if we hypothesize a model changing predictions during training is an indicator of model uncertainty. However, ground truth is unknown for unlabeled samples. Bengar et al. (2021) proposed a new metric called label dispersion for such a purpose. Let’s see across the training time, $c^*$ is the most commonly predicted label for the input $\\mathbf{x}$ and the label dispersion measures the fraction of training steps when the model does not assign $c^**$ to this sample:\n $$ \\text{Dispersion}(\\mathbf{x}) = 1 - \\frac{f_\\mathbf{x}}{T} \\text{ where } f_\\mathbf{x} = \\sum_{t=1}^T \\mathbb{1}[\\hat{y}_t = c^*], c^* = \\arg\\max_{c=1,\\dots,C}\\sum_{t=1}^T \\mathbb{1}[\\hat{y}_t = c] $$ In their implementation, dispersion is computed at every epoch. Label dispersion is low if the model consistently assigns the same label to the same sample but high if the prediction changes often. Label dispersion is correlated with network uncertainty, as shown in Fig. 11.\nFig. 11. Label dispersion is correlated with network uncertainty. On the x-axis, data points are sorted by label dispersion scores. The y-axis is the model prediction accuracy when the model trys to infer the labels for those samples. (Image source: Bengar et al. 2021) Hybrid When running active learning in batch mode, it is important to control diversity within a batch. Suggestive Annotation (SA; Yang et al. 2017) is a two-step hybrid strategy, aiming to select both high uncertainty \u0026amp; highly representative labeled samples. It uses uncertainty obtained from an ensemble of models trained on the labeled data and core-sets for choosing representative data samples.\n First, SA selects top $K$ images with high uncertainty scores to form a candidate pool $\\mathcal{S}_c \\subseteq \\mathcal{S}_U$. The uncertainty is measured as disagreement between multiple models training with bootstrapping. The next step is to find a subset $\\mathcal{S}_a \\subseteq \\mathcal{S}_c$ with highest representativeness. The cosine similarity between feature vectors of two inputs approximates how similar they are. The representativeness of $\\mathcal{S}_a$ for $\\mathcal{S}_U$ reflects how well $\\mathcal{S}_a$ can represent all the samples in $\\mathcal{S}_u$, defined as: $$ F(\\mathcal{S}_a, \\mathcal{S}_u) = \\sum_{\\mathbf{x}_j \\in \\mathcal{S}_u} f(\\mathcal{S}_a, \\mathbf{x}_j) = \\sum_{\\mathbf{x}_j \\in \\mathcal{S}_u} \\max_{\\mathbf{x}_i \\in \\mathcal{S}_a} \\text{sim}(\\mathbf{x}_i, \\mathbf{x}_j) $$ Formulating $\\mathcal{S}_a \\subseteq \\mathcal{S}_c$ with $k$ data points that maximizes $F(\\mathcal{S}_a, \\mathcal{S}_u)$ is a generalized version of the maximum set cover problem. It is NP-hard and its best possible polynomial time approximation algorithm is a simple greedy method.\n Initially, $\\mathcal{S}_a = \\emptyset$ and $F(\\mathcal{S}_a, \\mathcal{S}_u) = 0$. Then, iteratively add $\\mathbf{x}_i \\in \\mathcal{S}_c$ that maximizes $F(\\mathcal{S}_a \\cup I_i, \\mathcal{S}_u)$ over $\\mathcal{S}_a$, until $\\mathcal{S}_s$ contains $k$ images. Zhdanov (2019) runs a similar process as SA, but at step 2, it relies on $k$-means instead of core-set, where the size of the candidate pool is configured relative to the batch size. Given batch size $b$ and a constant $beta$ (between 10 and 50), it follows these steps:\n Train a classifier on the labeled data; Measure informativeness of every unlabeled example (e.g. using uncertainty metrics); Prefilter top $\\beta b \\geq b$ most informative examples; Cluster $\\beta b$ examples into $B$ clusters; Select $b$ different examples closest to the cluster centers for this round of active learning. Active learning can be further combined with semi-supervised learning to save the budget. CEAL (Cost-Effective Active Learning; Yang et al. 2017) runs two things in parallel:\n Select uncertain samples via active learning and get them labeled; Select samples with the most confident prediction and assign them pseudo labels. The confidence prediction is judged by whether the prediction entropy is below a threshold $\\delta$. As the model is getting better in time, the threshold $\\delta$ decays in time as well. Fig. 12. Illustration of CEAL (cost-effective active learning). (Image source: Yang et al. 2017) Citation Cited as:\n Weng, Lilian. (Feb 2022). Learning with not enough data part 2: active learning. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2022-02-20-active-learning/.\n Or\n@article{weng2022active, title = \u0026quot;Learning with not Enough Data Part 2: Active Learning\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2022\u0026quot;, month = \u0026quot;Feb\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2022-02-20-active-learning/\u0026quot; } References [1] Burr Settles. Active learning literature survey. University of Wisconsin, Madison, 52(55-66):11, 2010.\n[2] https://jacobgil.github.io/deeplearning/activelearning\n[3] Yang et al. \u0026ldquo;Cost-effective active learning for deep image classification\u0026rdquo; TCSVT 2016.\n[4] Yarin Gal et al. \u0026ldquo;Dropout as a Bayesian Approximation: representing model uncertainty in deep learning.\u0026quot; ICML 2016.\n[5] Blundell et al. \u0026ldquo;Weight uncertainty in neural networks (Bayes-by-Backprop)\u0026quot; ICML 2015.\n[6] Settles et al. \u0026ldquo;Multiple-Instance Active Learning.\u0026quot; NIPS 2007.\n[7] Houlsby et al. Bayesian Active Learning for Classification and Preference Learning.\u0026quot; arXiv preprint arXiv:1112.5745 (2020).\n[8] Kirsch et al. \u0026ldquo;BatchBALD: Efficient and Diverse Batch Acquisition for Deep Bayesian Active Learning.\u0026quot; NeurIPS 2019.\n[9] Beluch et al. \u0026ldquo;The power of ensembles for active learning in image classification.\u0026quot; CVPR 2018.\n[10] Sener \u0026amp; Savarese. \u0026ldquo;Active learning for convolutional neural networks: A core-set approach.\u0026quot; ICLR 2018.\n[11] Donggeun Yoo \u0026amp; In So Kweon. \u0026ldquo;Learning Loss for Active Learning.\u0026quot; CVPR 2019.\n[12] Margatina et al. \u0026ldquo;Active Learning by Acquiring Contrastive Examples.\u0026quot; EMNLP 2021.\n[13] Sinha et al. \u0026ldquo;Variational Adversarial Active Learning\u0026rdquo; ICCV 2019\n[14] Ebrahimiet al. \u0026ldquo;Minmax Active Learning\u0026rdquo; arXiv preprint arXiv:2012.10467 (2021).\n[15] Mariya Toneva et al. \u0026ldquo;An empirical study of example forgetting during deep neural network learning.\u0026quot; ICLR 2019.\n[16] Javad Zolfaghari Bengar et al. \u0026ldquo;When Deep Learners Change Their Mind: Learning Dynamics for Active Learning.\u0026quot; CAIP 2021.\n[17] Yang et al. \u0026ldquo;Suggestive annotation: A deep active learning framework for biomedical image segmentation.\u0026quot; MICCAI 2017.\n[18] Fedor Zhdanov. \u0026ldquo;Diverse mini-batch Active Learning\u0026rdquo; arXiv preprint arXiv:1901.05954 (2019).\n","permalink":"https://lilianweng.github.io/posts/2022-02-20-active-learning/","summary":"This is part 2 of what to do when facing a limited amount of labeled data for supervised learning tasks. This time we will get some amount of human labeling work involved, but within a budget limit, and therefore we need to be smart when selecting which samples to label.\nNotations Symbol Meaning $K$ Number of unique class labels. $(\\mathbf{x}^l, y) \\sim \\mathcal{X}, y \\in \\{0, 1\\}^K$ Labeled dataset.","title":"Learning with not Enough Data Part 2: Active Learning"},{"content":"When facing a limited amount of labeled data for supervised learning tasks, four approaches are commonly discussed.\n Pre-training + fine-tuning: Pre-train a powerful task-agnostic model on a large unsupervised data corpus, e.g. pre-training LMs on free text, or pre-training vision models on unlabelled images via self-supervised learning, and then fine-tune it on the downstream task with a small set of labeled samples. Semi-supervised learning: Learn from the labelled and unlabeled samples together. A lot of research has happened on vision tasks within this approach. Active learning: Labeling is expensive, but we still want to collect more given a cost budget. Active learning learns to select most valuable unlabeled samples to be collected next and helps us act smartly with a limited budget. Pre-training + dataset auto-generation: Given a capable pre-trained model, we can utilize it to auto-generate a lot more labeled samples. This has been especially popular within the language domain driven by the success of few-shot learning. I plan to write a series of posts on the topic of “Learning with not enough data”. Part 1 is on Semi-Supervised Learning.\nWhat is semi-supervised learning? Semi-supervised learning uses both labeled and unlabeled data to train a model.\nInterestingly most existing literature on semi-supervised learning focuses on vision tasks. And instead pre-training + fine-tuning is a more common paradigm for language tasks.\nAll the methods introduced in this post have a loss combining two parts: $\\mathcal{L} = \\mathcal{L}_s + \\mu(t) \\mathcal{L}_u$. The supervised loss $\\mathcal{L}_s$ is easy to get given all the labeled examples. We will focus on how the unsupervised loss $\\mathcal{L}_u$ is designed. A common choice of the weighting term $\\mu(t)$ is a ramp function increasing the importance of $\\mathcal{L}_u$ in time, where $t$ is the training step.\n Disclaimer: The post is not gonna cover semi-supervised methods with focus on model architecture modification. Check this survey for how to use generative models and graph-based methods in semi-supervised learning.\n Notations Symbol Meaning $L$ Number of unique labels. $(\\mathbf{x}^l, y) \\sim \\mathcal{X}, y \\in \\{0, 1\\}^L$ Labeled dataset. $y$ is a one-hot representation of the true label. $\\mathbf{u} \\sim \\mathcal{U}$ Unlabeled dataset. $\\mathcal{D} = \\mathcal{X} \\cup \\mathcal{U}$ The entire dataset, including both labeled and unlabeled examples. $\\mathbf{x}$ Any sample which can be either labeled or unlabeled. $\\bar{\\mathbf{x}}$ $\\mathbf{x}$ with augmentation applied. $\\mathbf{x}_i$ The $i$-th sample. $\\mathcal{L}$, $\\mathcal{L}_s$, $\\mathcal{L}_u$ Loss, supervised loss, and unsupervised loss. $\\mu(t)$ The unsupervised loss weight, increasing in time. $p(y \\vert \\mathbf{x}), p_\\theta(y \\vert \\mathbf{x})$ The conditional probability over the label set given the input. $f_\\theta(.)$ The implemented neural network with weights $\\theta$, the model that we want to train. $\\mathbf{z} = f_\\theta(\\mathbf{x})$ A vector of logits output by $f$. $\\hat{y} = \\text{softmax}(\\mathbf{z})$ The predicted label distribution. $D[.,.]$ A distance function between two distributions, such as MSE, cross entropy, KL divergence, etc. $\\beta$ EMA weighting hyperparameter for teacher model weights. $\\alpha, \\lambda$ Parameters for MixUp, $\\lambda \\sim \\text{Beta}(\\alpha, \\alpha)$. $T$ Temperature for sharpening the predicted distribution. $\\tau$ A confidence threshold for selecting the qualified prediction. Hypotheses Several hypotheses have been discussed in literature to support certain design decisions in semi-supervised learning methods.\n H1: Smoothness Assumptions: If two data samples are close in a high-density region of the feature space, their labels should be the same or very similar.\n H2: Cluster Assumptions: The feature space has both dense regions and sparse regions. Densely grouped data points naturally form a cluster. Samples in the same cluster are expected to have the same label. This is a small extension of H1.\n H3: Low-density Separation Assumptions: The decision boundary between classes tends to be located in the sparse, low density regions, because otherwise the decision boundary would cut a high-density cluster into two classes, corresponding to two clusters, which invalidates H1 and H2.\n H4: Manifold Assumptions: The high-dimensional data tends to locate on a low-dimensional manifold. Even though real-world data might be observed in very high dimensions (e.g. such as images of real-world objects/scenes), they actually can be captured by a lower dimensional manifold where certain attributes are captured and similar points are grouped closely (e.g. images of real-world objects/scenes are not drawn from a uniform distribution over all pixel combinations). This enables us to learn a more efficient representation for us to discover and measure similarity between unlabeled data points. This is also the foundation for representation learning. [see a helpful link].\n Consistency Regularization Consistency Regularization, also known as Consistency Training, assumes that randomness within the neural network (e.g. with Dropout) or data augmentation transformations should not modify model predictions given the same input. Every method in this section has a consistency regularization loss as $\\mathcal{L}_u$.\nThis idea has been adopted in several self-supervised learning methods, such as SimCLR, BYOL, SimCSE, etc. Different augmented versions of the same sample should result in the same representation. Cross-view training in language modeling and multi-view learning in self-supervised learning all share the same motivation.\nΠ-model Fig. 1. Overview of the Π-model. Two versions of the same input with different stochastic augmentation and dropout masks pass through the network and the outputs are expected to be consistent. (Image source: Laine \u0026 Aila (2017)) Sajjadi et al. (2016) proposed an unsupervised learning loss to minimize the difference between two passes through the network with stochastic transformations (e.g. dropout, random max-pooling) for the same data point. The label is not explicitly used, so the loss can be applied to unlabeled dataset. Laine \u0026amp; Aila (2017) later coined the name, Π-Model, for such a setup.\n $$ \\mathcal{L}_u^\\Pi = \\sum_{\\mathbf{x} \\in \\mathcal{D}} \\text{MSE}(f_\\theta(\\mathbf{x}), f'_\\theta(\\mathbf{x})) $$ where $f'$ is the same neural network with different stochastic augmentation or dropout masks applied. This loss utilizes the entire dataset.\nTemporal ensembling Fig. 2. Overview of Temporal Ensembling. The per-sample EMA label prediction is the learning target. (Image source: Laine \u0026 Aila (2017)) Π-model requests the network to run two passes per sample, doubling the computation cost. To reduce the cost, Temporal Ensembling (Laine \u0026amp; Aila 2017) maintains an exponential moving average (EMA) of the model prediction in time per training sample $\\tilde{\\mathbf{z}}_i$ as the learning target, which is only evaluated and updated once per epoch. Because the ensemble output $\\tilde{\\mathbf{z}}_i$ is initialized to $\\mathbf{0}$, it is normalized by $(1-\\alpha^t)$ to correct this startup bias. Adam optimizer has such bias correction terms for the same reason.\n $$ \\tilde{\\mathbf{z}}^{(t)}_i = \\frac{\\alpha \\tilde{\\mathbf{z}}^{(t-1)}_i + (1-\\alpha) \\mathbf{z}_i}{1-\\alpha^t} $$ where $\\tilde{\\mathbf{z}}^{(t)}$ is the ensemble prediction at epoch $t$ and $\\mathbf{z}_i$ is the model prediction in the current round. Note that since $\\tilde{\\mathbf{z}}^{(0)} = \\mathbf{0}$, with correction, $\\tilde{\\mathbf{z}}^{(1)}$ is simply equivalent to $\\mathbf{z}_i$ at epoch 1.\nMean teachers Fig. 3. Overview of the Mean Teacher framework. (Image source: Tarvaninen \u0026 Valpola, 2017) Temporal Ensembling keeps track of an EMA of label predictions for each training sample as a learning target. However, this label prediction only changes every epoch, making the approach clumsy when the training dataset is large. Mean Teacher (Tarvaninen \u0026amp; Valpola, 2017) is proposed to overcome the slowness of target update by tracking the moving average of model weights instead of model outputs. Let’s call the original model with weights $\\theta$ as the student model and the model with moving averaged weights $\\theta’$ across consecutive student models as the mean teacher: $\\theta’ \\gets \\beta \\theta’ + (1-\\beta)\\theta$\nThe consistency regularization loss is the distance between predictions by the student and teacher and the student-teacher gap should be minimized. The mean teacher is expected to provide more accurate predictions than the student. It got confirmed in the empirical experiments, as shown in Fig. 4.\nFig. 4. Classification error on SVHN of Mean Teacher and the Π Model. The mean teacher (in orange) has better performance than the student model (in blue). (Image source: Tarvaninen \u0026 Valpola, 2017) According to their ablation studies,\n Input augmentation (e.g. random flips of input images, Gaussian noise) or student model dropout is necessary for good performance. Dropout is not needed on the teacher model. The performance is sensitive to the EMA decay hyperparameter $\\beta$. A good strategy is to use a small $\\beta=0.99$ during the ramp up stage and a larger $\\beta=0.999$ in the later stage when the student model improvement slows down. They found that MSE as the consistency cost function performs better than other cost functions like KL divergence. Noisy samples as learning targets Several recent consistency training methods learn to minimize prediction difference between the original unlabeled sample and its corresponding augmented version. It is quite similar to the Π-model but the consistency regularization loss is only applied to the unlabeled data.\nFig. 5. Consistency training with noisy samples. Adversarial Training (Goodfellow et al. 2014) applies adversarial noise onto the input and trains the model to be robust to such adversarial attack. The setup works in supervised learning,\n $$ \\begin{aligned} \\mathcal{L}_\\text{adv}(\\mathbf{x}^l, \\theta) \u0026= D[q(y\\mid \\mathbf{x}^l), p_\\theta(y\\mid \\mathbf{x}^l + r_\\text{adv})] \\\\ r_\\text{adv} \u0026= {\\arg\\max}_{r; \\|r\\| \\leq \\epsilon} D[q(y\\mid \\mathbf{x}^l), p_\\theta(y\\mid \\mathbf{x}^l + r_\\text{adv})] \\\\ r_\\text{adv} \u0026\\approx \\epsilon \\frac{g}{\\|g\\|_2} \\approx \\epsilon\\text{sign}(g)\\quad\\text{where }g = \\nabla_{r} D[y, p_\\theta(y\\mid \\mathbf{x}^l + r)] \\end{aligned} $$ where $q(y \\mid \\mathbf{x}^l)$ is the true distribution, approximated by one-hot encoding of the ground truth label, $y$. $p_\\theta(y \\mid \\mathbf{x}^l)$ is the model prediction. $D[.,.]$ is a distance function measuring the divergence between two distributions.\nVirtual Adversarial Training (VAT; Miyato et al. 2018) extends the idea to work in semi-supervised learning. Because $q(y \\mid \\mathbf{x}^l)$ is unknown, VAT replaces it with the current model prediction for the original input with the current weights $\\hat{\\theta}$. Note that $\\hat{\\theta}$ is a fixed copy of model weights, so there is no gradient update on $\\hat{\\theta}$.\n $$ \\begin{aligned} \\mathcal{L}_u^\\text{VAT}(\\mathbf{x}, \\theta) \u0026= D[p_{\\hat{\\theta}}(y\\mid \\mathbf{x}), p_\\theta(y\\mid \\mathbf{x} + r_\\text{vadv})] \\\\ r_\\text{vadv} \u0026= {\\arg\\max}_{r; \\|r\\| \\leq \\epsilon} D[p_{\\hat{\\theta}}(y\\mid \\mathbf{x}), p_\\theta(y\\mid \\mathbf{x} + r)] \\end{aligned} $$ The VAT loss applies to both labeled and unlabeled samples. It is a negative smoothness measure of the current model\u0026rsquo;s prediction manifold at each data point. The optimization of such loss motivates the manifold to be smoother.\nInterpolation Consistency Training (ICT; Verma et al. 2019) enhances the dataset by adding more interpolations of data points and expects the model prediction to be consistent with interpolations of the corresponding labels. MixUp (Zheng et al. 2018) operation mixes two images via a simple weighted sum and combines it with label smoothing. Following the idea of MixUp, ICT expects the prediction model to produce a label on a mixup sample to match the interpolation of predictions of corresponding inputs:\n $$ \\begin{aligned} \\text{mixup}_\\lambda (\\mathbf{x}_i, \\mathbf{x}_j) \u0026= \\lambda \\mathbf{x}_i + (1-\\lambda)\\mathbf{x}_j \\\\ p(\\text{mixup}_\\lambda (y \\mid \\mathbf{x}_i, \\mathbf{x}_j)) \u0026\\approx \\lambda p(y \\mid \\mathbf{x}_i) + (1-\\lambda) p(y \\mid \\mathbf{x}_j) \\end{aligned} $$ where $\\theta'$ is a moving average of $\\theta$, which is a mean teacher.\nFig. 6. Overview of Interpolation Consistency Training. MixUp is applied to produce more interpolated samples with interpolated labels as learning targets. (Image source: Verma et al. 2019) Because the probability of two randomly selected unlabeled samples belonging to different classes is high (e.g. There are 1000 object classes in ImageNet), the interpolation by applying a mixup between two random unlabeled samples is likely to happen around the decision boundary. According to the low-density separation assumptions, the decision boundary tends to locate in the low density regions.\n $$ \\mathcal{L}^\\text{ICT}_{u} = \\mathbb{E}_{\\mathbf{u}_i, \\mathbf{u}_j \\sim \\mathcal{U}} \\mathbb{E}_{\\lambda \\sim \\text{Beta}(\\alpha, \\alpha)} D[p_\\theta(y \\mid \\text{mixup}_\\lambda (\\mathbf{u}_i, \\mathbf{u}_j)), \\text{mixup}_\\lambda(p_{\\theta’}(y \\mid \\mathbf{u}_i), p_{\\theta'}(y \\mid \\mathbf{u}_j)] $$ where $\\theta'$ is a moving average of $\\theta$.\nSimilar to VAT, Unsupervised Data Augmentation (UDA; Xie et al. 2020) learns to predict the same output for an unlabeled example and the augmented one. UDA especially focuses on studying how the \u0026ldquo;quality\u0026rdquo; of noise can impact the semi-supervised learning performance with consistency training. It is crucial to use advanced data augmentation methods for producing meaningful and effective noisy samples. Good data augmentation should produce valid (i.e. does not change the label) and diverse noise, and carry targeted inductive biases.\nFor images, UDA adopts RandAugment (Cubuk et al. 2019) which uniformly samples augmentation operations available in PIL, no learning or optimization, so it is much cheaper than AutoAugment.\nFig. 7. Comparison of various semi-supervised learning methods on CIFAR-10 classification. Fully supervised Wide-ResNet-28-2 and PyramidNet+ShakeDrop have an error rate of **5.4** and **2.7** respectively when trained on 50,000 examples without RandAugment. (Image source: Xie et al. 2020) For language, UDA combines back-translation and TF-IDF based word replacement. Back-translation preserves the high-level meaning but may not retain certain words, while TF-IDF based word replacement drops uninformative words with low TF-IDF scores. In the experiments on language tasks, they found UDA to be complementary to transfer learning and representation learning; For example, BERT fine-tuned (i.e. $\\text{BERT}_\\text{FINETUNE}$ in Fig. 8.) on in-domain unlabeled data can further improve the performance.\nFig. 8. Comparison of UDA with different initialization configurations on various text classification tasks. (Image source: Xie et al. 2020) When calculating $\\mathcal{L}_u$, UDA found two training techniques to help improve the results.\n Low confidence masking: Mask out examples with low prediction confidence if lower than a threshold $\\tau$. Sharpening prediction distribution: Use a low temperature $T$ in softmax to sharpen the predicted probability distribution. In-domain data filtration: In order to extract more in-domain data from a large out-of-domain dataset, they trained a classifier to predict in-domain labels and then retain samples with high confidence predictions as in-domain candidates. $$ \\begin{aligned} \u0026\\mathcal{L}_u^\\text{UDA} = \\mathbb{1}[\\max_{y'} p_{\\hat{\\theta}}(y'\\mid \\mathbf{x}) \\tau ] \\cdot D[p^\\text{(sharp)}_{\\hat{\\theta}}(y \\mid \\mathbf{x}; T), p_\\theta(y \\mid \\bar{\\mathbf{x}})] \\\\ \u0026\\text{where } p_{\\hat{\\theta}}^\\text{(sharp)}(y \\mid \\mathbf{x}; T) = \\frac{\\exp(z^{(y)} / T)}{ \\sum_{y'} \\exp(z^{(y')} / T) } \\end{aligned} $$ where $\\hat{\\theta}$ is a fixed copy of model weights, same as in VAT, so no gradient update, and $\\bar{\\mathbf{x}}$ is the augmented data point. $\\tau$ is the prediction confidence threshold and $T$ is the distribution sharpening temperature.\nPseudo Labeling Pseudo Labeling (Lee 2013) assigns fake labels to unlabeled samples based on the maximum softmax probabilities predicted by the current model and then trains the model on both labeled and unlabeled samples simultaneously in a pure supervised setup.\nWhy could pseudo labels work? Pseudo label is in effect equivalent to Entropy Regularization (Grandvalet \u0026amp; Bengio 2004), which minimizes the conditional entropy of class probabilities for unlabeled data to favor low density separation between classes. In other words, the predicted class probabilities is in fact a measure of class overlap, minimizing the entropy is equivalent to reduced class overlap and thus low density separation.\nFig. 9. t-SNE visualization of outputs on MNIST test set by models training (a) without and (b) with pseudo labeling on 60000 unlabeled samples, in addition to 600 labeled data. Pseudo labeling leads to better segregation in the learned embedding space. (Image source: Lee 2013) Training with pseudo labeling naturally comes as an iterative process. We refer to the model that produces pseudo labels as teacher and the model that learns with pseudo labels as student.\nLabel propagation Label Propagation (Iscen et al. 2019) is an idea to construct a similarity graph among samples based on feature embedding. Then the pseudo labels are \u0026ldquo;diffused\u0026rdquo; from known samples to unlabeled ones where the propagation weights are proportional to pairwise similarity scores in the graph. Conceptually it is similar to a k-NN classifier and both suffer from the problem of not scaling up well with a large dataset.\nFig. 10. Illustration of how Label Propagation works. (Image source: Iscen et al. 2019) Self-Training Self-Training is not a new concept (Scudder 1965, Nigram \u0026amp; Ghani CIKM 2000). It is an iterative algorithm, alternating between the following two steps until every unlabeled sample has a label assigned:\n Initially it builds a classifier on labeled data. Then it uses this classifier to predict labels for the unlabeled data and converts the most confident ones into labeled samples. Xie et al. (2020) applied self-training in deep learning and achieved great results. On the ImageNet classification task, they first trained an EfficientNet (Tan \u0026amp; Le 2019) model as teacher to generate pseudo labels for 300M unlabeled images and then trained a larger EfficientNet as student to learn with both true labeled and pseudo labeled images. One critical element in their setup is to have noise during student model training but have no noise for the teacher to produce pseudo labels. Thus their method is called Noisy Student. They applied stochastic depth (Huang et al. 2016), dropout and RandAugment to noise the student. Noise is important for the student to perform better than the teacher. The added noise has a compound effect to encourage the model\u0026rsquo;s decision making frontier to be smooth, on both labeled and unlabeled data.\nA few other important technical configs in noisy student self-training are:\n The student model should be sufficiently large (i.e. larger than the teacher) to fit more data. Noisy student should be paired with data balancing, especially important to balance the number of pseudo labeled images in each class. Soft pseudo labels work better than hard ones. Noisy student also improves adversarial robustness against an FGSM (Fast Gradient Sign Attack = The attack uses the gradient of the loss w.r.t the input data and adjusts the input data to maximize the loss) attack though the model is not optimized for adversarial robustness.\nSentAugment, proposed by Du et al. (2020), aims to solve the problem when there is not enough in-domain unlabeled data for self-training in the language domain. It relies on sentence embedding to find unlabeled in-domain samples from a large corpus and uses the retrieved sentences for self-training.\nReducing confirmation bias Confirmation bias is a problem with incorrect pseudo labels provided by an imperfect teacher model. Overfitting to wrong labels may not give us a better student model.\nTo reduce confirmation bias, Arazo et al. (2019) proposed two techniques. One is to adopt MixUp with soft labels. Given two samples, $(\\mathbf{x}_i, \\mathbf{x}_j)$ and their corresponding true or pseudo labels $(y_i, y_j)$, the interpolated label equation can be translated to a cross entropy loss with softmax outputs:\n $$ \\begin{aligned} \u0026\\bar{\\mathbf{x}} = \\lambda \\mathbf{x}_i + (1-\\lambda) \\mathbf{x}_j \\\\ \u0026\\bar{y} = \\lambda y_i + (1-\\lambda) y_j \\Leftrightarrow \\mathcal{L} = \\lambda [y_i^\\top \\log f_\\theta(\\bar{\\mathbf{x}})] + (1-\\lambda) [y_j^\\top \\log f_\\theta(\\bar{\\mathbf{x}})] \\end{aligned} $$ Mixup is insufficient if there are too few labeled samples. They further set a minimum number of labeled samples in every mini batch by oversampling the labeled samples. This works better than upweighting labeled samples, because it leads to more frequent updates rather than few updates of larger magnitude which could be less stable. Like consistency regularization, data augmentation and dropout are also important for pseudo labeling to work well.\nMeta Pseudo Labels (Pham et al. 2021) adapts the teacher model constantly with the feedback of how well the student performs on the labeled dataset. The teacher and the student are trained in parallel, where the teacher learns to generate better pseudo labels and the student learns from the pseudo labels.\nLet the teacher and student model weights be $\\theta_T$ and $\\theta_S$, respectively. The student model\u0026rsquo;s loss on the labeled samples is defined as a function $\\theta^\\text{PL}_S(.)$ of $\\theta_T$ and we would like to minimize this loss by optimizing the teacher model accordingly.\n $$ \\begin{aligned} \\min_{\\theta_T} \u0026\\mathcal{L}_s(\\theta^\\text{PL}_S(\\theta_T)) = \\min_{\\theta_T} \\mathbb{E}_{(\\mathbf{x}^l, y) \\in \\mathcal{X}} \\text{CE}[y, f_{\\theta_S}(\\mathbf{x}^l)] \\\\ \\text{where } \u0026\\theta^\\text{PL}_S(\\theta_T) = \\arg\\min_{\\theta_S} \\mathcal{L}_u (\\theta_T, \\theta_S) = \\arg\\min_{\\theta_S} \\mathbb{E}_{\\mathbf{u} \\sim \\mathcal{U}} \\text{CE}[(f_{\\theta_T}(\\mathbf{u}), f_{\\theta_S}(\\mathbf{u}))] \\end{aligned} $$ However, it is not trivial to optimize the above equation. Borrowing the idea of MAML, it approximates the multi-step $\\arg\\min_{\\theta_S}$ with the one-step gradient update of $\\theta_S$,\n $$ \\begin{aligned} \\theta^\\text{PL}_S(\\theta_T) \u0026\\approx \\theta_S - \\eta_S \\cdot \\nabla_{\\theta_S} \\mathcal{L}_u(\\theta_T, \\theta_S) \\\\ \\min_{\\theta_T} \\mathcal{L}_s (\\theta^\\text{PL}_S(\\theta_T)) \u0026\\approx \\min_{\\theta_T} \\mathcal{L}_s \\big( \\theta_S - \\eta_S \\cdot \\nabla_{\\theta_S} \\mathcal{L}_u(\\theta_T, \\theta_S) \\big) \\end{aligned} $$ With soft pseudo labels, the above objective is differentiable. But if using hard pseudo labels, it is not differentiable and thus we need to use RL, e.g. REINFORCE.\nThe optimization procedure is alternative between training two models:\n Student model update: Given a batch of unlabeled samples $\\{ \\mathbf{u} \\}$, we generate pseudo labels by $f_{\\theta_T}(\\mathbf{u})$ and optimize $\\theta_S$ with one step SGD: $\\theta’_S = \\color{green}{\\theta_S - \\eta_S \\cdot \\nabla_{\\theta_S} \\mathcal{L}_u(\\theta_T, \\theta_S)}$. Teacher model update: Given a batch of labeled samples $\\{(\\mathbf{x}^l, y)\\}$, we reuse the student’s update to optimize $\\theta_T$: $\\theta’_T = \\theta_T - \\eta_T \\cdot \\nabla_{\\theta_T} \\mathcal{L}_s ( \\color{green}{\\theta_S - \\eta_S \\cdot \\nabla_{\\theta_S} \\mathcal{L}_u(\\theta_T, \\theta_S)} )$. In addition, the UDA objective is applied to the teacher model to incorporate consistency regularization. Fig. 11. Comparison of Meta Pseudo Labels with other semi- or self-supervised learning methods on image classification tasks. (Image source: Pham et al. 2021) Pseudo Labeling with Consistency Regularization It is possible to combine the above two approaches together, running semi-supervised learning with both pseudo labeling and consistency training.\nMixMatch MixMatch (Berthelot et al. 2019), as a holistic approach to semi-supervised learning, utilizes unlabeled data by merging the following techniques:\n Consistency regularization: Encourage the model to output the same predictions on perturbed unlabeled samples. Entropy minimization: Encourage the model to output confident predictions on unlabeled data. MixUp augmentation: Encourage the model to have linear behaviour between samples. Given a batch of labeled data $\\mathcal{X}$ and unlabeled data $\\mathcal{U}$, we create augmented versions of them via $\\text{MixMatch}(.)$, $\\bar{\\mathcal{X}}$ and $\\bar{\\mathcal{U}}$, containing augmented samples and guessed labels for unlabeled examples.\n $$ \\begin{aligned} \\bar{\\mathcal{X}}, \\bar{\\mathcal{U}} \u0026= \\text{MixMatch}(\\mathcal{X}, \\mathcal{U}, T, K, \\alpha) \\\\ \\mathcal{L}^\\text{MM}_s \u0026= \\frac{1}{\\vert \\bar{\\mathcal{X}} \\vert} \\sum_{(\\bar{\\mathbf{x}}^l, y)\\in \\bar{\\mathcal{X}}} D[y, p_\\theta(y \\mid \\bar{\\mathbf{x}}^l)] \\\\ \\mathcal{L}^\\text{MM}_u \u0026= \\frac{1}{L\\vert \\bar{\\mathcal{U}} \\vert} \\sum_{(\\bar{\\mathbf{u}}, \\hat{y})\\in \\bar{\\mathcal{U}}} \\| \\hat{y} - p_\\theta(y \\mid \\bar{\\mathbf{u}}) \\|^2_2 \\\\ \\end{aligned} $$ where $T$ is the sharpening temperature to reduce the guessed label overlap; $K$ is the number of augmentations generated per unlabeled example; $\\alpha$ is the parameter in MixUp.\nFor each $\\mathbf{u}$, MixMatch generates $K$ augmentations, $\\bar{\\mathbf{u}}^{(k)} = \\text{Augment}(\\mathbf{u})$ for $k=1, \\dots, K$ and the pseudo label is guessed based on the average: $\\hat{y} = \\frac{1}{K} \\sum_{k=1}^K p_\\theta(y \\mid \\bar{\\mathbf{u}}^{(k)})$.\nFig. 12. The process of \"label guessing\" in MixMatch: averaging $K$ augmentations, correcting the predicted marginal distribution and finally sharpening the distribution. (Image source: Berthelot et al. 2019) According to their ablation studies, it is critical to have MixUp especially on the unlabeled data. Removing temperature sharpening on the pseudo label distribution hurts the performance quite a lot. Average over multiple augmentations for label guessing is also necessary.\nReMixMatch (Berthelot et al. 2020) improves MixMatch by introducing two new mechanisms:\nFig. 13. Illustration of two improvements introduced in ReMixMatch over MixMatch. (Image source: Berthelot et al. 2020) Distribution alignment. It encourages the marginal distribution $p(y)$ to be close to the marginal distribution of the ground truth labels. Let $p(y)$ be the class distribution in the true labels and $\\tilde{p}(\\hat{y})$ be a running average of the predicted class distribution among the unlabeled data. The model prediction on an unlabeled sample $p_\\theta(y \\vert \\mathbf{u})$ is normalized to be $\\text{Normalize}\\big( \\frac{p_\\theta(y \\vert \\mathbf{u}) p(y)}{\\tilde{p}(\\hat{y})} \\big)$ to match the true marginal distribution. Note that entropy minimization is not a useful objective if the marginal distribution is not uniform. I do feel the assumption that the class distributions on the labeled and unlabeled data should match is too strong and not necessarily to be true in the real-world setting. Augmentation anchoring. Given an unlabeled sample, it first generates an \u0026ldquo;anchor\u0026rdquo; version with weak augmentation and then averages $K$ strongly augmented versions using CTAugment (Control Theory Augment). CTAugment only samples augmentations that keep the model predictions within the network tolerance. The ReMixMatch loss is a combination of several terms,\n a supervised loss with data augmentation and MixUp applied; an unsupervised loss with data augmentation and MixUp applied, using pseudo labels as targets; a CE loss on a single heavily-augmented unlabeled image without MixUp; a rotation loss as in self-supervised learning. DivideMix DivideMix (Junnan Li et al. 2020) combines semi-supervised learning with Learning with noisy labels (LNL). It models the per-sample loss distribution via a GMM to dynamically divide the training data into a labeled set with clean examples and an unlabeled set with noisy ones. Following the idea in Arazo et al. 2019, they fit a two-component GMM on the per-sample cross entropy loss $\\ell_i = y_i^\\top \\log f_\\theta(\\mathbf{x}_i)$. Clean samples are expected to get lower loss faster than noisy samples. The component with smaller mean is the cluster corresponding to clean labels and let’s denote it as $c$. If the GMM posterior probability $w_i = p_\\text{GMM}(c \\mid \\ell_i)$ (i.e. the probability of the sampling belonging to the clean sample set) is larger than the threshold $\\tau$, this sample is considered as a clean sample and otherwise a noisy one.\nThe data clustering step is named co-divide. To avoid confirmation bias, DivideMix simultaneously trains two diverged networks where each network uses the dataset division from the other network; e.g. thinking about how Double Q Learning works.\nFig. 14. DivideMix trains two networks independently to reduce confirmation bias. They run co-divide, co-refinement, and co-guessing together. (Image source: Junnan Li et al. 2020) Compared to MixMatch, DivideMix has an additional co-divide stage for handling noisy samples, as well as the following improvements during training:\n Label co-refinement: It linearly combines the ground-truth label $y_i$ with the network’s prediction $\\hat{y}_i$, which is averaged across multiple augmentations of $\\mathbf{x}_i$, guided by the clean set probability $w_i$ produced by the other network. Label co-guessing: It averages the predictions from two models for unlabelled data samples. Fig. 15. The algorithm of DivideMix. (Image source: Junnan Li et al. 2020) FixMatch FixMatch (Sohn et al. 2020) generates pseudo labels on unlabeled samples with weak augmentation and only keeps predictions with high confidence. Here both weak augmentation and high confidence filtering help produce high-quality trustworthy pseudo label targets. Then FixMatch learns to predict these pseudo labels given a heavily-augmented sample.\nFig. 16. Illustration of how FixMatch works. (Image source: Sohn et al. 2020) $$ \\begin{aligned} \\mathcal{L}_s \u0026= \\frac{1}{B} \\sum^B_{b=1} \\text{CE}[y_b, p_\\theta(y \\mid \\mathcal{A}_\\text{weak}(\\mathbf{x}_b))] \\\\ \\mathcal{L}_u \u0026= \\frac{1}{\\mu B} \\sum_{b=1}^{\\mu B} \\mathbb{1}[\\max(\\hat{y}_b) \\geq \\tau]\\;\\text{CE}(\\hat{y}_b, p_\\theta(y \\mid \\mathcal{A}_\\text{strong}(\\mathbf{u}_b))) \\end{aligned} $$ where $\\hat{y}_b$ is the pseudo label for an unlabeled example; $\\mu$ is a hyperparameter that determines the relative sizes of $\\mathcal{X}$ and $\\mathcal{U}$.\n Weak augmentation $\\mathcal{A}_\\text{weak}(.)$: A standard flip-and-shift augmentation Strong augmentation $\\mathcal{A}_\\text{strong}(.)$ : AutoAugment, Cutout, RandAugment, CTAugment Fig. 17. Performance of FixMatch and several other semi-supervised learning methods on image classification tasks. (Image source: Sohn et al. 2020) According to the ablation studies of FixMatch,\n Sharpening the predicted distribution with a temperature parameter $T$ does not have a significant impact when the threshold $\\tau$ is used. Cutout and CTAugment as part of strong augmentations are necessary for good performance. When the weak augmentation for label guessing is replaced with strong augmentation, the model diverges early in training. If discarding weak augmentation completely, the model overfit the guessed labels. Using weak instead of strong augmentation for pseudo label prediction leads to unstable performance. Strong data augmentation is critical. Combined with Powerful Pre-Training It is a common paradigm, especially in language tasks, to first pre-train a task-agnostic model on a large unsupervised data corpus via self-supervised learning and then fine-tune it on the downstream task with a small labeled dataset. Research has shown that we can obtain extra gain if combining semi-supervised learning with pretraining.\nZoph et al. (2020) studied to what degree self-training can work better than pre-training. Their experiment setup was to use ImageNet for pre-training or self-training to improve COCO. Note that when using ImageNet for self-training, it discards labels and only uses ImageNet samples as unlabeled data points. He et al. (2018) has demonstrated that ImageNet classification pre-training does not work well if the downstream task is very different, such as object detection.\nFig. 18. The effect of (a) data augment (from weak to strong) and (b) the labeled dataset size on the object detection performance. In the legend: `Rand Init` refers to a model initialized w/ random weights; `ImageNet` is initialized with a pre-trained checkpoint at 84.5% top-1 ImageNet accuracy; `ImageNet++` is initialized with a checkpoint with a higher accuracy 86.9%. (Image source: Zoph et al. 2020) Their experiments demonstrated a series of interesting findings:\n The effectiveness of pre-training diminishes with more labeled samples available for the downstream task. Pre-training is helpful in the low-data regimes (20%) but neutral or harmful in the high-data regime. Self-training helps in high data/strong augmentation regimes, even when pre-training hurts. Self-training can bring in additive improvement on top of pre-training, even using the same data source. Self-supervised pre-training (e.g. via SimCLR) hurts the performance in a high data regime, similar to how supervised pre-training does. Joint-training supervised and self-supervised objectives help resolve the mismatch between the pre-training and downstream tasks. Pre-training, joint-training and self-training are all additive. Noisy labels or un-targeted labeling (i.e. pre-training labels are not aligned with downstream task labels) is worse than targeted pseudo labeling. Self-training is computationally more expensive than fine-tuning on a pre-trained model. Chen et al. (2020) proposed a three-step procedure to merge the benefits of self-supervised pretraining, supervised fine-tuning and self-training together:\n Unsupervised or self-supervised pretrain a big model. Supervised fine-tune it on a few labeled examples. It is important to use a big (deep and wide) neural network. Bigger models yield better performance with fewer labeled samples. Distillation with unlabeled examples by adopting pseudo labels in self-training. It is possible to distill the knowledge from a large model into a small one because the task-specific use does not require extra capacity of the learned representation. The distillation loss is formatted as the following, where the teacher network is fixed with weights $\\hat{\\theta}_T$. $$ \\mathcal{L}_\\text{distill} = - (1-\\alpha) \\underbrace{\\sum_{(\\mathbf{x}^l_i, y_i) \\in \\mathcal{X}} \\big[ \\log p_{\\theta_S}(y_i \\mid \\mathbf{x}^l_i) \\big]}_\\text{Supervised loss} - \\alpha \\underbrace{\\sum_{\\mathbf{u}_i \\in \\mathcal{U}} \\Big[ \\sum_{i=1}^L p_{\\hat{\\theta}_T}(y^{(i)} \\mid \\mathbf{u}_i; T) \\log p_{\\theta_S}(y^{(i)} \\mid \\mathbf{u}_i; T) \\Big]}_\\text{Distillation loss using unlabeled data} $$ Fig. 19. A semi-supervised learning framework leverages unlabeled data corpus by (Left) task-agnostic unsupervised pretraining and (Right) task-specific self-training and distillation. (Image source: Chen et al. 2020) They experimented on the ImageNet classification task. The self-supervised pre-training uses SimCLRv2, a directly improved version of SimCLR. Observations in their empirical studies confirmed several learnings, aligned with Zoph et al. 2020:\n Bigger models are more label-efficient; Bigger/deeper project heads in SimCLR improve representation learning; Distillation using unlabeled data improves semi-supervised learning. Fig. 20. Comparison of performance by SimCLRv2 + semi-supervised distillation on ImageNet classification. (Image source: Chen et al. 2020) 💡 Quick summary of common themes among recent semi-supervised learning methods, many aiming to reduce confirmation bias:\n Apply valid and diverse noise to samples by advanced data augmentation methods. When dealing with images, MixUp is an effective augmentation. Mixup could work on language too, resulting in a small incremental improvement (Guo et al. 2019). Set a threshold and discard pseudo labels with low confidence. Set a minimum number of labeled samples per mini-batch. Sharpen the pseudo label distribution to reduce the class overlap. Citation Cited as:\n Weng, Lilian. (Dec 2021). Learning with not enough data part 1: semi-supervised learning. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2021-12-05-semi-supervised/.\n Or\n@article{weng2021semi, title = \u0026quot;Learning with not Enough Data Part 1: Semi-Supervised Learning\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2021\u0026quot;, month = \u0026quot;Dec\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2021-12-05-semi-supervised/\u0026quot; } References [1] Ouali, Hudelot \u0026amp; Tami. “An Overview of Deep Semi-Supervised Learning” arXiv preprint arXiv:2006.05278 (2020).\n[2] Sajjadi, Javanmardi \u0026amp; Tasdizen “Regularization With Stochastic Transformations and Perturbations for Deep Semi-Supervised Learning.” arXiv preprint arXiv:1606.04586 (2016).\n[3] Pham et al. “Meta Pseudo Labels.” CVPR 2021.\n[4] Laine \u0026amp; Aila. “Temporal Ensembling for Semi-Supervised Learning” ICLR 2017.\n[5] Tarvaninen \u0026amp; Valpola. “Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results.” NeuriPS 2017\n[6] Xie et al. “Unsupervised Data Augmentation for Consistency Training.” NeuriPS 2020.\n[7] Miyato et al. “Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning.” IEEE transactions on pattern analysis and machine intelligence 41.8 (2018).\n[8] Verma et al. “Interpolation consistency training for semi-supervised learning.” IJCAI 2019\n[9] Lee. “Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks.” ICML 2013 Workshop: Challenges in Representation Learning.\n[10] Iscen et al. “Label propagation for deep semi-supervised learning.” CVPR 2019.\n[11] Xie et al. “Self-training with Noisy Student improves ImageNet classification” CVPR 2020.\n[12] Jingfei Du et al. “Self-training Improves Pre-training for Natural Language Understanding.” 2020\n[13] Iscen et al. “Label propagation for deep semi-supervised learning.” CVPR 2019\n[14] Arazo et al. “Pseudo-labeling and confirmation bias in deep semi-supervised learning.” IJCNN 2020.\n[15] Berthelot et al. “MixMatch: A holistic approach to semi-supervised learning.” NeuriPS 2019\n[16] Berthelot et al. “ReMixMatch: Semi-supervised learning with distribution alignment and augmentation anchoring.” ICLR 2020\n[17] Sohn et al. “FixMatch: Simplifying semi-supervised learning with consistency and confidence.” CVPR 2020\n[18] Junnan Li et al. “DivideMix: Learning with Noisy Labels as Semi-supervised Learning.” 2020 [code]\n[19] Zoph et al. “Rethinking pre-training and self-training.” 2020.\n[20] Chen et al. “Big Self-Supervised Models are Strong Semi-Supervised Learners” 2020\n","permalink":"https://lilianweng.github.io/posts/2021-12-05-semi-supervised/","summary":"When facing a limited amount of labeled data for supervised learning tasks, four approaches are commonly discussed.\n Pre-training + fine-tuning: Pre-train a powerful task-agnostic model on a large unsupervised data corpus, e.g. pre-training LMs on free text, or pre-training vision models on unlabelled images via self-supervised learning, and then fine-tune it on the downstream task with a small set of labeled samples. Semi-supervised learning: Learn from the labelled and unlabeled samples together.","title":"Learning with not Enough Data Part 1: Semi-Supervised Learning"},{"content":"[Updated on 2022-03-13: add expert choice routing.] [Updated on 2022-06-10]: Greg and I wrote a shorted and upgraded version of this post, published on OpenAI Blog: \u0026ldquo;Techniques for Training Large Neural Networks\u0026rdquo;\nIn recent years, we are seeing better results on many NLP benchmark tasks with larger pre-trained language models. How to train large and deep neural networks is challenging, as it demands a large amount of GPU memory and a long horizon of training time.\nHowever an individual GPU worker has limited memory and the sizes of many large models have grown beyond a single GPU. There are several parallelism paradigms to enable model training across multiple GPUs, as well as a variety of model architecture and memory saving designs to help make it possible to train very large neural networks.\nTraining Parallelism The main bottleneck for training very large neural network models is the intense demand for a large amount of GPU memory, way above what can be hosted on an individual GPU machine. Besides the model weights (e.g. tens of billions of floating point numbers), it is usually even more expensive to store intermediate computation outputs such as gradients and optimizer states (e.g. momentums \u0026amp; variations in Adam). Additionally training a large model often pairs with a large training corpus and thus a single process may just take forever.\nAs a result, parallelism is necessary. Parallelism can happen at different dimensions, including data, model architecture, and tensor operation.\nData Parallelism The most naive way for Data parallelism (DP) is to copy the same model weights into multiple workers and assign a fraction of data to each worker to be processed at the same time.\nNaive DP cannot work well if the model size is larger than a single GPU node’s memory. Methods like GeePS (Cui et al. 2016) offload temporarily unused parameters back to CPU to work with limited GPU memory when the model is too big to fit into one machine. The data swapping transfer should happen at the backend and not interfere with training computation.\nAt the end of each minibatch, workers need to synchronize gradients or weights to avoid staleness. There are two main synchronization approaches and both have clear pros \u0026amp; cons.\n Bulk synchronous parallels (BSP): Workers sync data at the end of every minibatch. It prevents model weights staleness and good learning efficiency but each machine has to halt and wait for others to send gradients. Asynchronous parallel (ASP): Every GPU worker processes the data asynchronously, no waiting or stalling. However, it can easily lead to stale weights being used and thus lower the statistical learning efficiency. Even though it increases the computation time, it may not speed up training time to convergence. Somewhere in the middle is to synchronize gradients globally once every $x$ iterations ($x \u0026gt; 1$). This feature is called “gradient accumulation” in Distribution Data Parallel (DDP) since Pytorch v1.5 (Li et al. 2021). Bucketing gradients avoid immediate AllReduce operations but instead buckets multiple gradients into one AllReduce to improve throughput. Computation and communication scheduling optimization can be made based on the computation graph.\nFig. 1. Pseudo code for Pytorch DDP. (Image source: Li et al. 2021) Model Parallelism Model parallelism (MP) aims to solve the case when the model weights cannot fit into a single node. The computation and model parameters are partitioned across multiple machines. Different from data parallelism where each worker hosts a full copy of the entire model, MP only allocates a fraction of model parameters on one worker and thus both the memory usage and the computation are reduced.\nSince deep neural networks usually contain a stack of vertical layers, it feels straightforward to split a large model by layer, where a small consecutive set of layers are grouped into one partition on one worker. However, a naive implementation for running every data batch through multiple such workers with sequential dependency leads to big bubbles of waiting time and severe under-utilization of computation resources.\nFig. 2. A naive model parallelism setup where the model is vertically split into 4 partitions. Data is processed by one worker at a time due to sequential dependency, leading to large “bubbles” of idle time. (Image source: Huang et al. 2019) Pipeline Parallelism Pipeline parallelism (PP) combines model parallelism with data parallelism to reduce inefficient time “bubbles''. The main idea is to split one minibatch into multiple microbatches and enable each stage worker to process one microbatch simultaneously. Note that every microbatch needs two passes, one forward and one backward. Inter-worker communication only transfers activations (forward) and gradients (backward). How these passes are scheduled and how the gradients are aggregated vary in different approaches. The number of partitions (workers) is also known as pipeline depth.\nIn GPipe (Huang et al. 2019) gradients from multiple microbatches are aggregated and applied synchronously at the end. The synchronous gradient descent guarantees learning consistency and efficiency irrespective of the number of workers. As shown in Fig. 3, bubbles still exist but are much smaller than what’s in Fig. 2. Given $m$ evenly split microbatches and $d$ partitions, assuming both forward and backward per microbatch take one unit of time, the fraction of bubble is:\n $$ 1 - \\frac{2md}{(2m + 2(d-1))d} = \\frac{d-1}{m+d-1} $$ The GPipe paper observed that the bubble overhead is almost negligible if the number of microbatches is more than 4x the number of partitions $m \u0026gt; 4d$ (when activation recomputation is applied).\nFig. 3. Illustration of pipeline parallelism in GPipe with 4 microbatches and 4 partitions. GPipe aggregates and updates gradients across devices synchronously at the end of every batch. (Image source: Huang et al. 2019) GPipe achieves almost linear speedup in throughput with the number of devices, although it is not always guaranteed if the model parameters are not evenly distributed across workers.\nPipeDream (Narayanan et al. 2019) schedules each worker to alternatively process the forward and backward passes (1F1B). PipeDream names each model partition “stage” and each stage worker can have multiple replicas to run data parallelism. In this process, PipeDream uses a deterministic round-robin load balancing strategy to assign work among multiple replicas of stages to ensure that the forward and backward passes for the same minibatch happen on the same replica.\nFig. 4. Illustration of `1F1B` microbatch scheduling in PipeDream. (Image source: Harlap et al. 2018) Since PipeDream does not have an end-of-batch global gradient sync across all the workers, an native implementation of 1F1B can easily lead to the forward and backward passes of one microbatch using different versions of model weights, thus lowering the learning efficiency. PipeDream proposed a few designs to tackle this issue:\n Weight stashing: Each worker keeps track of several model versions and makes sure that the same version of weights are used in the forward and backward passes given one data batch. Vertical sync (Optional): The version of model weights flows between stage workers together with activations and gradients. Then the computation adopts the corresponding stashed version propagated from the previous worker. This process keeps version consistency across workers. Note that it is asynchronous, different from GPipe. At the beginning of a training run, PipeDream first profiles the computation memory cost and time of each layer in the model and then optimizes a solution for partitioning layers into stages, which is a dynamic programming problem.\nFig. 5. Results for VGG16 on ILSVRC12. (Top) Accuracy vs time. The integer marks the number of stage workers. ASP = Asynchronous parallel \u0026 BSP = Bulk synchronous parallels. (Bottom) Training time speedup for different parallelism configurations. Straight pipeline refers to pipeline parallelism without data parallelism. (Image source: Harlap et al. 2018) Two variations of PipeDream were later proposed to reduce the memory footprint by stashed model versions (Narayanan et al. 2021).\nPipeDream-flush adds a globally synchronized pipeline flush periodically, just like GPipe. In this way, it greatly reduces the memory footprint (i.e. only maintain a single version of model weights) by sacrificing a little throughput.\nFig. 6. Illustration of pipeline scheduling in PipeDream-flush. (Image source: (Narayanan et al. 2021) PipeDream-2BW maintains only two versions of model weights, where “2BW” is short for “double-buffered weights”. It generates a new model version every $k$ microbatches and $k$ should be larger than the pipeline depth $d$, $k \u0026gt; d$. A newly updated model version cannot fully replace the old version immediately since some leftover backward passes still depend on the old version. In total only two versions need to be saved so the memory cost is much reduced.\nFig. 7. Illustration of pipeline scheduling in PipeDream-2BW. (Image source: (Narayanan et al. 2021) Tensor Parallelism Both model and pipeline parallelisms split a model vertically. OTOH we can horizontally partition the computation for one tensor operation across multiple devices, named Tensor parallelism (TP).\nLet\u0026rsquo;s take the transformer as an example given its popularity. The transformer model mainly consists of layers of MLP and self-attention blocks. Megatron-LM (Shoeybi et al. 2020) adopts a simple way to parallelize intra-layer computation for MLP and self-attention.\nA MLP layer in a transformer contains a GEMM (General matrix multiply) followed by an non-linear GeLU transfer. Let’s split weight matrix $A$ by column:\n $$ \\begin{aligned} \\text{Split }A \u0026= [A_1, A_2] \\\\ Y \u0026=\\text{GeLU}(XA) \\\\ [Y_1, Y_2] \u0026= [\\text{GeLU}(XA_1), \\text{GeLU}(XA_2)] \\end{aligned} $$ The attention block runs GEMM with query ($Q$), key ($K$), and value weights ($V$) according to the above partitioning in parallel and then combines them with another GEMM to produce the attention head results.\n $$ \\text{Attention}(X, Q, K, V) = \\text{softmax}(\\frac{(XQ) (XK)^\\top}{\\sqrt{d_k}}) XV $$ Fig. 8. Illustration of tensor parallelism for key transformer components proposed in Megatron-LM. (Image source: Shoeybi et al. 2020) Narayanan et al. (2021) combined pipeline, tensor and data parallelism with a new pipeline scheduling strategy and named their approach PTD-P. Instead of only positioning a continuous set of layers (“model chunk”) on a device, each worker can be assigned with multiple chunks of smaller continuous subsets of layers (e.g. device 1 has layers 1, 2, 9, 10; device 2 has layers 3, 4, 11, 12; each has two model chunks). The number of microbatches in one batch should be exactly divided by the number of workers ($m % d = 0$). If there are $v$ model chunks per worker, the pipeline bubble time can be reduced by a multiplier of $v$ compared to a GPipe scheduling.\nFig. 9. (Top) Default `1F1B` pipeline schedule as in PipeDream-flush. (Bottom) Interleaved 1F1B pipeline schedule. First model chunks are in dark colors and second chunks are in light colors. (Image source: Narayanan et al. 202)) Mixture-of-Experts (MoE) The Mixture-of-Experts (MoE) approach attracts a lot of attention recently as researchers (mainly from Google) try to push the limit of model size. The core of the idea is ensembling learning: Combination of multiple weak learners gives you a strong learner!\nWithin one deep neural network, ensembling can be implemented with a gating mechanism connecting multiple experts (Shazeer et al., 2017). The gating mechanism controls which subset of the network (e.g. which experts) should be activated to produce outputs. The paper named it \u0026ldquo;sparsely gated mixture-of-experts\u0026rdquo; (MoE) layer.\nPrecisely one MoE layer contains\n $n$ feed-forward networks as experts $\\{E_i\\}^n_{i=1}$ A trainable gating network $G$ to learn a probability distribution over $n$ experts so as to route the traffic to a few selected experts. Depending on the gating outputs, not every expert has to be evaluated. When the number of experts is too large, we can consider using a two-level hierarchical MoE.\nFig. 10. Illustration of a mixture-of-experts (MoE) layer. Only 2 out of $n$ experts are selected and activated by the gating network. (Image source: Shazeer et al., 2017) A simple choice of $G$ is to multiply the input with a trainable weight matrix $G_g$ and then do softmax: $G_\\sigma (x) = \\text{softmax}(x W_g)$. However, this produces a dense control vector for gating and does not help save computation resources because we don\u0026rsquo;t need to evaluate an expert only when $G^{(i)}(x)=0$. Thus the MoE layer only keeps the top $k$ values. It also adds tunable Gaussian noise into $G$ to improve load balancing. This mechanism is called noisy top-k gating.\n $$ \\begin{aligned} G(x) \u0026= \\text{softmax}( \\text{topk}(H(x), k)) \\\\ H^{(i)}(x) \u0026= (xW_g)^{(i)} + \\epsilon \\cdot \\text{softplus}((xW_\\text{noise})^{(i)} ); \\quad \\epsilon \\sim \\mathcal{N}(0, \\mathbf{1}) \\\\ \\text{topk}^{(i)}(v, k) \u0026= \\begin{cases} v^{(i)} \u0026 \\text{if }v^{(i)}\\text{ is in the top }k\\text{ elements of }v \\\\ -\\infty \u0026 \\text{otherwise} \\end{cases} \\end{aligned} $$ where the superscript $v^{(i)}$ denotes the i-th dimension of the vector $v$. The function $\\text{topk}(., k)$ selected the top $k$ dimensions with highest values by setting other dimensions to $-\\infty$.\nTo avoid the self-reinforcing effect that the gating network may favor a few strong experts all the time, Shazeer et al. (2017) proposed a soft constraint via an additional importance loss to encourage all the experts to have the same weights. It is equivalent to the square of the coefficient of variation of batchwise average value per expert.\n $$ L_\\text{aux} = w_\\text{aux} \\cdot \\text{CV}(\\sum_{x \\in X} G(x))^2 $$ where $ \\text{CV}$ is the coefficient of variation and the loss weight $w_\\text{aux}$ is a hyperparameter to tune.\nBecause every expert network only gets a fraction of training samples (\u0026ldquo;The shrinking batch problem\u0026rdquo;), we should try to use a batch size as large as possible in MoE. However, it is restricted by GPU memory. Data parallelism and model parallelism can be applied to improve the throughput.\nFig. 11. Test perplexity on 1-Billion-Word language modeling benchmark. (Left) The model capacity increases from left to right, containing 4, 32, 256, 256, 1024 and 4096 experts. (Right) Performance of the 4 billion parameters MoE model, the largest one in the left figure, under different computation budgets. (Image source: Shazeer et al., 2017) GShard (Lepikhin et al., 2020) scales the MoE transformer model up to 600 billion parameters with sharding. The MoE transformer replaces every other feed forward layer with a MoE layer. The sharded MoE transformer only has the MoE layers sharded across multiple machines, while other layers are simply duplicated.\nThere are several improved designs for the gating function $G$ in GShard:\n Expert capacity: The amount of tokens going through one expert should not go above a threshold, named “expert capacity”. If a token is routed to experts that have reached their capacity, the token would be marked “overflowed” and the gating output is changed to a zero vector. Local group dispatching: Tokens are evenly partitioned into multiple local groups and the expert capacity is enforced on the group level. Auxiliary loss: The motivation is similar to the original MoE aux loss. They add an auxiliary loss to minimize the mean square of the fraction of data routed to each expert. Random routing: The 2nd-best expert is selected with a probability proportional to its weight; otherwise, GShard follows a random routing, so as to add some randomness. Fig. 12. Pseudo code of the group-level top-2 gating mechanism with auxiliary loss in GShard. (Image source: Lepikhin et al., 2020) Switch Transformer (Fedus et al. 2021) scales the model size up to trillions of parameters (!!) by replacing the dense feed forward layer with a sparse switch FFN layer in which each input is only routed to one expert network. The auxiliary loss for load balancing is $\\text{loss}_\\text{aux} = w_\\text{aux} \\sum_{i=1}^n f_i p_i$ given $n$ experts, where $f_i$ is the fraction of tokens routed to the $i$-th expert and $p_i$ is the routing probability for expert $i$ predicted by the gating network.\nFig. 13. Switch transformer. The sparse switch FFN layer is in the blue boxes. (Image source: Fedus et al. 2021) To improve training stability, switch transformer incorporates the following designs:\n Selective precision. They showed that selectively casting only a local part of the model to FP32 precision improves stability, while avoiding the expensive communication cost of FP32 tensors. The FP32 precision is only used within the body of the router function and the results are recast to FP16. Smaller initialization. The initialization of weight matrices is sampled from a truncated normal distribution with mean $\\mu=0$ and stdev $\\sigma = \\sqrt{s/n}$. They also recommended reducing the transformer initialization scale parameter $s=1$ to $s=0.1$. Use higher expert dropout. Fine-tuning often works with a small dataset. To avoid overfitting, the dropout rate within each expert is increased by a significant amount. Interestingly they found that increasing dropout across all layers lead to poor performance. In the paper, they used a dropout rate 0.1 at non-expert layers but 0.4 within expert FF layers. The switch transformer paper summarized different data and model parallelism strategies for training large models with a nice illustration:\nFig. 14. An illustration of various parallelism strategies on how (Top) model weights and (Bottom) data are split over multiple GPU cores. In the top row, each color denotes a unique weight matrix. In the bottom row, different colors indicate different sets of tokens. (Image source: Fedus et al. 2021) Both GShard top-2 and Switch Transformer top-1 depend on token choice, where each token picks the best one or two experts to route through. They both adopt an auxiliary loss to encourage more balanced load allocation but it does not guarantee the best performance. Furthermore, the expert capacity limit may lead to wasted tokens as they would be discarded if an expert reaches its capacity limit.\nExport Choice (EC) (Zhou et al. 2022) routing instead enables each expert to select the top-$k$ tokens. In this way, each expert naturally guarantees a fixed capacity and each token may be routed to multiple experts. EC can achieve perfect load balancing and is shown to improve training convergence by 2x.\nGiven $e$ experts and an input matrix $X \\in \\mathbb{R}^{n \\times d}$, the token-to-expert affinity scores are computed by: $$ S = \\text{softmax}(X \\cdot W_g), \\text{where } W_g \\in \\mathbb{R}^{d \\times e}, S \\in \\mathbb{R}^{n \\times e} $$\nA token-to-expert assignment is represented by three matrices, $I, G \\in \\mathbb{R}^{e\\times k}$ and $P \\in \\mathbb{R}^{e \\times k \\times n}$. $I[i,j]$ annotates which token is the $j$-th selection by the $i$-th expert. The gating matrix $G$ stores the routing weights of selected tokens. $P$ is the one-hot version of $I$, used to produce the input matrix ($P \\cdot X \\in \\mathbb{R}^{e \\times k \\times d}$) for the gated FFN layer. $$ G, I = \\text{top-k}(S^\\top, k)\\quad P = \\text{one-hot}(I) $$\nOne regularization that export choice routing explored is to limit the maximum number of experts per token.\n $$ \\begin{aligned} \u0026 \\max_A \\langle S^\\top, A\\rangle + \\lambda H(A) \\\\ \\text{s.t.} \u0026 \\forall i: \\sum_{j'} A[i, j'] = k,\\quad \\forall j: \\sum_{i'} A[i', j] \\leq b,\\quad \\forall i,j: 0 \\leq A[i,j] \\leq 1 \\end{aligned} $$ where each entry $A[i,j]$ in $A \\in \\mathbb{R}^{e \\times n}$ marks whether the $i$-the expert selects the $j$-th token. Solving this is non-trivial. The paper used Dykstra\u0026rsquo;s algorithm that runs a sequence of multiple iterative computation steps. Capped expert choice results in a slight decrease in the fine-tuning performance in the experiments.\nThe parameter $k$ is determined by $k=nc/e$, where $n$ is the total number of tokens in one batch and $c$ is a capacity factor indicating the average number of experts used by one token. The paper used $c=2$ in most experiments, but EC with $c=1$ still outperforms the top-1 token choice gating. Interestingly, $c=0.5$ only marginally hurts the training performance.\nOne big drawback of EC is that it does not work when the batch size is too small, neither for auto-regressive text generation, because it needs to know the future tokens to do the top-$k$ selection.\nOther Memory Saving Designs CPU Offloading When the GPU memory is full, one option is to offload temporarily unused data to CPU and read them back when needed later (Rhu et al. 2016). The idea of CPU offloading is straightforward but is less popular in recent years due to the slowdown it brings into the training time.\nActivation Recomputation Activation recomputation (also known as “activation checkpointing” or “gradient checkpointing”; Chen et al. 2016) is a smart yet simple idea to reduce memory footprint at the cost of computation time. It reduces the memory cost of training a $\\ell$ layer deep neural net to $O(\\sqrt{\\ell})$, which only additionally consumes an extra forward pass computation per batch.\nLet\u0026rsquo;s say, we evenly divide an $\\ell$-layer network into $d$ partitions. Only activations at partition boundaries are saved and communicated between workers. Intermediate activations at intra-partition layers are still needed for computing gradients so they are recomputed during backward passes. With activation recomputation, the memory cost for training $M(\\ell)$ is:\n $$ M(\\ell) =\\max_{i=1,\\dots,k} \\underbrace{\\text{cost-of-one-partition}(i)}_\\text{cost of back-propagation on the i-th partition} + \\underbrace{O(d)}_\\text{store intermediate outputs} = O(\\frac{\\ell}{d}) + O(d) $$ The minimum cost is $O(\\sqrt{\\ell})$ at $d=\\sqrt{\\ell}$.\nActivation recompuation trick can give sublinear memory cost with respect to the model size.\nFig. 15. The memory cost of different memory saving algorithms. Sharing: Memory used by intermediate results is recycled when no longer needed. Inplace: Save the output directly into memory of an input value. (Image source: Chen et al. 2016) Mixed Precision Training Narang \u0026amp; Micikevicius et al. (2018) introduced a method to train models using half-precision floating point (FP16) numbers without losing model accuracy.\nFig. 16. The procedure of mixed precision training at one layer. (Image source: Narang \u0026 Micikevicius, et al. 2018) Three techniques to avoid losing critical information at half-precision:\n Full-precision master copy of weights. Maintain a full precision (FP32) copy of model weights that accumulates gradients. The numbers are rounded up to half-precision for forward \u0026amp; backward passes. The motivation is that each gradient update (i.e. gradient times the learning rate) might be too small to be fully contained within the FP16 range (i.e. $2^{-24}$ becomes zero in FP16). Loss scaling. Scale up the loss to better handle gradients with small magnitudes (See Fig. 16). Scaling up the gradients helps shift them to occupy a larger section towards the right section (containing larger values) of the representable range, preserving values that are otherwise lost. Arithmetic precision. For common network arithmetic (e.g. vector dot-product, reduction by summing up vector elements), we can accumulate the partial results in FP32 and then save the final output as FP16 before saving into memory. Point-wise operations can be executed in either FP16 or FP32. Fig. 17. The histogram of gradients in full precision. The left part up to $2^{-24}$ will be zero-ed off once the model switches to FP16. (Image source: Narang \u0026 Micikevicius, et al. 2018) In their experiments, loss scaling is not needed for some networks (e.g. image classification, Faster R-CNN), but necessary for others (e.g. Multibox SSD, big LSTM language model).\nCompression Intermediate results often consume a lot of memory, although they are only needed in one forward pass and one backward pass. There is a noticeable temporal gap between these two uses. Thus Jain et al. (2018) proposed a data encoding strategy to compress the intermediate results after the first use in the first pass and then decode it back for back-propagation later.\nTheir system Gist incorporates two encoding schemes: Layer-specific lossless encoding; focus on ReLU-Pool (“Binarize”) and ReLU-Conv (“Sparse storage and dense computation”) patterns. Aggressive lossy encoding; use delayed precision reduction (DPR). They observed that the first immediate use of feature maps should be kept at high precision but the second use can tolerate lower precision.\nThe experiments showed that Gist can reduce the memory cost by 2x across 5 SOTA image classification DNNs, with an average of 1.8x with only 4% performance overhead.\nMemory Efficient Optimizer Optimizers are eager for memory consumption. Take the popular Adam optimizer as an example, it internally needs to maintain momentums and variances, both at the same scale as gradients and model parameters. All out of a sudden, we need to save 4x the memory of model weights.\nSeveral optimizers have been proposed to reduce the memory footprint. For example, instead of storing the full momentums and variations as in Adam, Adafactor (Shazeer et al. 2018) only tracks the per-row and per-column sums of the moving averages and then estimates the second moments based on these sums. SM3 (Anil et al. 2019) describes a different adaptive optimization method, leading to largely reduced memory as well.\nZeRO (Zero Redundancy Optimizer; Rajbhandari et al. 2019) optimizes the memory used for training large models based on the observation about two major memory consumption of large model training:\n The majority is occupied by model states, including optimizer states (e.g. Adam momentums and variances), gradients and parameters. Mixed-precision training demands a lot of memory since the optimizer needs to keep a copy of FP32 parameters and other optimizer states, besides the FP16 version. The remaining is consumed by activations, temporary buffers and unusable fragmented memory (named residual states in the paper). ZeRO combines two approaches, ZeRO-DP and ZeRO-R. ZeRO-DP is an enhanced data parallelism to avoid simple redundancy over model states. It partitions optimizer state, gradients and parameters across multiple data parallel processes via a dynamic communication schedule to minimize the communication volume. ZeRO-R optimizes the memory consumption of residual states, using partitioned activation recomputation, constant buffer size and on-the-fly memory defragmentation.\nCitation Cited as:\n Weng, Lilian. (Sep 2021). How to train really large models on many GPUs? Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2021-09-25-train-large/.\n Or\n@article{weng2021large, title = \u0026quot;How to Train Really Large Models on Many GPUs?\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2021\u0026quot;, month = \u0026quot;Sep\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2021-09-25-train-large/\u0026quot; } References [1] Li et al. “PyTorch Distributed: Experiences on Accelerating Data Parallel Training” VLDB 2020.\n[2] Cui et al. “GeePS: Scalable deep learning on distributed GPUs with a GPU-specialized parameter server” EuroSys 2016\n[3] Shoeybi et al. “Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.” arXiv preprint arXiv:1909.08053 (2019).\n[4] Narayanan et al. “Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM.” arXiv preprint arXiv:2104.04473 (2021).\n[5] Huang et al. “GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism.” arXiv preprint arXiv:1811.06965 (2018).\n[6] Narayanan et al. \u0026ldquo;PipeDream: Generalized Pipeline Parallelism for DNN Training.\u0026quot; SOSP 2019.\n[7] Narayanan et al. “Memory-Efficient Pipeline-Parallel DNN Training.” ICML 2021.\n[8] Shazeer et al. “The Sparsely-Gated Mixture-of-Experts Layer Noam.” arXiv preprint arXiv:1701.06538 (2017).\n[9] Lepikhin et al. “GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding.” arXiv preprint arXiv:2006.16668 (2020).\n[10] Fedus et al. “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity.” arXiv preprint arXiv:2101.03961 (2021).\n[11] Narang \u0026amp; Micikevicius, et al. “Mixed precision training.” ICLR 2018.\n[12] Chen et al. 2016 “Training Deep Nets with Sublinear Memory Cost.” arXiv preprint arXiv:1604.06174 (2016).\n[13] Jain et al. “Gist: Efficient data encoding for deep neural network training.” ISCA 2018.\n[14] Shazeer \u0026amp; Stern. “Adafactor: Adaptive learning rates with sublinear memory cost.” arXiv preprint arXiv:1804.04235 (2018).\n[15] Anil et al. “Memory-Efficient Adaptive Optimization.” arXiv preprint arXiv:1901.11150 (2019).\n[16] Rajbhandari et al. “ZeRO: Memory Optimization Towards Training A Trillion Parameter Models Samyam.” arXiv preprint arXiv:1910.02054 (2019).\n[17] Zhou et al. “Mixture-of-Experts with Expert Choice Routing” arXiv preprint arXiv:2202.09368 (2022).\n","permalink":"https://lilianweng.github.io/posts/2021-09-25-train-large/","summary":"[Updated on 2022-03-13: add expert choice routing.] [Updated on 2022-06-10]: Greg and I wrote a shorted and upgraded version of this post, published on OpenAI Blog: \u0026ldquo;Techniques for Training Large Neural Networks\u0026rdquo;\nIn recent years, we are seeing better results on many NLP benchmark tasks with larger pre-trained language models. How to train large and deep neural networks is challenging, as it demands a large amount of GPU memory and a long horizon of training time.","title":"How to Train Really Large Models on Many GPUs?"},{"content":"[Updated on 2021-09-19: Highly recommend this blog post on score-based generative modeling by Yang Song (author of several key papers in the references)]. [Updated on 2022-08-27: Added classifier-free guidance, GLIDE, unCLIP and Imagen. [Updated on 2022-08-31: Added latent diffusion model.\nSo far, I\u0026rsquo;ve written about three types of generative models, GAN, VAE, and Flow-based models. They have shown great success in generating high-quality samples, but each has some limitations of its own. GAN models are known for potentially unstable training and less diversity in generation due to their adversarial training nature. VAE relies on a surrogate loss. Flow models have to use specialized architectures to construct reversible transform.\nDiffusion models are inspired by non-equilibrium thermodynamics. They define a Markov chain of diffusion steps to slowly add random noise to data and then learn to reverse the diffusion process to construct desired data samples from the noise. Unlike VAE or flow models, diffusion models are learned with a fixed procedure and the latent variable has high dimensionality (same as the original data).\nFig. 1. Overview of different types of generative models. What are Diffusion Models? Several diffusion-based generative models have been proposed with similar ideas underneath, including diffusion probabilistic models (Sohl-Dickstein et al., 2015), noise-conditioned score network (NCSN; Yang \u0026amp; Ermon, 2019), and denoising diffusion probabilistic models (DDPM; Ho et al. 2020).\nForward diffusion process Given a data point sampled from a real data distribution $\\mathbf{x}_0 \\sim q(\\mathbf{x})$, let us define a forward diffusion process in which we add small amount of Gaussian noise to the sample in $T$ steps, producing a sequence of noisy samples $\\mathbf{x}_1, \\dots, \\mathbf{x}_T$. The step sizes are controlled by a variance schedule $\\{\\beta_t \\in (0, 1)\\}_{t=1}^T$.\n $$ q(\\mathbf{x}_t \\vert \\mathbf{x}_{t-1}) = \\mathcal{N}(\\mathbf{x}_t; \\sqrt{1 - \\beta_t} \\mathbf{x}_{t-1}, \\beta_t\\mathbf{I}) \\quad q(\\mathbf{x}_{1:T} \\vert \\mathbf{x}_0) = \\prod^T_{t=1} q(\\mathbf{x}_t \\vert \\mathbf{x}_{t-1}) $$ The data sample $\\mathbf{x}_0$ gradually loses its distinguishable features as the step $t$ becomes larger. Eventually when $T \\to \\infty$, $\\mathbf{x}_T$ is equivalent to an isotropic Gaussian distribution.\nFig. 2. The Markov chain of forward (reverse) diffusion process of generating a sample by slowly adding (removing) noise. (Image source: Ho et al. 2020 with a few additional annotations) A nice property of the above process is that we can sample $\\mathbf{x}_t$ at any arbitrary time step $t$ in a closed form using reparameterization trick. Let $\\alpha_t = 1 - \\beta_t$ and $\\bar{\\alpha}_t = \\prod_{i=1}^t \\alpha_i$:\n $$ \\begin{aligned} \\mathbf{x}_t \u0026= \\sqrt{\\alpha_t}\\mathbf{x}_{t-1} + \\sqrt{1 - \\alpha_t}\\boldsymbol{\\epsilon}_{t-1} \u0026 \\text{ ;where } \\boldsymbol{\\epsilon}_{t-1}, \\boldsymbol{\\epsilon}_{t-2}, \\dots \\sim \\mathcal{N}(\\mathbf{0}, \\mathbf{I}) \\\\ \u0026= \\sqrt{\\alpha_t \\alpha_{t-1}} \\mathbf{x}_{t-2} + \\sqrt{1 - \\alpha_t \\alpha_{t-1}} \\bar{\\boldsymbol{\\epsilon}}_{t-2} \u0026 \\text{ ;where } \\bar{\\boldsymbol{\\epsilon}}_{t-2} \\text{ merges two Gaussians (*).} \\\\ \u0026= \\dots \\\\ \u0026= \\sqrt{\\bar{\\alpha}_t}\\mathbf{x}_0 + \\sqrt{1 - \\bar{\\alpha}_t}\\boldsymbol{\\epsilon} \\\\ q(\\mathbf{x}_t \\vert \\mathbf{x}_0) \u0026= \\mathcal{N}(\\mathbf{x}_t; \\sqrt{\\bar{\\alpha}_t} \\mathbf{x}_0, (1 - \\bar{\\alpha}_t)\\mathbf{I}) \\end{aligned} $$ (*) Recall that when we merge two Gaussians with different variance, $\\mathcal{N}(\\mathbf{0}, \\sigma_1^2\\mathbf{I})$ and $\\mathcal{N}(\\mathbf{0}, \\sigma_2^2\\mathbf{I})$, the new distribution is $\\mathcal{N}(\\mathbf{0}, (\\sigma_1^2 + \\sigma_2^2)\\mathbf{I})$. Here the merged standard deviation is $\\sqrt{(1 - \\alpha_t) + \\alpha_t (1-\\alpha_{t-1})} = \\sqrt{1 - \\alpha_t\\alpha_{t-1}}$.\nUsually, we can afford a larger update step when the sample gets noisier, so $\\beta_1 \u0026lt; \\beta_2 \u0026lt; \\dots \u0026lt; \\beta_T$ and therefore $\\bar{\\alpha}_1 \u0026gt; \\dots \u0026gt; \\bar{\\alpha}_T$.\nConnection with stochastic gradient Langevin dynamics Langevin dynamics is a concept from physics, developed for statistically modeling molecular systems. Combined with stochastic gradient descent, stochastic gradient Langevin dynamics (Welling \u0026amp; Teh 2011) can produce samples from a probability density $p(\\mathbf{x})$ using only the gradients $\\nabla_\\mathbf{x} \\log p(\\mathbf{x})$ in a Markov chain of updates:\n $$ \\mathbf{x}_t = \\mathbf{x}_{t-1} + \\frac{\\delta}{2} \\nabla_\\mathbf{x} \\log p(\\mathbf{x}_{t-1}) + \\sqrt{\\delta} \\boldsymbol{\\epsilon}_t ,\\quad\\text{where } \\boldsymbol{\\epsilon}_t \\sim \\mathcal{N}(\\mathbf{0}, \\mathbf{I}) $$ where $\\delta$ is the step size. When $T \\to \\infty, \\epsilon \\to 0$, $\\mathbf{x}_T$ equals to the true probability density $p(\\mathbf{x})$.\nCompared to standard SGD, stochastic gradient Langevin dynamics injects Gaussian noise into the parameter updates to avoid collapses into local minima.\nReverse diffusion process If we can reverse the above process and sample from $q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t)$, we will be able to recreate the true sample from a Gaussian noise input, $\\mathbf{x}_T \\sim \\mathcal{N}(\\mathbf{0}, \\mathbf{I})$. Note that if $\\beta_t$ is small enough, $q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t)$ will also be Gaussian. Unfortunately, we cannot easily estimate $q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t)$ because it needs to use the entire dataset and therefore we need to learn a model $p_\\theta$ to approximate these conditional probabilities in order to run the reverse diffusion process.\n $$ p_\\theta(\\mathbf{x}_{0:T}) = p(\\mathbf{x}_T) \\prod^T_{t=1} p_\\theta(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t) \\quad p_\\theta(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t) = \\mathcal{N}(\\mathbf{x}_{t-1}; \\boldsymbol{\\mu}_\\theta(\\mathbf{x}_t, t), \\boldsymbol{\\Sigma}_\\theta(\\mathbf{x}_t, t)) $$ Fig. 3. An example of training a diffusion model for modeling a 2D swiss roll data. (Image source: Sohl-Dickstein et al., 2015) It is noteworthy that the reverse conditional probability is tractable when conditioned on $\\mathbf{x}_0$:\n $$ q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0) = \\mathcal{N}(\\mathbf{x}_{t-1}; \\color{blue}{\\tilde{\\boldsymbol{\\mu}}}(\\mathbf{x}_t, \\mathbf{x}_0), \\color{red}{\\tilde{\\beta}_t} \\mathbf{I}) $$ Using Bayes' rule, we have:\n $$ \\begin{aligned} q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0) \u0026= q(\\mathbf{x}_t \\vert \\mathbf{x}_{t-1}, \\mathbf{x}_0) \\frac{ q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_0) }{ q(\\mathbf{x}_t \\vert \\mathbf{x}_0) } \\\\ \u0026\\propto \\exp \\Big(-\\frac{1}{2} \\big(\\frac{(\\mathbf{x}_t - \\sqrt{\\alpha_t} \\mathbf{x}_{t-1})^2}{\\beta_t} + \\frac{(\\mathbf{x}_{t-1} - \\sqrt{\\bar{\\alpha}_{t-1}} \\mathbf{x}_0)^2}{1-\\bar{\\alpha}_{t-1}} - \\frac{(\\mathbf{x}_t - \\sqrt{\\bar{\\alpha}_t} \\mathbf{x}_0)^2}{1-\\bar{\\alpha}_t} \\big) \\Big) \\\\ \u0026= \\exp \\Big(-\\frac{1}{2} \\big(\\frac{\\mathbf{x}_t^2 - 2\\sqrt{\\alpha_t} \\mathbf{x}_t \\color{blue}{\\mathbf{x}_{t-1}} \\color{black}{+ \\alpha_t} \\color{red}{\\mathbf{x}_{t-1}^2} }{\\beta_t} + \\frac{ \\color{red}{\\mathbf{x}_{t-1}^2} \\color{black}{- 2 \\sqrt{\\bar{\\alpha}_{t-1}} \\mathbf{x}_0} \\color{blue}{\\mathbf{x}_{t-1}} \\color{black}{+ \\bar{\\alpha}_{t-1} \\mathbf{x}_0^2} }{1-\\bar{\\alpha}_{t-1}} - \\frac{(\\mathbf{x}_t - \\sqrt{\\bar{\\alpha}_t} \\mathbf{x}_0)^2}{1-\\bar{\\alpha}_t} \\big) \\Big) \\\\ \u0026= \\exp\\Big( -\\frac{1}{2} \\big( \\color{red}{(\\frac{\\alpha_t}{\\beta_t} + \\frac{1}{1 - \\bar{\\alpha}_{t-1}})} \\mathbf{x}_{t-1}^2 - \\color{blue}{(\\frac{2\\sqrt{\\alpha_t}}{\\beta_t} \\mathbf{x}_t + \\frac{2\\sqrt{\\bar{\\alpha}_{t-1}}}{1 - \\bar{\\alpha}_{t-1}} \\mathbf{x}_0)} \\mathbf{x}_{t-1} \\color{black}{ + C(\\mathbf{x}_t, \\mathbf{x}_0) \\big) \\Big)} \\end{aligned} $$ where $C(\\mathbf{x}_t, \\mathbf{x}_0)$ is some function not involving $\\mathbf{x}_{t-1}$ and details are omitted. Following the standard Gaussian density function, the mean and variance can be parameterized as follows (recall that $\\alpha_t = 1 - \\beta_t$ and $\\bar{\\alpha}_t = \\prod_{i=1}^T \\alpha_i$):\n $$ \\begin{aligned} \\tilde{\\beta}_t \u0026= 1/(\\frac{\\alpha_t}{\\beta_t} + \\frac{1}{1 - \\bar{\\alpha}_{t-1}}) = 1/(\\frac{\\alpha_t - \\bar{\\alpha}_t + \\beta_t}{\\beta_t(1 - \\bar{\\alpha}_{t-1})}) = \\color{green}{\\frac{1 - \\bar{\\alpha}_{t-1}}{1 - \\bar{\\alpha}_t} \\cdot \\beta_t} \\\\ \\tilde{\\boldsymbol{\\mu}}_t (\\mathbf{x}_t, \\mathbf{x}_0) \u0026= (\\frac{\\sqrt{\\alpha_t}}{\\beta_t} \\mathbf{x}_t + \\frac{\\sqrt{\\bar{\\alpha}_{t-1} }}{1 - \\bar{\\alpha}_{t-1}} \\mathbf{x}_0)/(\\frac{\\alpha_t}{\\beta_t} + \\frac{1}{1 - \\bar{\\alpha}_{t-1}}) \\\\ \u0026= (\\frac{\\sqrt{\\alpha_t}}{\\beta_t} \\mathbf{x}_t + \\frac{\\sqrt{\\bar{\\alpha}_{t-1} }}{1 - \\bar{\\alpha}_{t-1}} \\mathbf{x}_0) \\color{green}{\\frac{1 - \\bar{\\alpha}_{t-1}}{1 - \\bar{\\alpha}_t} \\cdot \\beta_t} \\\\ \u0026= \\frac{\\sqrt{\\alpha_t}(1 - \\bar{\\alpha}_{t-1})}{1 - \\bar{\\alpha}_t} \\mathbf{x}_t + \\frac{\\sqrt{\\bar{\\alpha}_{t-1}}\\beta_t}{1 - \\bar{\\alpha}_t} \\mathbf{x}_0\\\\ \\end{aligned} $$ Thanks to the nice property, we can represent $\\mathbf{x}_0 = \\frac{1}{\\sqrt{\\bar{\\alpha}_t}}(\\mathbf{x}_t - \\sqrt{1 - \\bar{\\alpha}_t}\\boldsymbol{\\epsilon}_t)$ and plug it into the above equation and obtain:\n $$ \\begin{aligned} \\tilde{\\boldsymbol{\\mu}}_t \u0026= \\frac{\\sqrt{\\alpha_t}(1 - \\bar{\\alpha}_{t-1})}{1 - \\bar{\\alpha}_t} \\mathbf{x}_t + \\frac{\\sqrt{\\bar{\\alpha}_{t-1}}\\beta_t}{1 - \\bar{\\alpha}_t} \\frac{1}{\\sqrt{\\bar{\\alpha}_t}}(\\mathbf{x}_t - \\sqrt{1 - \\bar{\\alpha}_t}\\boldsymbol{\\epsilon}_t) \\\\ \u0026= \\color{cyan}{\\frac{1}{\\sqrt{\\alpha_t}} \\Big( \\mathbf{x}_t - \\frac{1 - \\alpha_t}{\\sqrt{1 - \\bar{\\alpha}_t}} \\boldsymbol{\\epsilon}_t \\Big)} \\end{aligned} $$ As demonstrated in Fig. 2., such a setup is very similar to VAE and thus we can use the variational lower bound to optimize the negative log-likelihood.\n $$ \\begin{aligned} - \\log p_\\theta(\\mathbf{x}_0) \u0026\\leq - \\log p_\\theta(\\mathbf{x}_0) + D_\\text{KL}(q(\\mathbf{x}_{1:T}\\vert\\mathbf{x}_0) \\| p_\\theta(\\mathbf{x}_{1:T}\\vert\\mathbf{x}_0) ) \\\\ \u0026= -\\log p_\\theta(\\mathbf{x}_0) + \\mathbb{E}_{\\mathbf{x}_{1:T}\\sim q(\\mathbf{x}_{1:T} \\vert \\mathbf{x}_0)} \\Big[ \\log\\frac{q(\\mathbf{x}_{1:T}\\vert\\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{0:T}) / p_\\theta(\\mathbf{x}_0)} \\Big] \\\\ \u0026= -\\log p_\\theta(\\mathbf{x}_0) + \\mathbb{E}_q \\Big[ \\log\\frac{q(\\mathbf{x}_{1:T}\\vert\\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{0:T})} + \\log p_\\theta(\\mathbf{x}_0) \\Big] \\\\ \u0026= \\mathbb{E}_q \\Big[ \\log \\frac{q(\\mathbf{x}_{1:T}\\vert\\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{0:T})} \\Big] \\\\ \\text{Let }L_\\text{VLB} \u0026= \\mathbb{E}_{q(\\mathbf{x}_{0:T})} \\Big[ \\log \\frac{q(\\mathbf{x}_{1:T}\\vert\\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{0:T})} \\Big] \\geq - \\mathbb{E}_{q(\\mathbf{x}_0)} \\log p_\\theta(\\mathbf{x}_0) \\end{aligned} $$ It is also straightforward to get the same result using Jensen\u0026rsquo;s inequality. Say we want to minimize the cross entropy as the learning objective,\n $$ \\begin{aligned} L_\\text{CE} \u0026= - \\mathbb{E}_{q(\\mathbf{x}_0)} \\log p_\\theta(\\mathbf{x}_0) \\\\ \u0026= - \\mathbb{E}_{q(\\mathbf{x}_0)} \\log \\Big( \\int p_\\theta(\\mathbf{x}_{0:T}) d\\mathbf{x}_{1:T} \\Big) \\\\ \u0026= - \\mathbb{E}_{q(\\mathbf{x}_0)} \\log \\Big( \\int q(\\mathbf{x}_{1:T} \\vert \\mathbf{x}_0) \\frac{p_\\theta(\\mathbf{x}_{0:T})}{q(\\mathbf{x}_{1:T} \\vert \\mathbf{x}_{0})} d\\mathbf{x}_{1:T} \\Big) \\\\ \u0026= - \\mathbb{E}_{q(\\mathbf{x}_0)} \\log \\Big( \\mathbb{E}_{q(\\mathbf{x}_{1:T} \\vert \\mathbf{x}_0)} \\frac{p_\\theta(\\mathbf{x}_{0:T})}{q(\\mathbf{x}_{1:T} \\vert \\mathbf{x}_{0})} \\Big) \\\\ \u0026\\leq - \\mathbb{E}_{q(\\mathbf{x}_{0:T})} \\log \\frac{p_\\theta(\\mathbf{x}_{0:T})}{q(\\mathbf{x}_{1:T} \\vert \\mathbf{x}_{0})} \\\\ \u0026= \\mathbb{E}_{q(\\mathbf{x}_{0:T})}\\Big[\\log \\frac{q(\\mathbf{x}_{1:T} \\vert \\mathbf{x}_{0})}{p_\\theta(\\mathbf{x}_{0:T})} \\Big] = L_\\text{VLB} \\end{aligned} $$ To convert each term in the equation to be analytically computable, the objective can be further rewritten to be a combination of several KL-divergence and entropy terms (See the detailed step-by-step process in Appendix B in Sohl-Dickstein et al., 2015):\n $$ \\begin{aligned} L_\\text{VLB} \u0026= \\mathbb{E}_{q(\\mathbf{x}_{0:T})} \\Big[ \\log\\frac{q(\\mathbf{x}_{1:T}\\vert\\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{0:T})} \\Big] \\\\ \u0026= \\mathbb{E}_q \\Big[ \\log\\frac{\\prod_{t=1}^T q(\\mathbf{x}_t\\vert\\mathbf{x}_{t-1})}{ p_\\theta(\\mathbf{x}_T) \\prod_{t=1}^T p_\\theta(\\mathbf{x}_{t-1} \\vert\\mathbf{x}_t) } \\Big] \\\\ \u0026= \\mathbb{E}_q \\Big[ -\\log p_\\theta(\\mathbf{x}_T) + \\sum_{t=1}^T \\log \\frac{q(\\mathbf{x}_t\\vert\\mathbf{x}_{t-1})}{p_\\theta(\\mathbf{x}_{t-1} \\vert\\mathbf{x}_t)} \\Big] \\\\ \u0026= \\mathbb{E}_q \\Big[ -\\log p_\\theta(\\mathbf{x}_T) + \\sum_{t=2}^T \\log \\frac{q(\\mathbf{x}_t\\vert\\mathbf{x}_{t-1})}{p_\\theta(\\mathbf{x}_{t-1} \\vert\\mathbf{x}_t)} + \\log\\frac{q(\\mathbf{x}_1 \\vert \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_0 \\vert \\mathbf{x}_1)} \\Big] \\\\ \u0026= \\mathbb{E}_q \\Big[ -\\log p_\\theta(\\mathbf{x}_T) + \\sum_{t=2}^T \\log \\Big( \\frac{q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{t-1} \\vert\\mathbf{x}_t)}\\cdot \\frac{q(\\mathbf{x}_t \\vert \\mathbf{x}_0)}{q(\\mathbf{x}_{t-1}\\vert\\mathbf{x}_0)} \\Big) + \\log \\frac{q(\\mathbf{x}_1 \\vert \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_0 \\vert \\mathbf{x}_1)} \\Big] \\\\ \u0026= \\mathbb{E}_q \\Big[ -\\log p_\\theta(\\mathbf{x}_T) + \\sum_{t=2}^T \\log \\frac{q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{t-1} \\vert\\mathbf{x}_t)} + \\sum_{t=2}^T \\log \\frac{q(\\mathbf{x}_t \\vert \\mathbf{x}_0)}{q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_0)} + \\log\\frac{q(\\mathbf{x}_1 \\vert \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_0 \\vert \\mathbf{x}_1)} \\Big] \\\\ \u0026= \\mathbb{E}_q \\Big[ -\\log p_\\theta(\\mathbf{x}_T) + \\sum_{t=2}^T \\log \\frac{q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{t-1} \\vert\\mathbf{x}_t)} + \\log\\frac{q(\\mathbf{x}_T \\vert \\mathbf{x}_0)}{q(\\mathbf{x}_1 \\vert \\mathbf{x}_0)} + \\log \\frac{q(\\mathbf{x}_1 \\vert \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_0 \\vert \\mathbf{x}_1)} \\Big]\\\\ \u0026= \\mathbb{E}_q \\Big[ \\log\\frac{q(\\mathbf{x}_T \\vert \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_T)} + \\sum_{t=2}^T \\log \\frac{q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{t-1} \\vert\\mathbf{x}_t)} - \\log p_\\theta(\\mathbf{x}_0 \\vert \\mathbf{x}_1) \\Big] \\\\ \u0026= \\mathbb{E}_q [\\underbrace{D_\\text{KL}(q(\\mathbf{x}_T \\vert \\mathbf{x}_0) \\parallel p_\\theta(\\mathbf{x}_T))}_{L_T} + \\sum_{t=2}^T \\underbrace{D_\\text{KL}(q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0) \\parallel p_\\theta(\\mathbf{x}_{t-1} \\vert\\mathbf{x}_t))}_{L_{t-1}} \\underbrace{- \\log p_\\theta(\\mathbf{x}_0 \\vert \\mathbf{x}_1)}_{L_0} ] \\end{aligned} $$ Let\u0026rsquo;s label each component in the variational lower bound loss separately:\n $$ \\begin{aligned} L_\\text{VLB} \u0026= L_T + L_{T-1} + \\dots + L_0 \\\\ \\text{where } L_T \u0026= D_\\text{KL}(q(\\mathbf{x}_T \\vert \\mathbf{x}_0) \\parallel p_\\theta(\\mathbf{x}_T)) \\\\ L_t \u0026= D_\\text{KL}(q(\\mathbf{x}_t \\vert \\mathbf{x}_{t+1}, \\mathbf{x}_0) \\parallel p_\\theta(\\mathbf{x}_t \\vert\\mathbf{x}_{t+1})) \\text{ for }1 \\leq t \\leq T-1 \\\\ L_0 \u0026= - \\log p_\\theta(\\mathbf{x}_0 \\vert \\mathbf{x}_1) \\end{aligned} $$ Every KL term in $L_\\text{VLB}$ (except for $L_0$) compares two Gaussian distributions and therefore they can be computed in closed form. $L_T$ is constant and can be ignored during training because $q$ has no learnable parameters and $\\mathbf{x}_T$ is a Gaussian noise. Ho et al. 2020 models $L_0$ using a separate discrete decoder derived from $\\mathcal{N}(\\mathbf{x}_0; \\boldsymbol{\\mu}_\\theta(\\mathbf{x}_1, 1), \\boldsymbol{\\Sigma}_\\theta(\\mathbf{x}_1, 1))$.\nParameterization of $L_t$ for Training Loss Recall that we need to learn a neural network to approximate the conditioned probability distributions in the reverse diffusion process, $p_\\theta(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t) = \\mathcal{N}(\\mathbf{x}_{t-1}; \\boldsymbol{\\mu}_\\theta(\\mathbf{x}_t, t), \\boldsymbol{\\Sigma}_\\theta(\\mathbf{x}_t, t))$. We would like to train $\\boldsymbol{\\mu}_\\theta$ to predict $\\tilde{\\boldsymbol{\\mu}}_t = \\frac{1}{\\sqrt{\\alpha_t}} \\Big( \\mathbf{x}_t - \\frac{1 - \\alpha_t}{\\sqrt{1 - \\bar{\\alpha}_t}} \\boldsymbol{\\epsilon}_t \\Big)$. Because $\\mathbf{x}_t$ is available as input at training time, we can reparameterize the Gaussian noise term instead to make it predict $\\boldsymbol{\\epsilon}_t$ from the input $\\mathbf{x}_t$ at time step $t$:\n $$ \\begin{aligned} \\boldsymbol{\\mu}_\\theta(\\mathbf{x}_t, t) \u0026= \\color{cyan}{\\frac{1}{\\sqrt{\\alpha_t}} \\Big( \\mathbf{x}_t - \\frac{1 - \\alpha_t}{\\sqrt{1 - \\bar{\\alpha}_t}} \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) \\Big)} \\\\ \\text{Thus }\\mathbf{x}_{t-1} \u0026= \\mathcal{N}(\\mathbf{x}_{t-1}; \\frac{1}{\\sqrt{\\alpha_t}} \\Big( \\mathbf{x}_t - \\frac{1 - \\alpha_t}{\\sqrt{1 - \\bar{\\alpha}_t}} \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) \\Big), \\boldsymbol{\\Sigma}_\\theta(\\mathbf{x}_t, t)) \\end{aligned} $$ The loss term $L_t$ is parameterized to minimize the difference from $\\tilde{\\boldsymbol{\\mu}}$ :\n $$ \\begin{aligned} L_t \u0026= \\mathbb{E}_{\\mathbf{x}_0, \\boldsymbol{\\epsilon}} \\Big[\\frac{1}{2 \\| \\boldsymbol{\\Sigma}_\\theta(\\mathbf{x}_t, t) \\|^2_2} \\| \\color{blue}{\\tilde{\\boldsymbol{\\mu}}_t(\\mathbf{x}_t, \\mathbf{x}_0)} - \\color{green}{\\boldsymbol{\\mu}_\\theta(\\mathbf{x}_t, t)} \\|^2 \\Big] \\\\ \u0026= \\mathbb{E}_{\\mathbf{x}_0, \\boldsymbol{\\epsilon}} \\Big[\\frac{1}{2 \\|\\boldsymbol{\\Sigma}_\\theta \\|^2_2} \\| \\color{blue}{\\frac{1}{\\sqrt{\\alpha_t}} \\Big( \\mathbf{x}_t - \\frac{1 - \\alpha_t}{\\sqrt{1 - \\bar{\\alpha}_t}} \\boldsymbol{\\epsilon}_t \\Big)} - \\color{green}{\\frac{1}{\\sqrt{\\alpha_t}} \\Big( \\mathbf{x}_t - \\frac{1 - \\alpha_t}{\\sqrt{1 - \\bar{\\alpha}_t}} \\boldsymbol{\\boldsymbol{\\epsilon}}_\\theta(\\mathbf{x}_t, t) \\Big)} \\|^2 \\Big] \\\\ \u0026= \\mathbb{E}_{\\mathbf{x}_0, \\boldsymbol{\\epsilon}} \\Big[\\frac{ (1 - \\alpha_t)^2 }{2 \\alpha_t (1 - \\bar{\\alpha}_t) \\| \\boldsymbol{\\Sigma}_\\theta \\|^2_2} \\|\\boldsymbol{\\epsilon}_t - \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t)\\|^2 \\Big] \\\\ \u0026= \\mathbb{E}_{\\mathbf{x}_0, \\boldsymbol{\\epsilon}} \\Big[\\frac{ (1 - \\alpha_t)^2 }{2 \\alpha_t (1 - \\bar{\\alpha}_t) \\| \\boldsymbol{\\Sigma}_\\theta \\|^2_2} \\|\\boldsymbol{\\epsilon}_t - \\boldsymbol{\\epsilon}_\\theta(\\sqrt{\\bar{\\alpha}_t}\\mathbf{x}_0 + \\sqrt{1 - \\bar{\\alpha}_t}\\boldsymbol{\\epsilon}_t, t)\\|^2 \\Big] \\end{aligned} $$ Simplification Empirically, Ho et al. (2020) found that training the diffusion model works better with a simplified objective that ignores the weighting term:\n $$ \\begin{aligned} L_t^\\text{simple} \u0026= \\mathbb{E}_{t \\sim [1, T], \\mathbf{x}_0, \\boldsymbol{\\epsilon}_t} \\Big[\\|\\boldsymbol{\\epsilon}_t - \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t)\\|^2 \\Big] \\\\ \u0026= \\mathbb{E}_{t \\sim [1, T], \\mathbf{x}_0, \\boldsymbol{\\epsilon}_t} \\Big[\\|\\boldsymbol{\\epsilon}_t - \\boldsymbol{\\epsilon}_\\theta(\\sqrt{\\bar{\\alpha}_t}\\mathbf{x}_0 + \\sqrt{1 - \\bar{\\alpha}_t}\\boldsymbol{\\epsilon}_t, t)\\|^2 \\Big] \\end{aligned} $$ The final simple objective is:\n $$ L_\\text{simple} = L_t^\\text{simple} + C $$ where $C$ is a constant not depending on $\\theta$.\nFig. 4. The training and sampling algorithms in DDPM (Image source: Ho et al. 2020) Connection with noise-conditioned score networks (NCSN) Song \u0026amp; Ermon (2019) proposed a score-based generative modeling method where samples are produced via Langevin dynamics using gradients of the data distribution estimated with score matching. The score of each sample $\\mathbf{x}$\u0026rsquo;s density probability is defined as its gradient $\\nabla_{\\mathbf{x}} \\log q(\\mathbf{x})$. A score network $\\mathbf{s}_\\theta: \\mathbb{R}^D \\to \\mathbb{R}^D$ is trained to estimate it, $\\mathbf{s}_\\theta(\\mathbf{x}) \\approx \\nabla_{\\mathbf{x}} \\log q(\\mathbf{x})$.\nTo make it scalable with high-dimensional data in the deep learning setting, they proposed to use either denoising score matching (Vincent, 2011) or sliced score matching (use random projections; Song et al., 2019). Denosing score matching adds a pre-specified small noise to the data $q(\\tilde{\\mathbf{x}} \\vert \\mathbf{x})$ and estimates $q(\\tilde{\\mathbf{x}})$ with score matching.\nRecall that Langevin dynamics can sample data points from a probability density distribution using only the score $\\nabla_{\\mathbf{x}} \\log q(\\mathbf{x})$ in an iterative process.\nHowever, according to the manifold hypothesis, most of the data is expected to concentrate in a low dimensional manifold, even though the observed data might look only arbitrarily high-dimensional. It brings a negative effect on score estimation since the data points cannot cover the whole space. In regions where data density is low, the score estimation is less reliable. After adding a small Gaussian noise to make the perturbed data distribution cover the full space $\\mathbb{R}^D$, the training of the score estimator network becomes more stable. Song \u0026amp; Ermon (2019) improved it by perturbing the data with the noise of different levels and train a noise-conditioned score network to jointly estimate the scores of all the perturbed data at different noise levels.\nThe schedule of increasing noise levels resembles the forward diffusion process. If we use the diffusion process annotation, the score approximates $\\mathbf{s}_\\theta(\\mathbf{x}_t, t) \\approx \\nabla_{\\mathbf{x}_t} \\log q(\\mathbf{x}_t)$. Given a Gaussian distribution $\\mathbf{x} \\sim \\mathcal{N}(\\mathbf{\\mu}, \\sigma^2 \\mathbf{I})$, we can write the derivative of the logarithm of its density function as $\\nabla_{\\mathbf{x}}\\log p(\\mathbf{x}) = \\nabla_{\\mathbf{x}} \\Big(-\\frac{1}{2\\sigma^2}(\\mathbf{x} - \\boldsymbol{\\mu})^2 \\Big) = - \\frac{\\mathbf{x} - \\boldsymbol{\\mu}}{\\sigma^2} = - \\frac{\\boldsymbol{\\epsilon}}{\\sigma}$ where $\\boldsymbol{\\epsilon} \\sim \\mathcal{N}(\\boldsymbol{0}, \\mathbf{I})$. Recall that $q(\\mathbf{x}_t \\vert \\mathbf{x}_0) \\sim \\mathcal{N}(\\sqrt{\\bar{\\alpha}_t} \\mathbf{x}_0, (1 - \\bar{\\alpha}_t)\\mathbf{I})$ and therefore,\n $$ \\mathbf{s}_\\theta(\\mathbf{x}_t, t) \\approx \\nabla_{\\mathbf{x}_t} \\log q(\\mathbf{x}_t) = \\mathbb{E}_{q(\\mathbf{x}_0)} [\\nabla_{\\mathbf{x}_t} q(\\mathbf{x}_t \\vert \\mathbf{x}_0)] = \\mathbb{E}_{q(\\mathbf{x}_0)} \\Big[ - \\frac{\\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t)}{\\sqrt{1 - \\bar{\\alpha}_t}} \\Big] = - \\frac{\\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t)}{\\sqrt{1 - \\bar{\\alpha}_t}} $$ Parameterization of $\\beta_t$ The forward variances are set to be a sequence of linearly increasing constants in Ho et al. (2020), from $\\beta_1=10^{-4}$ to $\\beta_T=0.02$. They are relatively small compared to the normalized image pixel values between $[-1, 1]$. Diffusion models in their experiments showed high-quality samples but still could not achieve competitive model log-likelihood as other generative models.\nNichol \u0026amp; Dhariwal (2021) proposed several improvement techniques to help diffusion models to obtain lower NLL. One of the improvements is to use a cosine-based variance schedule. The choice of the scheduling function can be arbitrary, as long as it provides a near-linear drop in the middle of the training process and subtle changes around $t=0$ and $t=T$.\n $$ \\beta_t = \\text{clip}(1-\\frac{\\bar{\\alpha}_t}{\\bar{\\alpha}_{t-1}}, 0.999) \\quad\\bar{\\alpha}_t = \\frac{f(t)}{f(0)}\\quad\\text{where }f(t)=\\cos\\Big(\\frac{t/T+s}{1+s}\\cdot\\frac{\\pi}{2}\\Big)^2 $$ where the small offset $s$ is to prevent $\\beta_t$ from being too small when close to $t=0$.\nFig. 5. Comparison of linear and cosine-based scheduling of $\\beta\\_t$ during training. (Image source: Nichol \u0026 Dhariwal, 2021) Parameterization of reverse process variance $\\boldsymbol{\\Sigma}_\\theta$ Ho et al. (2020) chose to fix $\\beta_t$ as constants instead of making them learnable and set $\\boldsymbol{\\Sigma}_\\theta(\\mathbf{x}_t, t) = \\sigma^2_t \\mathbf{I}$ , where $\\sigma_t$ is not learned but set to $\\beta_t$ or $\\tilde{\\beta}_t = \\frac{1 - \\bar{\\alpha}_{t-1}}{1 - \\bar{\\alpha}_t} \\cdot \\beta_t$. Because they found that learning a diagonal variance $\\boldsymbol{\\Sigma}_\\theta$ leads to unstable training and poorer sample quality.\nNichol \u0026amp; Dhariwal (2021) proposed to learn $\\boldsymbol{\\Sigma}_\\theta(\\mathbf{x}_t, t)$ as an interpolation between $\\beta_t$ and $\\tilde{\\beta}_t$ by model predicting a mixing vector $\\mathbf{v}$ :\n $$ \\boldsymbol{\\Sigma}_\\theta(\\mathbf{x}_t, t) = \\exp(\\mathbf{v} \\log \\beta_t + (1-\\mathbf{v}) \\log \\tilde{\\beta}_t) $$ However, the simple objective $L_\\text{simple}$ does not depend on $\\boldsymbol{\\Sigma}_\\theta$ . To add the dependency, they constructed a hybrid objective $L_\\text{hybrid} = L_\\text{simple} + \\lambda L_\\text{VLB}$ where $\\lambda=0.001$ is small and stop gradient on $\\boldsymbol{\\mu}_\\theta$ in the $L_\\text{VLB}$ term such that $L_\\text{VLB}$ only guides the learning of $\\boldsymbol{\\Sigma}_\\theta$. Empirically they observed that $L_\\text{VLB}$ is pretty challenging to optimize likely due to noisy gradients, so they proposed to use a time-averaging smoothed version of $L_\\text{VLB}$ with importance sampling.\nFig. 6. Comparison of negative log-likelihood of improved DDPM with other likelihood-based generative models. NLL is reported in the unit of bits/dim. (Image source: Nichol \u0026 Dhariwal, 2021) Speed up Diffusion Model Sampling It is very slow to generate a sample from DDPM by following the Markov chain of the reverse diffusion process, as $T$ can be up to one or a few thousand steps. One data point from Song et al. 2020: \u0026ldquo;For example, it takes around 20 hours to sample 50k images of size 32 × 32 from a DDPM, but less than a minute to do so from a GAN on an Nvidia 2080 Ti GPU.\u0026rdquo;\nOne simple way is to run a strided sampling schedule (Nichol \u0026amp; Dhariwal, 2021) by taking the sampling update every $\\lceil T/S \\rceil$ steps to reduce the process from $T$ to $S$ steps. The new sampling schedule for generation is $\\{\\tau_1, \\dots, \\tau_S\\}$ where $\\tau_1 \u0026lt; \\tau_2 \u0026lt; \\dots \u0026lt;\\tau_S \\in [1, T]$ and $S \u0026lt; T$.\nFor another approach, let\u0026rsquo;s rewrite $q_\\sigma(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0)$ to be parameterized by a desired standard deviation $\\sigma_t$ according to the nice property:\n $$ \\begin{aligned} \\mathbf{x}_{t-1} \u0026= \\sqrt{\\bar{\\alpha}_{t-1}}\\mathbf{x}_0 + \\sqrt{1 - \\bar{\\alpha}_{t-1}}\\boldsymbol{\\epsilon}_{t-1} \\\\ \u0026= \\sqrt{\\bar{\\alpha}_{t-1}}\\mathbf{x}_0 + \\sqrt{1 - \\bar{\\alpha}_{t-1} - \\sigma_t^2} \\boldsymbol{\\epsilon}_t + \\sigma_t\\boldsymbol{\\epsilon} \\\\ \u0026= \\sqrt{\\bar{\\alpha}_{t-1}}\\mathbf{x}_0 + \\sqrt{1 - \\bar{\\alpha}_{t-1} - \\sigma_t^2} \\frac{\\mathbf{x}_t - \\sqrt{\\bar{\\alpha}_t}\\mathbf{x}_0}{\\sqrt{1 - \\bar{\\alpha}_t}} + \\sigma_t\\boldsymbol{\\epsilon} \\\\ q_\\sigma(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0) \u0026= \\mathcal{N}(\\mathbf{x}_{t-1}; \\sqrt{\\bar{\\alpha}_{t-1}}\\mathbf{x}_0 + \\sqrt{1 - \\bar{\\alpha}_{t-1} - \\sigma_t^2} \\frac{\\mathbf{x}_t - \\sqrt{\\bar{\\alpha}_t}\\mathbf{x}_0}{\\sqrt{1 - \\bar{\\alpha}_t}}, \\sigma_t^2 \\mathbf{I}) \\end{aligned} $$ Recall that in $q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0) = \\mathcal{N}(\\mathbf{x}_{t-1}; \\tilde{\\boldsymbol{\\mu}}(\\mathbf{x}_t, \\mathbf{x}_0), \\tilde{\\beta}_t \\mathbf{I})$, therefore we have:\n $$ \\tilde{\\beta}_t = \\sigma_t^2 = \\frac{1 - \\bar{\\alpha}_{t-1}}{1 - \\bar{\\alpha}_t} \\cdot \\beta_t $$ Let $\\sigma_t^2 = \\eta \\cdot \\tilde{\\beta}_t$ such that we can adjust $\\eta \\in \\mathbb{R}^+$ as a hyperparameter to control the sampling stochasticity. The special case of $\\eta = 0$ makes the sampling process deterministic. Such a model is named the denoising diffusion implicit model (DDIM; Song et al., 2020). DDIM has the same marginal noise distribution but deterministically maps noise back to the original data samples.\nDuring generation, we only sample a subset of $S$ diffusion steps $\\{\\tau_1, \\dots, \\tau_S\\}$ and the inference process becomes:\n $$ q_{\\sigma, \\tau}(\\mathbf{x}_{\\tau_{i-1}} \\vert \\mathbf{x}_{\\tau_t}, \\mathbf{x}_0) = \\mathcal{N}(\\mathbf{x}_{\\tau_{i-1}}; \\sqrt{\\bar{\\alpha}_{t-1}}\\mathbf{x}_0 + \\sqrt{1 - \\bar{\\alpha}_{t-1} - \\sigma_t^2} \\frac{\\mathbf{x}_{\\tau_i} - \\sqrt{\\bar{\\alpha}_t}\\mathbf{x}_0}{\\sqrt{1 - \\bar{\\alpha}_t}}, \\sigma_t^2 \\mathbf{I}) $$ While all the models are trained with $T=1000$ diffusion steps in the experiments, they observed that DDIM ($\\eta=0$) can produce the best quality samples when $S$ is small, while DDPM ($\\eta=1$) performs much worse on small $S$. DDPM does perform better when we can afford to run the full reverse Markov diffusion steps ($S=T=1000$). With DDIM, it is possible to train the diffusion model up to any arbitrary number of forward steps but only sample from a subset of steps in the generative process.\nFig. 7. FID scores on CIFAR10 and CelebA datasets by diffusion models of different settings, including $\\color{cyan}{\\text{DDIM}}$ ($\\eta=0$) and $\\color{orange}{\\text{DDPM}}$ ($\\hat{\\sigma}$). (Image source: Song et al., 2020) Compared to DDPM, DDIM is able to:\n Generate higher-quality samples using a much fewer number of steps. Have \u0026ldquo;consistency\u0026rdquo; property since the generative process is deterministic, meaning that multiple samples conditioned on the same latent variable should have similar high-level features. Because of the consistency, DDIM can do semantically meaningful interpolation in the latent variable. Latent diffusion model (LDM; Rombach \u0026amp; Blattmann, et al. 2022) runs the diffusion process in the latent space instead of pixel space, making training cost lower and inference speed faster. It is motivated by the observation that most bits of an image contribute to perceptual details and the semantic and conceptual composition still remains after aggressive compression. LDM loosely decomposes the perceptual compression and semantic compression with generative modeling learning by first trimming off pixel-level redundancy with autoencoder and then manipulate/generate semantic concepts with diffusion process on learned latent.\nFig. 8. The plot for tradeoff between compression rate and distortion, illustrating two-stage compressions - perceptural and semantic comparession. (Image source: Rombach \u0026 Blattmann, et al. 2022) The perceptual compression process relies on an autoencoder model. An encoder $\\mathcal{E}$ is used to compress the input image $\\mathbf{x} \\in \\mathbb{R}^{H \\times W \\times 3}$ to a smaller 2D latent vector $\\mathbf{z} = \\mathcal{E}(\\mathbf{x}) \\in \\mathbb{R}^{h \\times w \\times c}$ , where the downsampling rate $f=H/h=W/w=2^m, m \\in \\mathbb{N}$. Then an decoder $\\mathcal{D}$ reconstructs the images from the latent vector, $\\tilde{\\mathbf{x}} = \\mathcal{D}(\\mathbf{z})$. The paper explored two types of regularization in autoencoder training to avoid arbitrarily high-variance in the latent spaces.\n KL-reg: A small KL penalty towards a standard normal distribution over the learned latent, similar to VAE. VQ-reg: Uses a vector quantization layer within the decoder, like VQVAE but the quantization layer is absorbed by the decoder. The diffusion and denoising processes happen on the latent vector $\\mathbf{z}$. The denoising model is a time-conditioned U-Net, augmented with the cross-attention mechanism to handle flexible conditioning information for image generation (e.g. class labels, semantic maps, blurred variants of an image). The design is equivalent to fuse representation of different modality into the model with cross-attention mechanism. Each type of conditioning information is paired with a domain-specific encoder $\\tau_\\theta$ to project the conditioning input $y$ to an intermediate representation that can be mapped into cross-attention component, $\\tau_\\theta(y) \\in \\mathbb{R}^{M \\times d_\\tau}$:\n $$ \\begin{aligned} \u0026\\text{Attention}(\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}) = \\text{softmax}\\Big(\\frac{\\mathbf{Q}\\mathbf{K}^\\top}{\\sqrt{d}}\\Big) \\cdot \\mathbf{V} \\\\ \u0026\\text{where }\\mathbf{Q} = \\mathbf{W}^{(i)}_Q \\cdot \\varphi_i(\\mathbf{z}_i),\\; \\mathbf{K} = \\mathbf{W}^{(i)}_K \\cdot \\tau_\\theta(y),\\; \\mathbf{V} = \\mathbf{W}^{(i)}_V \\cdot \\tau_\\theta(y) \\\\ \u0026\\text{and } \\mathbf{W}^{(i)}_Q \\in \\mathbb{R}^{d \\times d^i_\\epsilon},\\; \\mathbf{W}^{(i)}_K, \\mathbf{W}^{(i)}_V \\in \\mathbb{R}^{d \\times d_\\tau},\\; \\varphi_i(\\mathbf{z}_i) \\in \\mathbb{R}^{N \\times d^i_\\epsilon},\\; \\tau_\\theta(y) \\in \\mathbb{R}^{M \\times d_\\tau} \\end{aligned} $$ Fig. 9. The architecture of latent diffusion model. (Image source: Rombach \u0026 Blattmann, et al. 2022) Conditioned Generation While training generative models on images with conditioning information such as ImageNet dataset, it is common to generate samples conditioned on class labels or a piece of descriptive text.\nClassifier Guided Diffusion To explicit incorporate class information into the diffusion process, Dhariwal \u0026amp; Nichol (2021) trained a classifier $f_\\phi(y \\vert \\mathbf{x}_t, t)$ on noisy image $\\mathbf{x}_t$ and use gradients $\\nabla_\\mathbf{x} \\log f_\\phi(y \\vert \\mathbf{x}_t)$ to guide the diffusion sampling process toward the conditioning information $y$ (e.g. a target class label) by altering the noise prediction. Recall that $\\nabla_{\\mathbf{x}_t} \\log q(\\mathbf{x}_t) = - \\frac{1}{\\sqrt{1 - \\bar{\\alpha}_t}} \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t)$ and we can write the score function for the joint distribution $q(\\mathbf{x}_t, y)$ as following,\n $$ \\begin{aligned} \\nabla_{\\mathbf{x}_t} \\log q(\\mathbf{x}_t, y) \u0026= \\nabla_{\\mathbf{x}_t} \\log q(\\mathbf{x}_t) + \\nabla_{\\mathbf{x}_t} \\log q(y \\vert \\mathbf{x}_t) \\\\ \u0026\\approx - \\frac{1}{\\sqrt{1 - \\bar{\\alpha}_t}} \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) + \\nabla_{\\mathbf{x}_t} \\log f_\\phi(y \\vert \\mathbf{x}_t) \\\\ \u0026= - \\frac{1}{\\sqrt{1 - \\bar{\\alpha}_t}} (\\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) - \\sqrt{1 - \\bar{\\alpha}_t} \\nabla_{\\mathbf{x}_t} \\log f_\\phi(y \\vert \\mathbf{x}_t)) \\end{aligned} $$ Thus, a new classifier-guided predictor $\\bar{\\boldsymbol{\\epsilon}}_\\theta$ would take the form as following,\n $$ \\bar{\\boldsymbol{\\epsilon}}_\\theta(\\mathbf{x}_t, t) = \\boldsymbol{\\epsilon}_\\theta(x_t, t) - \\sqrt{1 - \\bar{\\alpha}_t} \\nabla_{\\mathbf{x}_t} \\log f_\\phi(y \\vert \\mathbf{x}_t) $$ To control the strength of the classifier guidance, we can add a weight $w$ to the delta part,\n $$ \\bar{\\boldsymbol{\\epsilon}}_\\theta(\\mathbf{x}_t, t) = \\boldsymbol{\\epsilon}_\\theta(x_t, t) - \\sqrt{1 - \\bar{\\alpha}_t} \\; w \\nabla_{\\mathbf{x}_t} \\log f_\\phi(y \\vert \\mathbf{x}_t) $$ The resulting ablated diffusion model (ADM) and the one with additional classifier guidance (ADM-G) are able to achieve better results than SOTA generative models (e.g. BigGAN).\nFig. 10. The algorithms use guidance from a classifier to run conditioned generation with DDPM and DDIM. (Image source: Dhariwal \u0026 Nichol, 2021]) Additionally with some modifications on the U-Net architecture, Dhariwal \u0026amp; Nichol (2021) showed performance better than GAN with diffusion models. The architecture modifications include larger model depth/width, more attention heads, multi-resolution attention, BigGAN residual blocks for up/downsampling, residual connection rescale by $1/\\sqrt{2}$ and adaptive group normalization (AdaGN).\nClassifier-Free Guidance Without an independent classifier $f_\\phi$, it is still possible to run conditional diffusion steps by incorporating the scores from a conditional and an unconditional diffusion model (Ho \u0026amp; Salimans, 2021). Let unconditional denoising diffusion model $p_\\theta(\\mathbf{x})$ parameterized through a score estimator $\\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t)$ and the conditional model $p_\\theta(\\mathbf{x} \\vert y)$ parameterized through $\\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t, y)$. These two models can be learned via a single neural network. Precisely, a conditional diffusion model $p_\\theta(\\mathbf{x} \\vert y)$ is trained on paired data $(\\mathbf{x}, y)$, where the conditioning information $y$ gets discarded periodically at random such that the model knows how to generate images unconditionally as well, i.e. $\\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) = \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t, y=\\varnothing)$.\nThe gradient of an implicit classifier can be represented with conditional and unconditional score estimators. Once plugged into the classifier-guided modified score, the score contains no dependency on a separate classifier.\n $$ \\begin{aligned} \\nabla_{\\mathbf{x}_t} \\log p(y \\vert \\mathbf{x}_t) \u0026= \\nabla_{\\mathbf{x}_t} \\log p(\\mathbf{x}_t \\vert y) - \\nabla_{\\mathbf{x}_t} \\log p(\\mathbf{x}_t) \\\\ \u0026= - \\frac{1}{\\sqrt{1 - \\bar{\\alpha}_t}}\\Big( \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t, y) - \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) \\Big) \\\\ \\bar{\\boldsymbol{\\epsilon}}_\\theta(\\mathbf{x}_t, t, y) \u0026= \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t, y) - \\sqrt{1 - \\bar{\\alpha}_t} \\; w \\nabla_{\\mathbf{x}_t} \\log p(y \\vert \\mathbf{x}_t) \\\\ \u0026= \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t, y) + w \\big(\\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t, y) - \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) \\big) \\\\ \u0026= (w+1) \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t, y) - w \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) \\end{aligned} $$ Their experiments showed that classifier-free guidance can achieve a good balance between FID (distinguish between synthetic and generated images) and IS (quality and diversity).\n$$ q(\\mathbf{x}_t \\vert y) q(y \\vert \\mathbf{x}_t)^w \\propto \\frac{q(y\\vert \\mathbf{x}_t) q(\\mathbf{x}_t)}{q(y)} q(y \\vert \\mathbf{x}_t)^w \\propto q(\\mathbf{x}_t) q(y \\vert \\mathbf{x}_t)^{w+1} $$ Therefore, the classifier-guided noise prediction can be rewritten as $$ \\begin{aligned} \\bar{\\boldsymbol{\\epsilon}}_\\theta(\\mathbf{x}_t, t) \u0026= \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) - \\sqrt{1 - \\bar{\\alpha}_t} (w+1) \\nabla_{x_t} \\log f_\\phi(y\\vert \\mathbf{x}_t) \\\\ \u0026 \\approx - \\sqrt{1 - \\bar{\\alpha}_t} \\nabla_{\\mathbf{x}_t} [\\log p(\\mathbf{x}_t) + (w+1) \\log f_\\phi (y \\vert \\mathbf{x}_t)] \\\\ \u0026 = - \\sqrt{1 - \\bar{\\alpha}_t} \\nabla_{\\mathbf{x}_t} [\\log p(\\mathbf{x}_t \\vert y) + w \\log p_\\phi (y \\vert \\mathbf{x}_t)] \\end{aligned} $$ -- The guided diffusion model, GLIDE (Nichol, Dhariwal \u0026amp; Ramesh, et al. 2022), explored both guiding strategies, CLIP guidance and classifier-free guidance, and found that the latter is more preferred. They hypothesized that it is because CLIP guidance exploits the model with adversarial examples towards the CLIP model, rather than optimize the better matched images generation.\nScale up Generation Resolution and Quality To generate high-quality images at high resolution, Ho et al. (2021) proposed to use a pipeline of multiple diffusion models at increasing resolutions. Noise conditioning augmentation between pipeline models is crucial to the final image quality, which is to apply strong data augmentation to the conditioning input $\\mathbf{z}$ of each super-resolution model $p_\\theta(\\mathbf{x} \\vert \\mathbf{z})$. The conditioning noise helps reduce compounding error in the pipeline setup. U-net is a common choice of model architecture in diffusion modeling for high-resolution image generation.\nFig. 11. A cascaded pipeline of multiple diffusion models at increasing resolutions. (Image source: Ho et al. 2021]) They found the most effective noise is to apply Gaussian noise at low resolution and Gaussian blur at high resolution. In addition, they also explored two forms of conditioning augmentation that require small modification to the training process. Note that conditioning noise is only applied to training but not at inference.\n Truncated conditioning augmentation stops the diffusion process early at step $t \u0026gt; 0$ for low resolution. Non-truncated conditioning augmentation runs the full low resolution reverse process until step 0 but then corrupt it by $\\mathbf{z}_t \\sim q(\\mathbf{x}_t \\vert \\mathbf{x}_0)$ and then feeds the corrupted $\\mathbf{z}_t$ s into the super-resolution model. The two-stage diffusion model unCLIP (Ramesh et al. 2022) heavily utilizes the CLIP text encoder to produce text-guided images at high quality. Given a pretrained CLIP model $\\mathbf{c}$ and paired training data for the diffusion model, $(\\mathbf{x}, y)$, where $x$ is an image and $y$ is the corresponding caption, we can compute the CLIP text and image embedding, $\\mathbf{c}^t(y)$ and $\\mathbf{c}^i(\\mathbf{x})$, respectively. The unCLIP learns two models in parallel:\n A prior model $P(\\mathbf{c}^i \\vert y)$: outputs CLIP image embedding $\\mathbf{c}^i$ given the text $y$. A decoder $P(\\mathbf{x} \\vert \\mathbf{c}^i, [y])$: generates the image $\\mathbf{x}$ given CLIP image embedding $\\mathbf{c}^i$ and optionally the original text $y$. These two models enable conditional generation, because\n $$ \\underbrace{P(\\mathbf{x} \\vert y) = P(\\mathbf{x}, \\mathbf{c}^i \\vert y)}_{\\mathbf{c}^i\\text{ is deterministic given }\\mathbf{x}} = P(\\mathbf{x} \\vert \\mathbf{c}^i, y)P(\\mathbf{c}^i \\vert y) $$ Fig. 12. The architecture of unCLIP. (Image source: Ramesh et al. 2022]) unCLIP follows a two-stage image generation process:\n Given a text $y$, a CLIP model is first used to generate a text embedding $\\mathbf{c}^t(y)$. Using CLIP latent space enables zero-shot image manipulation via text. A diffusion or autoregressive prior $P(\\mathbf{c}^i \\vert y)$ processes this CLIP text embedding to construct an image prior and then a diffusion decoder $P(\\mathbf{x} \\vert \\mathbf{c}^i, [y])$ generates an image, conditioned on the prior. This decoder can also generate image variations conditioned on an image input, preserving its style and semantics. Instead of CLIP model, Imagen (Saharia et al. 2022) uses a pre-trained large LM (i.e. a frozen T5-XXL text encoder) to encode text for image generation. There is a general trend that larger model size can lead to better image quality and text-image alignment. They found that T5-XXL and CLIP text encoder achieve similar performance on MS-COCO, but human evaluation prefers T5-XXL on DrawBench (a collection of prompts covering 11 categories).\nWhen applying classifier-free guidance, increasing $w$ may lead to better image-text alignment but worse image fidelity. They found that it is due to train-test mismatch, that is saying, because training data $\\mathbf{x}$ stays within the range $[-1, 1]$, the test data should be so too. Two thresholding strategies are introduced:\n Static thresholding: clip $\\mathbf{x}$ prediction to $[-1, 1]$ Dynamic thresholding: at each sampling step, compute $s$ as a certain percentile absolute pixel value; if $s \u0026gt; 1$, clip the prediction to $[-s, s]$ and divide by $s$. Imagen modifies several designs in U-net to make it efficient U-Net.\n Shift model parameters from high resolution blocks to low resolution by adding more residual locks for the lower resolutions; Scale the skip connections by $1/\\sqrt{2}$ Reverse the order of downsampling (move it before convolutions) and upsampling operations (move it after convolution) in order to improve the speed of forward pass. They found that noise conditioning augmentation, dynamic thresholding and efficient U-Net are critical for image quality, but scaling text encoder size is more important than U-Net size.\nQuick Summary Pros: Tractability and flexibility are two conflicting objectives in generative modeling. Tractable models can be analytically evaluated and cheaply fit data (e.g. via a Gaussian or Laplace), but they cannot easily describe the structure in rich datasets. Flexible models can fit arbitrary structures in data, but evaluating, training, or sampling from these models is usually expensive. Diffusion models are both analytically tractable and flexible\n Cons: Diffusion models rely on a long Markov chain of diffusion steps to generate samples, so it can be quite expensive in terms of time and compute. New methods have been proposed to make the process much faster, but the sampling is still slower than GAN.\n Citation Cited as:\n Weng, Lilian. (Jul 2021). What are diffusion models? Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2021-07-11-diffusion-models/.\n Or\n@article{weng2021diffusion, title = \u0026quot;What are diffusion models?\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2021\u0026quot;, month = \u0026quot;Jul\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2021-07-11-diffusion-models/\u0026quot; } References [1] Jascha Sohl-Dickstein et al. “Deep Unsupervised Learning using Nonequilibrium Thermodynamics.” ICML 2015.\n[2] Max Welling \u0026amp; Yee Whye Teh. “Bayesian learning via stochastic gradient langevin dynamics.” ICML 2011.\n[3] Yang Song \u0026amp; Stefano Ermon. “Generative modeling by estimating gradients of the data distribution.” NeurIPS 2019.\n[4] Yang Song \u0026amp; Stefano Ermon. “Improved techniques for training score-based generative models.” NeuriPS 2020.\n[5] Jonathan Ho et al. “Denoising diffusion probabilistic models.” arxiv Preprint arxiv:2006.11239 (2020). [code]\n[6] Jiaming Song et al. “Denoising diffusion implicit models.” arxiv Preprint arxiv:2010.02502 (2020). [code]\n[7] Alex Nichol \u0026amp; Prafulla Dhariwal. “Improved denoising diffusion probabilistic models” arxiv Preprint arxiv:2102.09672 (2021). [code]\n[8] Prafula Dhariwal \u0026amp; Alex Nichol. \u0026ldquo;Diffusion Models Beat GANs on Image Synthesis.\u0026quot; arxiv Preprint arxiv:2105.05233 (2021). [code]\n[9] Jonathan Ho \u0026amp; Tim Salimans. \u0026ldquo;Classifier-Free Diffusion Guidance.\u0026quot; NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications.\n[10] Yang Song, et al. \u0026ldquo;Score-Based Generative Modeling through Stochastic Differential Equations.\u0026quot; ICLR 2021.\n[11] Alex Nichol, Prafulla Dhariwal \u0026amp; Aditya Ramesh, et al. \u0026ldquo;GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models.\u0026quot; ICML 2022.\n[12] Jonathan Ho, et al. \u0026ldquo;Cascaded diffusion models for high fidelity image generation.\u0026quot; J. Mach. Learn. Res. 23 (2022): 47-1.\n[13] Aditya Ramesh et al. \u0026ldquo;Hierarchical Text-Conditional Image Generation with CLIP Latents.\u0026quot; arxiv Preprint arxiv:2204.06125 (2022).\n[14] Chitwan Saharia \u0026amp; William Chan, et al. \u0026ldquo;Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding.\u0026quot; arxiv Preprint arxiv:2205.11487 (2022).\n[15] Rombach \u0026amp; Blattmann, et al. \u0026ldquo;High-Resolution Image Synthesis with Latent Diffusion Models.\u0026quot; CVPR 2022.code\n","permalink":"https://lilianweng.github.io/posts/2021-07-11-diffusion-models/","summary":"[Updated on 2021-09-19: Highly recommend this blog post on score-based generative modeling by Yang Song (author of several key papers in the references)]. [Updated on 2022-08-27: Added classifier-free guidance, GLIDE, unCLIP and Imagen. [Updated on 2022-08-31: Added latent diffusion model.\nSo far, I\u0026rsquo;ve written about three types of generative models, GAN, VAE, and Flow-based models. They have shown great success in generating high-quality samples, but each has some limitations of its own.","title":"What are Diffusion Models?"},{"content":"The goal of contrastive representation learning is to learn such an embedding space in which similar sample pairs stay close to each other while dissimilar ones are far apart. Contrastive learning can be applied to both supervised and unsupervised settings. When working with unsupervised data, contrastive learning is one of the most powerful approaches in self-supervised learning.\nContrastive Training Objectives In early versions of loss functions for contrastive learning, only one positive and one negative sample are involved. The trend in recent training objectives is to include multiple positive and negative pairs in one batch.\nContrastive Loss Contrastive loss (Chopra et al. 2005) is one of the earliest training objectives used for deep metric learning in a contrastive fashion.\nGiven a list of input samples $\\{ \\mathbf{x}_i \\}$, each has a corresponding label $y_i \\in \\{1, \\dots, L\\}$ among $L$ classes. We would like to learn a function $f_\\theta(.): \\mathcal{X}\\to\\mathbb{R}^d$ that encodes $x_i$ into an embedding vector such that examples from the same class have similar embeddings and samples from different classes have very different ones. Thus, contrastive loss takes a pair of inputs $(x_i, x_j)$ and minimizes the embedding distance when they are from the same class but maximizes the distance otherwise.\n $$ \\mathcal{L}_\\text{cont}(\\mathbf{x}_i, \\mathbf{x}_j, \\theta) = \\mathbb{1}[y_i=y_j] \\| f_\\theta(\\mathbf{x}_i) - f_\\theta(\\mathbf{x}_j) \\|^2_2 + \\mathbb{1}[y_i\\neq y_j]\\max(0, \\epsilon - \\|f_\\theta(\\mathbf{x}_i) - f_\\theta(\\mathbf{x}_j)\\|_2)^2 $$ where $\\epsilon$ is a hyperparameter, defining the lower bound distance between samples of different classes.\nTriplet Loss Triplet loss was originally proposed in the FaceNet (Schroff et al. 2015) paper and was used to learn face recognition of the same person at different poses and angles.\nFig. 1. Illustration of triplet loss given one positive and one negative per anchor. (Image source: Schroff et al. 2015) Given one anchor input $\\mathbf{x}$, we select one positive sample $\\mathbf{x}^+$ and one negative $\\mathbf{x}^-$, meaning that $\\mathbf{x}^+$ and $\\mathbf{x}$ belong to the same class and $\\mathbf{x}^-$ is sampled from another different class. Triplet loss learns to minimize the distance between the anchor $\\mathbf{x}$ and positive $\\mathbf{x}^+$ and maximize the distance between the anchor $\\mathbf{x}$ and negative $\\mathbf{x}^-$ at the same time with the following equation:\n $$ \\mathcal{L}_\\text{triplet}(\\mathbf{x}, \\mathbf{x}^+, \\mathbf{x}^-) = \\sum_{\\mathbf{x} \\in \\mathcal{X}} \\max\\big( 0, \\|f(\\mathbf{x}) - f(\\mathbf{x}^+)\\|^2_2 - \\|f(\\mathbf{x}) - f(\\mathbf{x}^-)\\|^2_2 + \\epsilon \\big) $$ where the margin parameter $\\epsilon$ is configured as the minimum offset between distances of similar vs dissimilar pairs.\nIt is crucial to select challenging $\\mathbf{x}^-$ to truly improve the model.\nLifted Structured Loss Lifted Structured Loss (Song et al. 2015) utilizes all the pairwise edges within one training batch for better computational efficiency.\nFig. 2. Illustration compares contrastive loss, triplet loss and lifted structured loss. Red and blue edges connect similar and dissimilar sample pairs respectively. (Image source: Song et al. 2015) Let $D_{ij} = | f(\\mathbf{x}_i) - f(\\mathbf{x}_j) |_2$, a structured loss function is defined as\n $$ \\begin{aligned} \\mathcal{L}_\\text{struct} \u0026= \\frac{1}{2\\vert \\mathcal{P} \\vert} \\sum_{(i,j) \\in \\mathcal{P}} \\max(0, \\mathcal{L}_\\text{struct}^{(ij)})^2 \\\\ \\text{where } \\mathcal{L}_\\text{struct}^{(ij)} \u0026= D_{ij} + \\color{red}{\\max \\big( \\max_{(i,k)\\in \\mathcal{N}} \\epsilon - D_{ik}, \\max_{(j,l)\\in \\mathcal{N}} \\epsilon - D_{jl} \\big)} \\end{aligned} $$ where $\\mathcal{P}$ contains the set of positive pairs and $\\mathcal{N}$ is the set of negative pairs. Note that the dense pairwise squared distance matrix can be easily computed per training batch.\nThe red part in $\\mathcal{L}_\\text{struct}^{(ij)}$ is used for mining hard negatives. However, it is not smooth and may cause the convergence to a bad local optimum in practice. Thus, it is relaxed to be:\n $$ \\mathcal{L}_\\text{struct}^{(ij)} = D_{ij} + \\log \\Big( \\sum_{(i,k)\\in\\mathcal{N}} \\exp(\\epsilon - D_{ik}) + \\sum_{(j,l)\\in\\mathcal{N}} \\exp(\\epsilon - D_{jl}) \\Big) $$ In the paper, they also proposed to enhance the quality of negative samples in each batch by actively incorporating difficult negative samples given a few random positive pairs.\nN-pair Loss Multi-Class N-pair loss (Sohn 2016) generalizes triplet loss to include comparison with multiple negative samples.\nGiven a $(N + 1)$-tuplet of training samples, $\\{ \\mathbf{x}, \\mathbf{x}^+, \\mathbf{x}^-_1, \\dots, \\mathbf{x}^-_{N-1} \\}$, including one positive and $N-1$ negative ones, N-pair loss is defined as:\n $$ \\begin{aligned} \\mathcal{L}_\\text{N-pair}(\\mathbf{x}, \\mathbf{x}^+, \\{\\mathbf{x}^-_i\\}^{N-1}_{i=1}) \u0026= \\log\\big(1 + \\sum_{i=1}^{N-1} \\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^-_i) - f(\\mathbf{x})^\\top f(\\mathbf{x}^+))\\big) \\\\ \u0026= -\\log\\frac{\\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^+))}{\\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^+)) + \\sum_{i=1}^{N-1} \\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^-_i))} \\end{aligned} $$ If we only sample one negative sample per class, it is equivalent to the softmax loss for multi-class classification.\nNCE Noise Contrastive Estimation, short for NCE, is a method for estimating parameters of a statistical model, proposed by Gutmann \u0026amp; Hyvarinen in 2010. The idea is to run logistic regression to tell apart the target data from noise. Read more on how NCE is used for learning word embedding here.\nLet $\\mathbf{x}$ be the target sample $\\sim P(\\mathbf{x} \\vert C=1; \\theta) = p_\\theta(\\mathbf{x})$ and $\\tilde{\\mathbf{x}}$ be the noise sample $\\sim P(\\tilde{\\mathbf{x}} \\vert C=0) = q(\\tilde{\\mathbf{x}})$. Note that the logistic regression models the logit (i.e. log-odds) and in this case we would like to model the logit of a sample $u$ from the target data distribution instead of the noise distribution:\n $$ \\ell_\\theta(\\mathbf{u}) = \\log \\frac{p_\\theta(\\mathbf{u})}{q(\\mathbf{u})} = \\log p_\\theta(\\mathbf{u}) - \\log q(\\mathbf{u}) $$ After converting logits into probabilities with sigmoid $\\sigma(.)$, we can apply cross entropy loss:\n $$ \\begin{aligned} \\mathcal{L}_\\text{NCE} \u0026= - \\frac{1}{N} \\sum_{i=1}^N \\big[ \\log \\sigma (\\ell_\\theta(\\mathbf{x}_i)) + \\log (1 - \\sigma (\\ell_\\theta(\\tilde{\\mathbf{x}}_i))) \\big] \\\\ \\text{ where }\\sigma(\\ell) \u0026= \\frac{1}{1 + \\exp(-\\ell)} = \\frac{p_\\theta}{p_\\theta + q} \\end{aligned} $$ Here I listed the original form of NCE loss which works with only one positive and one noise sample. In many follow-up works, contrastive loss incorporating multiple negative samples is also broadly referred to as NCE.\nInfoNCE The InfoNCE loss in CPC (Contrastive Predictive Coding; van den Oord, et al. 2018), inspired by NCE, uses categorical cross-entropy loss to identify the positive sample amongst a set of unrelated noise samples.\nGiven a context vector $\\mathbf{c}$, the positive sample should be drawn from the conditional distribution $p(\\mathbf{x} \\vert \\mathbf{c})$, while $N-1$ negative samples are drawn from the proposal distribution $p(\\mathbf{x})$, independent from the context $\\mathbf{c}$. For brevity, let us label all the samples as $X=\\{ \\mathbf{x}_i \\}^N_{i=1}$ among which only one of them $\\mathbf{x}_\\texttt{pos}$ is a positive sample. The probability of we detecting the positive sample correctly is:\n $$ p(C=\\texttt{pos} \\vert X, \\mathbf{c}) = \\frac{p(x_\\texttt{pos} \\vert \\mathbf{c}) \\prod_{i=1,\\dots,N; i \\neq \\texttt{pos}} p(\\mathbf{x}_i)}{\\sum_{j=1}^N \\big[ p(\\mathbf{x}_j \\vert \\mathbf{c}) \\prod_{i=1,\\dots,N; i \\neq j} p(\\mathbf{x}_i) \\big]} = \\frac{ \\frac{p(\\mathbf{x}_\\texttt{pos}\\vert c)}{p(\\mathbf{x}_\\texttt{pos})} }{ \\sum_{j=1}^N \\frac{p(\\mathbf{x}_j\\vert \\mathbf{c})}{p(\\mathbf{x}_j)} } = \\frac{f(\\mathbf{x}_\\texttt{pos}, \\mathbf{c})}{ \\sum_{j=1}^N f(\\mathbf{x}_j, \\mathbf{c}) } $$ where the scoring function is $f(\\mathbf{x}, \\mathbf{c}) \\propto \\frac{p(\\mathbf{x}\\vert\\mathbf{c})}{p(\\mathbf{x})}$.\nThe InfoNCE loss optimizes the negative log probability of classifying the positive sample correctly:\n $$ \\mathcal{L}_\\text{InfoNCE} = - \\mathbb{E} \\Big[\\log \\frac{f(\\mathbf{x}, \\mathbf{c})}{\\sum_{\\mathbf{x}' \\in X} f(\\mathbf{x}', \\mathbf{c})} \\Big] $$ The fact that $f(x, c)$ estimates the density ratio $\\frac{p(x\\vert c)}{p(x)}$ has a connection with mutual information optimization. To maximize the the mutual information between input $x$ and context vector $c$, we have:\n $$ I(\\mathbf{x}; \\mathbf{c}) = \\sum_{\\mathbf{x}, \\mathbf{c}} p(\\mathbf{x}, \\mathbf{c}) \\log\\frac{p(\\mathbf{x}, \\mathbf{c})}{p(\\mathbf{x})p(\\mathbf{c})} = \\sum_{\\mathbf{x}, \\mathbf{c}} p(\\mathbf{x}, \\mathbf{c})\\log\\color{blue}{\\frac{p(\\mathbf{x}|\\mathbf{c})}{p(\\mathbf{x})}} $$ where the logarithmic term in blue is estimated by $f$.\nFor sequence prediction tasks, rather than modeling the future observations $p_k(\\mathbf{x}_{t+k} \\vert \\mathbf{c}_t)$ directly (which could be fairly expensive), CPC models a density function to preserve the mutual information between $\\mathbf{x}_{t+k}$ and $\\mathbf{c}_t$:\n $$ f_k(\\mathbf{x}_{t+k}, \\mathbf{c}_t) = \\exp(\\mathbf{z}_{t+k}^\\top \\mathbf{W}_k \\mathbf{c}_t) \\propto \\frac{p(\\mathbf{x}_{t+k}\\vert\\mathbf{c}_t)}{p(\\mathbf{x}_{t+k})} $$ where $\\mathbf{z}_{t+k}$ is the encoded input and $\\mathbf{W}_k$ is a trainable weight matrix.\nSoft-Nearest Neighbors Loss Soft-Nearest Neighbors Loss (Salakhutdinov \u0026amp; Hinton 2007, Frosst et al. 2019) extends it to include multiple positive samples.\nGiven a batch of samples, $\\{\\mathbf{x}_i, y_i)\\}^B_{i=1}$ where $y_i$ is the class label of $\\mathbf{x}_i$ and a function $f(.,.)$ for measuring similarity between two inputs, the soft nearest neighbor loss at temperature $\\tau$ is defined as:\n $$ \\mathcal{L}_\\text{snn} = -\\frac{1}{B}\\sum_{i=1}^B \\log \\frac{\\sum_{i\\neq j, y_i = y_j, j=1,\\dots,B} \\exp(- f(\\mathbf{x}_i, \\mathbf{x}_j) / \\tau)}{\\sum_{i\\neq k, k=1,\\dots,B} \\exp(- f(\\mathbf{x}_i, \\mathbf{x}_k) /\\tau)} $$ The temperature $\\tau$ is used for tuning how concentrated the features are in the representation space. For example, when at low temperature, the loss is dominated by the small distances and widely separated representations cannot contribute much and become irrelevant.\nCommon Setup We can loosen the definition of \u0026ldquo;classes\u0026rdquo; and \u0026ldquo;labels\u0026rdquo; in soft nearest-neighbor loss to create positive and negative sample pairs out of unsupervised data by, for example, applying data augmentation to create noise versions of original samples.\nMost recent studies follow the following definition of contrastive learning objective to incorporate multiple positive and negative samples. According to the setup in (Wang \u0026amp; Isola 2020), let $p_\\texttt{data}(.)$ be the data distribution over $\\mathbb{R}^n$ and $p_\\texttt{pos}(., .)$ be the distribution of positive pairs over $\\mathbb{R}^{n \\times n}$. These two distributions should satisfy:\n Symmetry: $\\forall \\mathbf{x}, \\mathbf{x}^+, p_\\texttt{pos}(\\mathbf{x}, \\mathbf{x}^+) = p_\\texttt{pos}(\\mathbf{x}^+, \\mathbf{x})$ Matching marginal: $\\forall \\mathbf{x}, \\int p_\\texttt{pos}(\\mathbf{x}, \\mathbf{x}^+) d\\mathbf{x}^+ = p_\\texttt{data}(\\mathbf{x})$ To learn an encoder $f(\\mathbf{x})$ to learn a L2-normalized feature vector, the contrastive learning objective is:\n $$ \\begin{aligned} \\mathcal{L}_\\text{contrastive} \u0026= \\mathbb{E}_{(\\mathbf{x},\\mathbf{x}^+)\\sim p_\\texttt{pos}, \\{\\mathbf{x}^-_i\\}^M_{i=1} \\overset{\\text{i.i.d}}{\\sim} p_\\texttt{data} } \\Big[ -\\log\\frac{\\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^+) / \\tau)}{ \\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^+) / \\tau) + \\sum_{i=1}^M \\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}_i^-) / \\tau)} \\Big] \u0026 \\\\ \u0026\\approx \\mathbb{E}_{(\\mathbf{x},\\mathbf{x}^+)\\sim p_\\texttt{pos}, \\{\\mathbf{x}^-_i\\}^M_{i=1} \\overset{\\text{i.i.d}}{\\sim} p_\\texttt{data} }\\Big[ - f(\\mathbf{x})^\\top f(\\mathbf{x}^+) / \\tau + \\log\\big(\\sum_{i=1}^M \\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}_i^-) / \\tau)\\big) \\Big] \u0026 \\scriptstyle{\\text{; Assuming infinite negatives}} \\\\ \u0026= -\\frac{1}{\\tau}\\mathbb{E}_{(\\mathbf{x},\\mathbf{x}^+)\\sim p_\\texttt{pos}}f(\\mathbf{x})^\\top f(\\mathbf{x}^+) + \\mathbb{E}_{ \\mathbf{x} \\sim p_\\texttt{data}} \\Big[ \\log \\mathbb{E}_{\\mathbf{x}^- \\sim p_\\texttt{data}} \\big[ \\sum_{i=1}^M \\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}_i^-) / \\tau)\\big] \\Big] \u0026 \\end{aligned} $$ Key Ingredients Heavy Data Augmentation Given a training sample, data augmentation techniques are needed for creating noise versions of itself to feed into the loss as positive samples. Proper data augmentation setup is critical for learning good and generalizable embedding features. It introduces the non-essential variations into examples without modifying semantic meanings and thus encourages the model to learn the essential part of the representation. For example, experiments in SimCLR showed that the composition of random cropping and random color distortion is crucial for good performance on learning visual representation of images.\nLarge Batch Size Using a large batch size during training is another key ingredient in the success of many contrastive learning methods (e.g. SimCLR, CLIP), especially when it relies on in-batch negatives. Only when the batch size is big enough, the loss function can cover a diverse enough collection of negative samples, challenging enough for the model to learn meaningful representation to distinguish different examples.\nHard Negative Mining Hard negative samples should have different labels from the anchor sample, but have embedding features very close to the anchor embedding. With access to ground truth labels in supervised datasets, it is easy to identify task-specific hard negatives. For example when learning sentence embedding, we can treat sentence pairs labelled as \u0026ldquo;contradiction\u0026rdquo; in NLI datasets as hard negative pairs (e.g. SimCSE, or use top incorrect candidates returned by BM25 with most keywords matched as hard negative samples (DPR; Karpukhin et al., 2020).\nHowever, it becomes tricky to do hard negative mining when we want to remain unsupervised. Increasing training batch size or memory bank size implicitly introduces more hard negative samples, but it leads to a heavy burden of large memory usage as a side effect.\nChuang et al. (2020) studied the sampling bias in contrastive learning and proposed debiased loss. In the unsupervised setting, since we do not know the ground truth labels, we may accidentally sample false negative samples. Sampling bias can lead to significant performance drop.\nFig. 3. Sampling bias which refers to false negative samples in contrastive learning can lead to a big performance drop. (Image source: Chuang et al., 2020) Let us assume the probability of anchor class $c$ is uniform $\\rho(c)=\\eta^+$ and the probability of observing a different class is $\\eta^- = 1-\\eta^+$.\n The probability of observing a positive example for $\\mathbf{x}$ is $p^+_x(\\mathbf{x}')=p(\\mathbf{x}'\\vert \\mathbf{h}_{x'}=\\mathbf{h}_x)$; The probability of getting a negative sample for $\\mathbf{x}$ is $p^-_x(\\mathbf{x}')=p(\\mathbf{x}'\\vert \\mathbf{h}_{x'}\\neq\\mathbf{h}_x)$. When we are sampling $\\mathbf{x}^-$ , we cannot access the true $p^-_x(\\mathbf{x}^-)$ and thus $\\mathbf{x}^-$ may be sampled from the (undesired) anchor class $c$ with probability $\\eta^+$. The actual sampling data distribution becomes:\n $$ p(\\mathbf{x}') = \\eta^+ p^+_x(\\mathbf{x}') + \\eta^- p_x^-(\\mathbf{x}') $$ Thus we can use $p^-_x(\\mathbf{x}') = (p(\\mathbf{x}') - \\eta^+ p^+_x(\\mathbf{x}'))/\\eta^-$ for sampling $\\mathbf{x}^-$ to debias the loss. With $N$ samples $\\{\\mathbf{u}_i\\}^N_{i=1}$ from $p$ and $M$ samples $\\{ \\mathbf{v}_i \\}_{i=1}^M$ from $p^+_x$ , we can estimate the expectation of the second term $\\mathbb{E}_{\\mathbf{x}^-\\sim p^-_x}[\\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^-))]$ in the denominator of contrastive learning loss:\n $$ g(\\mathbf{x}, \\{\\mathbf{u}_i\\}^N_{i=1}, \\{\\mathbf{v}_i\\}_{i=1}^M) = \\max\\Big\\{ \\frac{1}{\\eta^-}\\Big( \\frac{1}{N}\\sum_{i=1}^N \\exp(f(\\mathbf{x})^\\top f(\\mathbf{u}_i)) - \\frac{\\eta^+}{M}\\sum_{i=1}^M \\exp(f(\\mathbf{x})^\\top f(\\mathbf{v}_i)) \\Big), \\exp(-1/\\tau) \\Big\\} $$ where $\\tau$ is the temperature and $\\exp(-1/\\tau)$ is the theoretical lower bound of $\\mathbb{E}_{\\mathbf{x}^-\\sim p^-_x}[\\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^-))]$.\nThe final debiased contrastive loss looks like:\n $$ \\mathcal{L}^{N,M}_\\text{debias}(f) = \\mathbb{E}_{\\mathbf{x},\\{\\mathbf{u}_i\\}^N_{i=1}\\sim p;\\;\\mathbf{x}^+, \\{\\mathbf{v}_i\\}_{i=1}^M\\sim p^+} \\Big[ -\\log\\frac{\\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^+)}{\\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^+) + N g(x,\\{\\mathbf{u}_i\\}^N_{i=1}, \\{\\mathbf{v}_i\\}_{i=1}^M)} \\Big] $$ Fig. 4. t-SNE visualization of learned representation with debiased contrastive learning. (Image source: Chuang et al., 2020) Following the above annotation, Robinson et al. (2021) modified the sampling probabilities to target at hard negatives by up-weighting the probability $p^-_x(x')$ to be proportional to its similarity to the anchor sample. The new sampling probability $q_\\beta(x^-)$ is:\n $$ q_\\beta(\\mathbf{x}^-) \\propto \\exp(\\beta f(\\mathbf{x})^\\top f(\\mathbf{x}^-)) \\cdot p(\\mathbf{x}^-) $$ where $\\beta$ is a hyperparameter to tune.\nWe can estimate the second term in the denominator $\\mathbb{E}_{\\mathbf{x}^- \\sim q_\\beta} [\\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^-))]$ using importance sampling where both the partition functions $Z_\\beta, Z^+_\\beta$ can be estimated empirically.\n $$ \\begin{aligned} \\mathbb{E}_{\\mathbf{u} \\sim q_\\beta} [\\exp(f(\\mathbf{x})^\\top f(\\mathbf{u}))] \u0026= \\mathbb{E}_{\\mathbf{u} \\sim p} [\\frac{q_\\beta}{p}\\exp(f(\\mathbf{x})^\\top f(\\mathbf{u}))] = \\mathbb{E}_{\\mathbf{u} \\sim p} [\\frac{1}{Z_\\beta}\\exp((\\beta + 1)f(\\mathbf{x})^\\top f(\\mathbf{u}))] \\\\ \\mathbb{E}_{\\mathbf{v} \\sim q^+_\\beta} [\\exp(f(\\mathbf{x})^\\top f(\\mathbf{v}))] \u0026= \\mathbb{E}_{\\mathbf{v} \\sim p^+} [\\frac{q^+_\\beta}{p}\\exp(f(\\mathbf{x})^\\top f(\\mathbf{v}))] = \\mathbb{E}_{\\mathbf{v} \\sim p} [\\frac{1}{Z^+_\\beta}\\exp((\\beta + 1)f(\\mathbf{x})^\\top f(\\mathbf{v}))] \\end{aligned} $$ Fig. 5. Pseudo code for computing NCE loss, debiased contrastive loss, and hard negative sample objective when setting $M=1$. (Image source: Robinson et al., 2021 ) Vision: Image Embedding Image Augmentations Most approaches for contrastive representation learning in the vision domain rely on creating a noise version of a sample by applying a sequence of data augmentation techniques. The augmentation should significantly change its visual appearance but keep the semantic meaning unchanged.\nBasic Image Augmentation There are many ways to modify an image while retaining its semantic meaning. We can use any one of the following augmentation or a composition of multiple operations.\n Random cropping and then resize back to the original size. Random color distortions Random Gaussian blur Random color jittering Random horizontal flip Random grayscale conversion Multi-crop augmentation: Use two standard resolution crops and sample a set of additional low resolution crops that cover only small parts of the image. Using low resolution crops reduces the compute cost. (SwAV) And many more \u0026hellip; Augmentation Strategies Many frameworks are designed for learning good data augmentation strategies (i.e. a composition of multiple transforms). Here are a few common ones.\n AutoAugment (Cubuk, et al. 2018): Inspired by NAS, AutoAugment frames the problem of learning best data augmentation operations (i.e. shearing, rotation, invert, etc.) for image classification as an RL problem and looks for the combination that leads to the highest accuracy on the evaluation set. RandAugment (Cubuk et al., 2019): RandAugment greatly reduces the search space of AutoAugment by controlling the magnitudes of different transformation operations with a single magnitude parameter. PBA (Population based augmentation; Ho et al., 2019): PBA combined PBT (Jaderberg et al, 2017) with AutoAugment, using the evolutionary algorithm to train a population of children models in parallel to evolve the best augmentation strategies. UDA (Unsupervised Data Augmentation; Xie et al., 2019): Among a set of possible augmentation strategies, UDA selects those to minimize the KL divergence between the predicted distribution over an unlabelled example and its unlabelled augmented version. Image Mixture Image mixture methods can construct new training examples from existing data points.\n Mixup (Zhang et al., 2018): It runs global-level mixture by creating a weighted pixel-wise combination of two existing images $I_1$ and $I_2$: $I_\\text{mixup} \\gets \\alpha I_1 + (1-\\alpha) I_2$ and $\\alpha \\in [0, 1]$. Cutmix (Yun et al., 2019): Cutmix does region-level mixture by generating a new example by combining a local region of one image with the rest of the other image. $I_\\text{cutmix} \\gets \\mathbf{M}_b \\odot I_1 + (1-\\mathbf{M}_b) \\odot I_2$, where $\\mathbf{M}_b \\in \\{0, 1\\}^I$ is a binary mask and $\\odot$ is element-wise multiplication. It is equivalent to filling the cutout (DeVries \u0026amp; Taylor 2017) region with the same region from another image. MoCHi (\u0026ldquo;Mixing of Contrastive Hard Negatives\u0026rdquo;; Kalantidis et al. 2020): Given a query $\\mathbf{q}$, MoCHi maintains a queue of $K$ negative features $Q=\\{\\mathbf{n}_1, \\dots, \\mathbf{n}_K \\}$ and sorts these negative features by similarity to the query, $\\mathbf{q}^\\top \\mathbf{n}$, in descending order. The first $N$ items in the queue are considered as the hardest negatives, $Q^N$. Then synthetic hard examples can be generated by $\\mathbf{h} = \\tilde{\\mathbf{h}} / |\\tilde{\\mathbf{h}}|$ where $\\tilde{\\mathbf{h}} = \\alpha\\mathbf{n}_i + (1-\\alpha) \\mathbf{n}_j$ and $\\alpha \\in (0, 1)$. Even harder examples can be created by mixing with the query feature, $\\mathbf{h}' = \\tilde{\\mathbf{h}'} / |\\tilde{\\mathbf{h}'}|_2$ where $\\tilde{\\mathbf{h}'} = \\beta\\mathbf{q} + (1-\\beta) \\mathbf{n}_j$ and $\\beta \\in (0, 0.5)$. Parallel Augmentation This category of approaches produce two noise versions of one anchor image and aim to learn representation such that these two augmented samples share the same embedding.\nSimCLR SimCLR (Chen et al, 2020) proposed a simple framework for contrastive learning of visual representations. It learns representations for visual inputs by maximizing agreement between differently augmented views of the same sample via a contrastive loss in the latent space.\nFig. 6. A simple framework for contrastive learning of visual representations. (Image source: Chen et al, 2020) Randomly sample a minibatch of $N$ samples and each sample is applied with two different data augmentation operations, resulting in $2N$ augmented samples in total. $$ \\tilde{\\mathbf{x}}_i = t(\\mathbf{x}),\\quad\\tilde{\\mathbf{x}}_j = t'(\\mathbf{x}),\\quad t, t' \\sim \\mathcal{T} $$ where two separate data augmentation operators, $t$ and $t'$, are sampled from the same family of augmentations $\\mathcal{T}$. Data augmentation includes random crop, resize with random flip, color distortions, and Gaussian blur.\nGiven one positive pair, other $2(N-1)$ data points are treated as negative samples. The representation is produced by a base encoder $f(.)$: $$ \\mathbf{h}_i = f(\\tilde{\\mathbf{x}}_i),\\quad \\mathbf{h}_j = f(\\tilde{\\mathbf{x}}_j) $$ The contrastive learning loss is defined using cosine similarity $\\text{sim}(.,.)$. Note that the loss operates on an extra projection layer of the representation $g(.)$ rather than on the representation space directly. But only the representation $\\mathbf{h}$ is used for downstream tasks. $$ \\begin{aligned} \\mathbf{z}_i \u0026= g(\\mathbf{h}_i),\\quad \\mathbf{z}_j = g(\\mathbf{h}_j) \\\\ \\mathcal{L}_\\text{SimCLR}^{(i,j)} \u0026= - \\log\\frac{\\exp(\\text{sim}(\\mathbf{z}_i, \\mathbf{z}_j) / \\tau)}{\\sum_{k=1}^{2N} \\mathbb{1}_{[k \\neq i]} \\exp(\\text{sim}(\\mathbf{z}_i, \\mathbf{z}_k) / \\tau)} \\end{aligned} $$ where $\\mathbb{1}_{[k \\neq i]}$ is an indicator function: 1 if $k\\neq i$ 0 otherwise.\nSimCLR needs a large batch size to incorporate enough negative samples to achieve good performance.\nFig. 7. The algorithm for SimCLR. (Image source: Chen et al, 2020). Barlow Twins Barlow Twins (Zbontar et al. 2021) feeds two distorted versions of samples into the same network to extract features and learns to make the cross-correlation matrix between these two groups of output features close to the identity. The goal is to keep the representation vectors of different distorted versions of one sample similar, while minimizing the redundancy between these vectors.\nFig. 8. Illustration of Barlow Twins learning pipeline. (Image source: Zbontar et al. 2021). Let $\\mathcal{C}$ be a cross-correlation matrix computed between outputs from two identical networks along the batch dimension. $\\mathcal{C}$ is a square matrix with the size same as the feature network\u0026rsquo;s output dimensionality. Each entry in the matrix $\\mathcal{C}_{ij}$ is the cosine similarity between network output vector dimension at index $i, j$ and batch index $b$, $\\mathbf{z}_{b,i}^A$ and $\\mathbf{z}_{b,j}^B$, with a value between -1 (i.e. perfect anti-correlation) and 1 (i.e. perfect correlation).\n $$ \\begin{aligned} \\mathcal{L}_\\text{BT} \u0026= \\underbrace{\\sum_i (1-\\mathcal{C}_{ii})^2}_\\text{invariance term} + \\lambda \\underbrace{\\sum_i\\sum_{i\\neq j} \\mathcal{C}_{ij}^2}_\\text{redundancy reduction term} \\\\ \\text{where } \\mathcal{C}_{ij} \u0026= \\frac{\\sum_b \\mathbf{z}^A_{b,i} \\mathbf{z}^B_{b,j}}{\\sqrt{\\sum_b (\\mathbf{z}^A_{b,i})^2}\\sqrt{\\sum_b (\\mathbf{z}^B_{b,j})^2}} \\end{aligned} $$ Barlow Twins is competitive with SOTA methods for self-supervised learning. It naturally avoids trivial constants (i.e. collapsed representations), and is robust to different training batch sizes.\nFig. 9. Algorithm of Barlow Twins in Pytorch style pseudo code. (Image source: Zbontar et al. 2021). BYOL Different from the above approaches, interestingly, BYOL (Bootstrap Your Own Latent; Grill, et al 2020) claims to achieve a new state-of-the-art results without using egative samples. It relies on two neural networks, referred to as online and target networks that interact and learn from each other. The target network (parameterized by $\\xi$) has the same architecture as the online one (parameterized by $\\theta$), but with polyak averaged weights, $\\xi \\leftarrow \\tau \\xi + (1-\\tau) \\theta$.\nThe goal is to learn a presentation $y$ that can be used in downstream tasks. The online network parameterized by $\\theta$ contains:\n An encoder $f_\\theta$; A projector $g_\\theta$; A predictor $q_\\theta$. The target network has the same network architecture, but with different parameter $\\xi$, updated by polyak averaging $\\theta$: $\\xi \\leftarrow \\tau \\xi + (1-\\tau) \\theta$.\nFig. 10. The model architecture of BYOL. After training, we only care about $f\\_\\theta$ for producing representation, $y=f\\_\\theta(x)$, and everything else is discarded. $\\text{sg}$ means stop gradient. (Image source: Grill, et al 2020) Given an image $\\mathbf{x}$, the BYOL loss is constructed as follows:\n Create two augmented views: $\\mathbf{v}=t(\\mathbf{x}); \\mathbf{v}'=t'(\\mathbf{x})$ with augmentations sampled $t \\sim \\mathcal{T}, t' \\sim \\mathcal{T}'$; Then they are encoded into representations, $\\mathbf{y}_\\theta=f_\\theta(\\mathbf{v}), \\mathbf{y}'=f_\\xi(\\mathbf{v}')$; Then they are projected into latent variables, $\\mathbf{z}_\\theta=g_\\theta(\\mathbf{y}_\\theta), \\mathbf{z}'=g_\\xi(\\mathbf{y}')$; The online network outputs a prediction $q_\\theta(\\mathbf{z}_\\theta)$; Both $q_\\theta(\\mathbf{z}_\\theta)$ and $\\mathbf{z}'$ are L2-normalized, giving us $\\bar{q}_\\theta(\\mathbf{z}_\\theta) = q_\\theta(\\mathbf{z}_\\theta) / | q_\\theta(\\mathbf{z}_\\theta) |$ and $\\bar{\\mathbf{z}'} = \\mathbf{z}' / |\\mathbf{z}'|$; The loss $\\mathcal{L}^\\text{BYOL}_\\theta$ is MSE between L2-normalized prediction $\\bar{q}_\\theta(\\mathbf{z})$ and $\\bar{\\mathbf{z}'}$; The other symmetric loss $\\tilde{\\mathcal{L}}^\\text{BYOL}_\\theta$ can be generated by switching $\\mathbf{v}'$ and $\\mathbf{v}$; that is, feeding $\\mathbf{v}'$ to online network and $\\mathbf{v}$ to target network. The final loss is $\\mathcal{L}^\\text{BYOL}_\\theta + \\tilde{\\mathcal{L}}^\\text{BYOL}_\\theta$ and only parameters $\\theta$ are optimized. Unlike most popular contrastive learning based approaches, BYOL does not use negative pairs. Most bootstrapping approaches rely on pseudo-labels or cluster indices, but BYOL directly boostrapps the latent representation.\nIt is quite interesting and surprising that without negative samples, BYOL still works well. Later I ran into this post by Abe Fetterman \u0026amp; Josh Albrecht, they highlighted two surprising findings while they were trying to reproduce BYOL:\n BYOL generally performs no better than random when batch normalization is removed. The presence of batch normalization implicitly causes a form of contrastive learning. They believe that using negative samples is important for avoiding model collapse (i.e. what if you use all-zeros representation for every data point?). Batch normalization injects dependency on negative samples inexplicitly because no matter how similar a batch of inputs are, the values are re-distributed (spread out $\\sim \\mathcal{N}(0, 1$) and therefore batch normalization prevents model collapse. Strongly recommend you to read the full article if you are working in this area. Memory Bank Computing embeddings for a large number of negative samples in every batch is extremely expensive. One common approach is to store the representation in memory to trade off data staleness for cheaper compute.\nInstance Discrimination with Memoy Bank Instance contrastive learning (Wu et al, 2018) pushes the class-wise supervision to the extreme by considering each instance as a distinct class of its own. It implies that the number of \u0026ldquo;classes\u0026rdquo; will be the same as the number of samples in the training dataset. Hence, it is unfeasible to train a softmax layer with these many heads, but instead it can be approximated by NCE.\nFig. 11. The training pipeline of instance-level contrastive learning. The learned embedding is L2-normalized. (Image source: Wu et al, 2018) Let $\\mathbf{v} = f_\\theta(x)$ be an embedding function to learn and the vector is normalized to have $|\\mathbf{v}|=1$. A non-parametric classifier predicts the probability of a sample $\\mathbf{v}$ belonging to class $i$ with a temperature parameter $\\tau$:\n $$ P(C=i\\vert \\mathbf{v}) = \\frac{\\exp(\\mathbf{v}_i^\\top \\mathbf{v} / \\tau)}{\\sum_{j=1}^n \\exp(\\mathbf{v}_j^\\top \\mathbf{v} / \\tau)} $$ Instead of computing the representations for all the samples every time, they implement an Memory Bank for storing sample representation in the database from past iterations. Let $V=\\{ \\mathbf{v}_i \\}$ be the memory bank and $\\mathbf{f}_i = f_\\theta(\\mathbf{x}_i)$ be the feature generated by forwarding the network. We can use the representation from the memory bank $\\mathbf{v}_i$ instead of the feature forwarded from the network $\\mathbf{f}_i$ when comparing pairwise similarity.\nThe denominator theoretically requires access to the representations of all the samples, but that is too expensive in practice. Instead we can estimate it via Monte Carlo approximation using a random subset of $M$ indices $\\{j_k\\}_{k=1}^M$.\n $$ P(i\\vert \\mathbf{v}) = \\frac{\\exp(\\mathbf{v}^\\top \\mathbf{f}_i / \\tau)}{\\sum_{j=1}^N \\exp(\\mathbf{v}_j^\\top \\mathbf{f}_i / \\tau)} \\simeq \\frac{\\exp(\\mathbf{v}^\\top \\mathbf{f}_i / \\tau)}{\\frac{N}{M} \\sum_{k=1}^M \\exp(\\mathbf{v}_{j_k}^\\top \\mathbf{f}_i / \\tau)} $$ Because there is only one instance per class, the training is unstable and fluctuates a lot. To improve the training smoothness, they introduced an extra term for positive samples in the loss function based on the proximal optimization method. The final NCE loss objective looks like:\n $$ \\begin{aligned} \\mathcal{L}_\\text{instance} \u0026= - \\mathbb{E}_{P_d}\\big[\\log h(i, \\mathbf{v}^{(t-1)}_i) - \\lambda \\|\\mathbf{v}^{(t)}_i - \\mathbf{v}^{(t-1)}_i\\|^2_2\\big] - M\\mathbb{E}_{P_n}\\big[\\log(1 - h(i, \\mathbf{v}'^{(t-1)})\\big] \\\\ h(i, \\mathbf{v}) \u0026= \\frac{P(i\\vert\\mathbf{v})}{P(i\\vert\\mathbf{v}) + MP_n(i)} \\text{ where the noise distribution is uniform }P_n = 1/N \\end{aligned} $$ where $\\{ \\mathbf{v}^{(t-1)} \\}$ are embeddings stored in the memory bank from the previous iteration. The difference between iterations $|\\mathbf{v}^{(t)}_i - \\mathbf{v}^{(t-1)}_i|^2_2$ will gradually vanish as the learned embedding converges.\nMoCo \u0026amp; MoCo-V2 Momentum Contrast (MoCo; He et al, 2019) provides a framework of unsupervised learning visual representation as a dynamic dictionary look-up. The dictionary is structured as a large FIFO queue of encoded representations of data samples.\nGiven a query sample $\\mathbf{x}_q$, we get a query representation through an encoder $\\mathbf{q} = f_q(\\mathbf{x}_q)$. A list of key representations $\\{\\mathbf{k}_1, \\mathbf{k}_2, \\dots \\}$ in the dictionary are encoded by a momentum encoder $\\mathbf{k}_i = f_k (\\mathbf{x}^k_i)$. Let\u0026rsquo;s assume among them there is a single positive key $\\mathbf{k}^+$ in the dictionary that matches $\\mathbf{q}$. In the paper, they create $\\mathbf{k}^+$ using a noise copy of $\\mathbf{x}_q$ with different augmentation. Then the InfoNCE contrastive loss with temperature $\\tau$ is used over one positive and $N-1$ negative samples:\n $$ \\mathcal{L}_\\text{MoCo} = - \\log \\frac{\\exp(\\mathbf{q} \\cdot \\mathbf{k}^+ / \\tau)}{\\sum_{i=1}^N \\exp(\\mathbf{q} \\cdot \\mathbf{k}_i / \\tau)} $$ Compared to the memory bank, a queue-based dictionary in MoCo enables us to reuse representations of immediately preceding mini-batches of data.\nThe MoCo dictionary is not differentiable as a queue, so we cannot rely on back-propagation to update the key encoder $f_k$. One naive way might be to use the same encoder for both $f_q$ and $f_k$. Differently, MoCo proposed to use a momentum-based update with a momentum coefficient $m \\in [0, 1)$. Say, the parameters of $f_q$ and $f_k$ are labeled as $\\theta_q$ and $\\theta_k$, respectively.\n $$ \\theta_k \\leftarrow m \\theta_k + (1-m) \\theta_q $$ Fig. 12. Illustration of how Momentum Contrast (MoCo) learns visual representations. (Image source: He et al, 2019) The advantage of MoCo compared to SimCLR is that MoCo decouples the batch size from the number of negatives, but SimCLR requires a large batch size in order to have enough negative samples and suffers performance drops when their batch size is reduced.\nTwo designs in SimCLR, namely, (1) an MLP projection head and (2) stronger data augmentation, are proved to be very efficient. MoCo V2 (Chen et al, 2020) combined these two designs, achieving even better transfer performance with no dependency on a very large batch size.\nCURL CURL (Srinivas, et al. 2020) applies the above ideas in Reinforcement Learning. It learns a visual representation for RL tasks by matching embeddings of two data-augmented versions, $o_q$ and $o_k$, of the raw observation $o$ via contrastive loss. CURL primarily relies on random crop data augmentation. The key encoder is implemented as a momentum encoder with weights as EMA of the query encoder weights, same as in MoCo.\nOne significant difference between RL and supervised visual tasks is that RL depends on temporal consistency between consecutive frames. Therefore, CURL applies augmentation consistently on each stack of frames to retain information about the temporal structure of the observation.\nFig. 13. The architecture of CURL. (Image source: Srinivas, et al. 2020) Feature Clustering DeepCluster DeepCluster (Caron et al. 2018) iteratively clusters features via k-means and uses cluster assignments as pseudo labels to provide supervised signals.\nFig. 14. Illustration of DeepCluster method which iteratively clusters deep features and uses the cluster assignments as pseudo-labels. (Image source: Caron et al. 2018) In each iteration, DeepCluster clusters data points using the prior representation and then produces the new cluster assignments as the classification targets for the new representation. However this iterative process is prone to trivial solutions. While avoiding the use of negative pairs, it requires a costly clustering phase and specific precautions to avoid collapsing to trivial solutions.\nSwAV SwAV (Swapping Assignments between multiple Views; Caron et al. 2020) is an online contrastive learning algorithm. It computes a code from an augmented version of the image and tries to predict this code using another augmented version of the same image.\nFig. 15. Comparison of SwAV and [contrastive instance learning](#instance-discrimination-with-memoy-bank). (Image source: Caron et al. 2020) Given features of images with two different augmentations, $\\mathbf{z}_t$ and $\\mathbf{z}_s$, SwAV computes corresponding codes $\\mathbf{q}_t$ and $\\mathbf{q}_s$ and the loss quantifies the fit by swapping two codes using $\\ell(.)$ to measure the fit between a feature and a code.\n $$ \\mathcal{L}_\\text{SwAV}(\\mathbf{z}_t, \\mathbf{z}_s) = \\ell(\\mathbf{z}_t, \\mathbf{q}_s) + \\ell(\\mathbf{z}_s, \\mathbf{q}_t) $$ The swapped fit prediction depends on the cross entropy between the predicted code and a set of $K$ trainable prototype vectors $\\mathbf{C} = \\{\\mathbf{c}_1, \\dots, \\mathbf{c}_K\\}$. The prototype vector matrix is shared across different batches and represents anchor clusters that each instance should be clustered to.\n $$ \\ell(\\mathbf{z}_t, \\mathbf{q}_s) = - \\sum_k \\mathbf{q}^{(k)}_s\\log\\mathbf{p}^{(k)}_t \\text{ where } \\mathbf{p}^{(k)}_t = \\frac{\\exp(\\mathbf{z}_t^\\top\\mathbf{c}_k / \\tau)}{\\sum_{k'}\\exp(\\mathbf{z}_t^\\top \\mathbf{c}_{k'} / \\tau)} $$ In a mini-batch containing $B$ feature vectors $\\mathbf{Z} = [\\mathbf{z}_1, \\dots, \\mathbf{z}_B]$, the mapping matrix between features and prototype vectors is defined as $\\mathbf{Q} = [\\mathbf{q}_1, \\dots, \\mathbf{q}_B] \\in \\mathbb{R}_+^{K\\times B}$. We would like to maximize the similarity between the features and the prototypes:\n $$ \\begin{aligned} \\max_{\\mathbf{Q}\\in\\mathcal{Q}} \u0026\\text{Tr}(\\mathbf{Q}^\\top \\mathbf{C}^\\top \\mathbf{Z}) + \\varepsilon \\mathcal{H}(\\mathbf{Q}) \\\\ \\text{where }\\mathcal{Q} \u0026= \\big\\{ \\mathbf{Q} \\in \\mathbb{R}_{+}^{K \\times B} \\mid \\mathbf{Q}\\mathbf{1}_B = \\frac{1}{K}\\mathbf{1}_K, \\mathbf{Q}^\\top\\mathbf{1}_K = \\frac{1}{B}\\mathbf{1}_B \\big\\} \\end{aligned} $$ where $\\mathcal{H}$ is the entropy, $\\mathcal{H}(\\mathbf{Q}) = - \\sum_{ij} \\mathbf{Q}_{ij} \\log \\mathbf{Q}_{ij}$, controlling the smoothness of the code. The coefficient $\\epsilon$ should not be too large; otherwise, all the samples will be assigned uniformly to all the clusters. The candidate set of solutions for $\\mathbf{Q}$ requires every mapping matrix to have each row sum up to $1/K$ and each column to sum up to $1/B$, enforcing that each prototype gets selected at least $B/K$ times on average.\nSwAV relies on the iterative Sinkhorn-Knopp algorithm (Cuturi 2013) to find the solution for $\\mathbf{Q}$.\nWorking with Supervised Datasets CLIP CLIP (Contrastive Language-Image Pre-training; Radford et al. 2021) jointly trains a text encoder and an image feature extractor over the pretraining task that predicts which caption goes with which image.\nFig. 16. Illustration of CLIP contrastive pre-training over text-image pairs. (Image source: Radford et al. 2021) Given a batch of $N$ (image, text) pairs, CLIP computes the dense cosine similarity matrix between all $N\\times N$ possible (image, text) candidates within this batch. The text and image encoders are jointly trained to maximize the similarity between $N$ correct pairs of (image, text) associations while minimizing the similarity for $N(N-1)$ incorrect pairs via a symmetric cross entropy loss over the dense matrix.\nSee the numy-like pseudo code for CLIP in Fig. 17.\nFig. 17. CLIP algorithm in Numpy style pseudo code. (Image source: Radford et al. 2021) Compared to other methods above for learning good visual representation, what makes CLIP really special is \u0026ldquo;the appreciation of using natural language as a training signal\u0026rdquo;. It does demand access to supervised dataset in which we know which text matches which image. It is trained on 400 million (text, image) pairs, collected from the Internet. The query list contains all the words occurring at least 100 times in the English version of Wikipedia. Interestingly, they found that Transformer-based language models are 3x slower than a bag-of-words (BoW) text encoder at zero-shot ImageNet classification. Using contrastive objective instead of trying to predict the exact words associated with images (i.e. a method commonly adopted by image caption prediction tasks) can further improve the data efficiency another 4x.\nFig. 18. Using bag-of-words text encoding and contrastive training objectives can bring in multiple folds of data efficiency improvement. (Image source: Radford et al. 2021) CLIP produces good visual representation that can non-trivially transfer to many CV benchmark datasets, achieving results competitive with supervised baseline. Among tested transfer tasks, CLIP struggles with very fine-grained classification, as well as abstract or systematic tasks such as counting the number of objects. The transfer performance of CLIP models is smoothly correlated with the amount of model compute.\nSupervised Contrastive Learning There are several known issues with cross entropy loss, such as the lack of robustness to noisy labels and the possibility of poor margins. Existing improvement for cross entropy loss involves the curation of better training data, such as label smoothing and data augmentation. Supervised Contrastive Loss (Khosla et al. 2021) aims to leverage label information more effectively than cross entropy, imposing that normalized embeddings from the same class are closer together than embeddings from different classes.\nFig. 19. Supervised vs self-supervised contrastive losses. Supervised contrastive learning considers different samples from the same class as positive examples, in addition to augmented versions. (Image source: Khosla et al. 2021) Given a set of randomly sampled $n$ (image, label) pairs, $\\{\\mathbf{x}_i, y_i\\}_{i=1}^n$, $2n$ training pairs can be created by applying two random augmentations of every sample, $\\{\\tilde{\\mathbf{x}}_i, \\tilde{y}_i\\}_{i=1}^{2n}$.\nSupervised contrastive loss $\\mathcal{L}_\\text{supcon}$ utilizes multiple positive and negative samples, very similar to soft nearest-neighbor loss:\n $$ \\mathcal{L}_\\text{supcon} = - \\sum_{i=1}^{2n} \\frac{1}{2 \\vert N_i \\vert - 1} \\sum_{j \\in N(y_i), j \\neq i} \\log \\frac{\\exp(\\mathbf{z}_i \\cdot \\mathbf{z}_j / \\tau)}{\\sum_{k \\in I, k \\neq i}\\exp({\\mathbf{z}_i \\cdot \\mathbf{z}_k / \\tau})} $$ where $\\mathbf{z}_k=P(E(\\tilde{\\mathbf{x}_k}))$, in which $E(.)$ is an encoder network (augmented image mapped to vector) $P(.)$ is a projection network (one vector mapped to another). $N_i= \\{j \\in I: \\tilde{y}_j = \\tilde{y}_i \\}$ contains a set of indices of samples with label $y_i$. Including more positive samples into the set $N_i$ leads to improved results.\nAccording to their experiments, supervised contrastive loss:\n does outperform the base cross entropy, but only by a small amount. outperforms the cross entropy on robustness benchmark (ImageNet-C, which applies common naturally occuring perturbations such as noise, blur and contrast changes to the ImageNet dataset). is less sensitive to hyperparameter changes. Language: Sentence Embedding In this section, we focus on how to learn sentence embedding.\nText Augmentation Most contrastive methods in vision applications depend on creating an augmented version of each image. However, it is more challenging to construct text augmentation which does not alter the semantics of a sentence. In this section we look into three approaches for augmenting text sequences, including lexical edits, back-translation and applying cutoff or dropout.\nLexical Edits EDA (Easy Data Augmentation; Wei \u0026amp; Zou 2019) defines a set of simple but powerful operations for text augmentation. Given a sentence, EDA randomly chooses and applies one of four simple operations:\n Synonym replacement (SR): Replace $n$ random non-stop words with their synonyms. Random insertion (RI): Place a random synonym of a randomly selected non-stop word in the sentence at a random position. Random swap (RS): Randomly swap two words and repeat $n$ times. Random deletion (RD): Randomly delete each word in the sentence with probability $p$. where $p=\\alpha$ and $n=\\alpha \\times \\text{sentence_length}$, with the intuition that longer sentences can absorb more noise while maintaining the original label. The hyperparameter $\\alpha$ roughly indicates the percent of words in one sentence that may be changed by one augmentation.\nEDA is shown to improve the classification accuracy on several classification benchmark datasets compared to baseline without EDA. The performance lift is more significant on a smaller training set. All the four operations in EDA help improve the classification accuracy, but get to optimal at different $\\alpha$\u0026rsquo;s.\nFig. 20. EDA leads to performance improvement on several classification benchmarks. (Image source: Wei \u0026 Zou 2019) In Contextual Augmentation (Sosuke Kobayashi, 2018), new substitutes for word $w_i$ at position $i$ can be smoothly sampled from a given probability distribution, $p(.\\mid S\\setminus\\{w_i\\})$, which is predicted by a bidirectional LM like BERT.\nBack-translation CERT (Contrastive self-supervised Encoder Representations from Transformers; Fang et al. (2020); code) generates augmented sentences via back-translation. Various translation models for different languages can be employed for creating different versions of augmentations. Once we have a noise version of text samples, many contrastive learning frameworks introduced above, such as MoCo, can be used to learn sentence embedding.\nDropout and Cutoff Shen et al. (2020) proposed to apply Cutoff to text augmentation, inspired by cross-view training. They proposed three cutoff augmentation strategies:\n Token cutoff removes the information of a few selected tokens. To make sure there is no data leakage, corresponding tokens in the input, positional and other relevant embedding matrices should all be zeroed out., Feature cutoff removes a few feature columns. Span cutoff removes a continuous chunk of texts. Fig. 21. Schematic illustration of token, feature and span cutoff augmentation strategies. (Image source: Shen et al. 2020) Multiple augmented versions of one sample can be created. When training, Shen et al. (2020) applied an additional KL-divergence term to measure the consensus between predictions from different augmented samples.\nSimCSE (Gao et al. 2021; code) learns from unsupervised data by predicting a sentence from itself with only dropout noise. In other words, they treat dropout as data augmentation for text sequences. A sample is simply fed into the encoder twice with different dropout masks and these two versions are the positive pair where the other in-batch samples are considered as negative pairs. It feels quite similar to the cutoff augmentation, but dropout is more flexible with less well-defined semantic meaning of what content can be masked off.\nFig. 22. SimCSE creates augmented samples by applying different dropout masks. The supervised version leverages NLI datasets to predict positive (entailment) or negative (contradiction) given a pair of sentences. (Image source: Gao et al. 2021) They ran experiments on 7 STS (Semantic Text Similarity) datasets and computed cosine similarity between sentence embeddings. They also tried out an optional MLM auxiliary objective loss to help avoid catastrophic forgetting of token-level knowledge. This aux loss was found to help improve performance on transfer tasks, but a consistent drop on the main STS tasks.\nFig. 23. Experiment numbers on a collection of STS benchmarks with SimCES. (Image source: Gao et al. 2021) Supervision from NLI The pre-trained BERT sentence embedding without any fine-tuning has been found to have poor performance for semantic similarity tasks. Instead of using the raw embeddings directly, we need to refine the embedding with further fine-tuning.\nNatural Language Inference (NLI) tasks are the main data sources to provide supervised signals for learning sentence embedding; such as SNLI, MNLI, and QQP.\nSentence-BERT SBERT (Sentence-BERT) (Reimers \u0026amp; Gurevych, 2019) relies on siamese and triplet network architectures to learn sentence embeddings such that the sentence similarity can be estimated by cosine similarity between pairs of embeddings. Note that learning SBERT depends on supervised data, as it is fine-tuned on several NLI datasets.\nThey experimented with a few different prediction heads on top of BERT model:\n Softmax classification objective: The classification head of the siamese network is built on the concatenation of two embeddings $f(\\mathbf{x}), f(\\mathbf{x}')$ and $\\vert f(\\mathbf{x}) - f(\\mathbf{x}') \\vert$. The predicted output is $\\hat{y}=\\text{softmax}(\\mathbf{W}_t [f(\\mathbf{x}); f(\\mathbf{x}'); \\vert f(\\mathbf{x}) - f(\\mathbf{x}') \\vert])$. They showed that the most important component is the element-wise difference $\\vert f(\\mathbf{x}) - f(\\mathbf{x}') \\vert$. Regression objective: This is the regression loss on $\\cos(f(\\mathbf{x}), f(\\mathbf{x}'))$, in which the pooling strategy has a big impact. In the experiments, they observed that max performs much worse than mean and CLS-token. Triplet objective: $\\max(0, |f(\\mathbf{x}) - f(\\mathbf{x}^+)|- |f(\\mathbf{x}) - f(\\mathbf{x}^-)| + \\epsilon)$, where $\\mathbf{x}, \\mathbf{x}^+, \\mathbf{x}^-$ are embeddings of the anchor, positive and negative sentences. In the experiments, which objective function works the best depends on the datasets, so there is no universal winner.\nFig. 24. Illustration of Sentence-BERT training framework with softmax classification head and regression head. (Image source: Reimers \u0026 Gurevych, 2019) The SentEval library (Conneau and Kiela, 2018) is commonly used for evaluating the quality of learned sentence embedding. SBERT outperformed other baselines at that time (Aug 2019) on 5 out of 7 tasks.\nFig. 25. The performance of Sentence-BERT on the SentEval benchmark. (Image source: Reimers \u0026 Gurevych, 2019) BERT-flow The embedding representation space is deemed isotropic if embeddings are uniformly distributed on each dimension; otherwise, it is anisotropic. Li et al, (2020) showed that a pre-trained BERT learns a non-smooth anisotropic semantic space of sentence embeddings and thus leads to poor performance for text similarity tasks without fine-tuning. Empirically, they observed two issues with BERT sentence embedding: Word frequency biases the embedding space. High-frequency words are close to the origin, but low-frequency ones are far away from the origin. Low-frequency words scatter sparsely. The embeddings of low-frequency words tend to be farther to their $k$-NN neighbors, while the embeddings of high-frequency words concentrate more densely.\nBERT-flow (Li et al, 2020; code) was proposed to transform the embedding to a smooth and isotropic Gaussian distribution via normalizing flows.\nFig. 26. Illustration of the flow-based calibration over the original sentence embedding space in BERT-flow. (Image source: Li et al, 2020) Let $\\mathcal{U}$ be the observed BERT sentence embedding space and $\\mathcal{Z}$ be the desired latent space which is a standard Gaussian. Thus, $p_\\mathcal{Z}$ is a Gaussian density function and $f_\\phi: \\mathcal{Z}\\to\\mathcal{U}$ is an invertible transformation:\n $$ \\mathbf{z}\\sim p_\\mathcal{Z}(\\mathbf{z}) \\quad \\mathbf{u}=f_\\phi(\\mathbf{z}) \\quad \\mathbf{z}=f^{-1}_\\phi(\\mathbf{u}) $$ A flow-based generative model learns the invertible mapping function by maximizing the likelihood of $\\mathcal{U}$\u0026rsquo;s marginal:\n $$ \\max_\\phi\\mathbb{E}_{\\mathbf{u}=\\text{BERT}(s), s\\sim\\mathcal{D}} \\Big[ \\log p_\\mathcal{Z}(f^{-1}_\\phi(\\mathbf{u})) + \\log\\big\\vert\\det\\frac{\\partial f^{-1}_\\phi(\\mathbf{u})}{\\partial\\mathbf{u}}\\big\\vert \\Big] $$ where $s$ is a sentence sampled from the text corpus $\\mathcal{D}$. Only the flow parameters $\\phi$ are optimized while parameters in the pretrained BERT stay unchanged.\nBERT-flow was shown to improve the performance on most STS tasks either with or without supervision from NLI datasets. Because learning normalizing flows for calibration does not require labels, it can utilize the entire dataset including validation and test sets.\nWhitening Operation Su et al. (2021) applied whitening operation to improve the isotropy of the learned representation and also to reduce the dimensionality of sentence embedding.\nThey transform the mean value of the sentence vectors to 0 and the covariance matrix to the identity matrix. Given a set of samples $\\{\\mathbf{x}_i\\}_{i=1}^N$, let $\\tilde{\\mathbf{x}}_i$ and $\\tilde{\\Sigma}$ be the transformed samples and corresponding covariance matrix:\n $$ \\begin{aligned} \\mu \u0026= \\frac{1}{N}\\sum_{i=1}^N \\mathbf{x}_i \\quad \\Sigma = \\frac{1}{N}\\sum_{i=1}^N (\\mathbf{x}_i - \\mu)^\\top (\\mathbf{x}_i - \\mu) \\\\ \\tilde{\\mathbf{x}}_i \u0026= (\\mathbf{x}_i - \\mu)W \\quad \\tilde{\\Sigma} = W^\\top\\Sigma W = I \\text{ thus } \\Sigma = (W^{-1})^\\top W^{-1} \\end{aligned} $$ If we get SVD decomposition of $\\Sigma = U\\Lambda U^\\top$, we will have $W^{-1}=\\sqrt{\\Lambda} U^\\top$ and $W=U\\sqrt{\\Lambda^{-1}}$. Note that within SVD, $U$ is an orthogonal matrix with column vectors as eigenvectors and $\\Lambda$ is a diagonal matrix with all positive elements as sorted eigenvalues.\nA dimensionality reduction strategy can be applied by only taking the first $k$ columns of $W$, named Whitening-$k$.\nFig. 27. Pseudo code of the whitening-$k$ operation. (Image source: Su et al. 2021) Whitening operations were shown to outperform BERT-flow and achieve SOTA with 256 sentence dimensionality on many STS benchmarks, either with or without NLI supervision.\nUnsupervised Sentence Embedding Learning Context Prediction Quick-Thought (QT) vectors (Logeswaran \u0026amp; Lee, 2018) formulate sentence representation learning as a classification problem: Given a sentence and its context, a classifier distinguishes context sentences from other contrastive sentences based on their vector representations (\u0026ldquo;cloze test\u0026rdquo;). Such a formulation removes the softmax output layer which causes training slowdown.\nFig. 28. Illustration of how Quick-Thought sentence embedding vectors are learned. (Image source: Logeswaran \u0026 Lee, 2018) Let $f(.)$ and $g(.)$ be two functions that encode a sentence $s$ into a fixed-length vector. Let $C(s)$ be the set of sentences in the context of $s$ and $S(s)$ be the set of candidate sentences including only one sentence $s_c \\in C(s)$ and many other non-context negative sentences. Quick Thoughts model learns to optimize the probability of predicting the only true context sentence $s_c \\in S(s)$. It is essentially NCE loss when considering the sentence $(s, s_c)$ as the positive pairs while other pairs $(s, s')$ where $s' \\in S(s), s'\\neq s_c$ as negatives.\n $$ \\mathcal{L}_\\text{QT} = - \\sum_{s \\in \\mathcal{D}} \\sum_{s_c \\in C(s)} \\log p(s_c \\vert s, S(s)) = - \\sum_{s \\in \\mathcal{D}} \\sum_{s_c \\in C(s)}\\frac{\\exp(f(s)^\\top g(s_c))}{\\sum_{s'\\in S(s)} \\exp(f(s)^\\top g(s'))} $$ Mutual Information Maximization IS-BERT (Info-Sentence BERT) (Zhang et al. 2020; code) adopts a self-supervised learning objective based on mutual information maximization to learn good sentence embeddings in the unsupervised manners.\nFig. 29. Illustration of Info-Sentence BERT. (Image source: Zhang et al. 2020) IS-BERT works as follows:\n Use BERT to encode an input sentence $s$ to a token embedding of length $l$, $\\mathbf{h}_{1:l}$.\n Then apply 1-D conv net with different kernel sizes (e.g. 1, 3, 5) to process the token embedding sequence to capture the n-gram local contextual dependencies: $\\mathbf{c}_i = \\text{ReLU}(\\mathbf{w} \\cdot \\mathbf{h}_{i:i+k-1} + \\mathbf{b})$. The output sequences are padded to stay the same sizes of the inputs.\n The final local representation of the $i$-th token $\\mathcal{F}_\\theta^{(i)} (\\mathbf{x})$ is the concatenation of representations of different kernel sizes.\n The global sentence representation $\\mathcal{E}_\\theta(\\mathbf{x})$ is computed by applying a mean-over-time pooling layer on the token representations $\\mathcal{F}_\\theta(\\mathbf{x}) = \\{\\mathcal{F}_\\theta^{(i)} (\\mathbf{x}) \\in \\mathbb{R}^d\\}_{i=1}^l$.\n Since the mutual information estimation is generally intractable for continuous and high-dimensional random variables, IS-BERT relies on the Jensen-Shannon estimator (Nowozin et al., 2016, Hjelm et al., 2019) to maximize the mutual information between $\\mathcal{E}_\\theta(\\mathbf{x})$ and $\\mathcal{F}_\\theta^{(i)} (\\mathbf{x})$.\n $$ I^\\text{JSD}_\\omega(\\mathcal{F}_\\theta^{(i)} (\\mathbf{x}); \\mathcal{E}_\\theta(\\mathbf{x})) = \\mathbb{E}_{\\mathbf{x}\\sim P} [-\\text{sp}(-T_\\omega(\\mathcal{F}_\\theta^{(i)} (\\mathbf{x}); \\mathcal{E}_\\theta(\\mathbf{x})))] \\\\ - \\mathbb{E}_{\\mathbf{x}\\sim P, \\mathbf{x}' \\sim\\tilde{P}} [\\text{sp}(T_\\omega(\\mathcal{F}_\\theta^{(i)} (\\mathbf{x}'); \\mathcal{E}_\\theta(\\mathbf{x})))] $$ where $T_\\omega: \\mathcal{F}\\times\\mathcal{E} \\to \\mathbb{R}$ is a learnable network with parameters $\\omega$, generating discriminator scores. The negative sample $\\mathbf{x}'$ is sampled from the distribution $\\tilde{P}=P$. And $\\text{sp}(x)=\\log(1+e^x)$ is the softplus activation function.\nThe unsupervised numbers on SentEval with IS-BERT outperforms most of the unsupervised baselines (Sep 2020), but unsurprisingly weaker than supervised runs. When using labelled NLI datasets, IS-BERT produces results comparable with SBERT (See Fig. 25 \u0026amp; 30).\nFig. 30. The performance of IS-BERT on the SentEval benchmark. (Image source: Zhang et al. 2020) Citation Cited as:\n Weng, Lilian. (May 2021). Contrastive representation learning. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2021-05-31-contrastive/.\n Or\n@article{weng2021contrastive, title = \u0026quot;Contrastive Representation Learning\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2021\u0026quot;, month = \u0026quot;May\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2021-05-31-contrastive/\u0026quot; } References [1] Sumit Chopra, Raia Hadsell and Yann LeCun. \u0026ldquo;Learning a similarity metric discriminatively, with application to face verification.\u0026quot; CVPR 2005.\n[2] Florian Schroff, Dmitry Kalenichenko and James Philbin. \u0026ldquo;FaceNet: A Unified Embedding for Face Recognition and Clustering.\u0026quot; CVPR 2015.\n[3] Hyun Oh Song et al. \u0026ldquo;Deep Metric Learning via Lifted Structured Feature Embedding.\u0026quot; CVPR 2016. [code]\n[4] Ruslan Salakhutdinov and Geoff Hinton. \u0026ldquo;Learning a Nonlinear Embedding by Preserving Class Neighbourhood Structure\u0026rdquo; AISTATS 2007.\n[5] Michael Gutmann and Aapo Hyvärinen. \u0026ldquo;Noise-contrastive estimation: A new estimation principle for unnormalized statistical models.\u0026quot; AISTATS 2010.\n[6] Kihyuk Sohn et al. \u0026ldquo;Improved Deep Metric Learning with Multi-class N-pair Loss Objective\u0026rdquo; NIPS 2016.\n[7] Nicholas Frosst, Nicolas Papernot and Geoffrey Hinton. \u0026ldquo;Analyzing and Improving Representations with the Soft Nearest Neighbor Loss.\u0026quot; ICML 2019\n[8] Tongzhou Wang and Phillip Isola. \u0026ldquo;Understanding Contrastive Representation Learning through Alignment and Uniformity on the Hypersphere.\u0026quot; ICML 2020. [code]\n[9] Zhirong Wu et al. \u0026ldquo;Unsupervised feature learning via non-parametric instance-level discrimination.\u0026quot; CVPR 2018.\n[10] Ekin D. Cubuk et al. \u0026ldquo;AutoAugment: Learning augmentation policies from data.\u0026quot; arXiv preprint arXiv:1805.09501 (2018).\n[11] Daniel Ho et al. \u0026ldquo;Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules.\u0026quot; ICML 2019.\n[12] Ekin D. Cubuk \u0026amp; Barret Zoph et al. \u0026ldquo;RandAugment: Practical automated data augmentation with a reduced search space.\u0026quot; arXiv preprint arXiv:1909.13719 (2019).\n[13] Hongyi Zhang et al. \u0026ldquo;mixup: Beyond Empirical Risk Minimization.\u0026quot; ICLR 2017.\n[14] Sangdoo Yun et al. \u0026ldquo;CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features.\u0026quot; ICCV 2019.\n[15] Yannis Kalantidis et al. \u0026ldquo;Mixing of Contrastive Hard Negatives\u0026rdquo; NeuriPS 2020.\n[16] Ashish Jaiswal et al. \u0026ldquo;A Survey on Contrastive Self-Supervised Learning.\u0026quot; arXiv preprint arXiv:2011.00362 (2021)\n[17] Jure Zbontar et al. \u0026ldquo;Barlow Twins: Self-Supervised Learning via Redundancy Reduction.\u0026quot; arXiv preprint arXiv:2103.03230 (2021) [code]\n[18] Alec Radford, et al. \u0026ldquo;Learning Transferable Visual Models From Natural Language Supervision\u0026rdquo; arXiv preprint arXiv:2103.00020 (2021)\n[19] Mathilde Caron et al. \u0026ldquo;Unsupervised Learning of Visual Features by Contrasting Cluster Assignments (SwAV).\u0026quot; NeuriPS 2020.\n[20] Mathilde Caron et al. \u0026ldquo;Deep Clustering for Unsupervised Learning of Visual Features.\u0026quot; ECCV 2018.\n[21] Prannay Khosla et al. \u0026ldquo;Supervised Contrastive Learning.\u0026quot; NeurIPS 2020.\n[22] Aaron van den Oord, Yazhe Li \u0026amp; Oriol Vinyals. \u0026ldquo;Representation Learning with Contrastive Predictive Coding\u0026rdquo; arXiv preprint arXiv:1807.03748 (2018).\n[23] Jason Wei and Kai Zou. \u0026ldquo;EDA: Easy data augmentation techniques for boosting performance on text classification tasks.\u0026quot; EMNLP-IJCNLP 2019.\n[24] Sosuke Kobayashi. \u0026ldquo;Contextual Augmentation: Data Augmentation by Words with Paradigmatic Relations.\u0026quot; NAACL 2018\n[25] Hongchao Fang et al. \u0026ldquo;CERT: Contrastive self-supervised learning for language understanding.\u0026quot; arXiv preprint arXiv:2005.12766 (2020).\n[26] Dinghan Shen et al. \u0026ldquo;A Simple but Tough-to-Beat Data Augmentation Approach for Natural Language Understanding and Generation.\u0026quot; arXiv preprint arXiv:2009.13818 (2020) [code]\n[27] Tianyu Gao et al. \u0026ldquo;SimCSE: Simple Contrastive Learning of Sentence Embeddings.\u0026quot; arXiv preprint arXiv:2104.08821 (2020). [code]\n[28] Nils Reimers and Iryna Gurevych. \u0026ldquo;Sentence-BERT: Sentence embeddings using Siamese BERT-networks.\u0026quot; EMNLP 2019.\n[29] Jianlin Su et al. \u0026ldquo;Whitening sentence representations for better semantics and faster retrieval.\u0026quot; arXiv preprint arXiv:2103.15316 (2021). [code]\n[30] Yan Zhang et al. \u0026ldquo;An unsupervised sentence embedding method by mutual information maximization.\u0026quot; EMNLP 2020. [code]\n[31] Bohan Li et al. \u0026ldquo;On the sentence embeddings from pre-trained language models.\u0026quot; EMNLP 2020.\n[32] Lajanugen Logeswaran and Honglak Lee. \u0026ldquo;An efficient framework for learning sentence representations.\u0026quot; ICLR 2018.\n[33] Joshua Robinson, et al. \u0026ldquo;Contrastive Learning with Hard Negative Samples.\u0026quot; ICLR 2021.\n[34] Ching-Yao Chuang et al. \u0026ldquo;Debiased Contrastive Learning.\u0026quot; NeuriPS 2020.\n","permalink":"https://lilianweng.github.io/posts/2021-05-31-contrastive/","summary":"The goal of contrastive representation learning is to learn such an embedding space in which similar sample pairs stay close to each other while dissimilar ones are far apart. Contrastive learning can be applied to both supervised and unsupervised settings. When working with unsupervised data, contrastive learning is one of the most powerful approaches in self-supervised learning.\nContrastive Training Objectives In early versions of loss functions for contrastive learning, only one positive and one negative sample are involved.","title":"Contrastive Representation Learning"},{"content":"Large pretrained language models are trained over a sizable collection of online data. They unavoidably acquire certain toxic behavior and biases from the Internet. Pretrained language models are very powerful and have shown great success in many NLP tasks. However, to safely deploy them for practical real-world applications demands a strong safety control over the model generation process.\nMany challenges are associated with the effort to diminish various types of unsafe content:\n First, there are a variety of unsafe content types, such as toxicity, abusiveness, hate speech, biases, stereotypes, cyberbullying, identity attacks and more, which may or may not demand different treatment. Second, there is no clearly and widely agreed-upon categorization and definition of unsafe behavior in pretrained language models. Individual perceptions could vary a lot due to different social backgrounds. In this post, we delve into the issue of toxicity in language models. As I\u0026rsquo;m still struggling to find a concrete definition of toxic content, I list a couple in the literature below.\n [Perspective API] A rude, disrespectful, or unreasonable comment; likely to make people leave a discussion.\n [Kurita et al. 2019] Content that can offend or harm its recipients, including hate speech, racism, and offensive language.\n [Pavlopoulos et al. 2020] We use the term \u0026lsquo;toxic\u0026rsquo; as an umbrella term, but we note that the literature uses several terms for different kinds of toxic language or related phenomena: \u0026lsquo;offensive\u0026rsquo;, \u0026lsquo;abusive\u0026rsquo;, \u0026lsquo;hateful\u0026rsquo;, etc.\n Overall, toxicity is a broad term to describe several types of unsafe content. Methodologies in this post can be applied given some form of definition of toxicity; e.g. presented in the instruction for annotators. How to properly define the concept of toxicity and thus collect accurate annotation labels is out of the scope of this post.\nCategorization of Toxic Content How to categorize toxic content is not a straightforward task. Which content should be considered toxic and what types of toxic content exist can be very subjective. Language that does not look offensive to one group might seem inappropriate to another.\nOne popular categorization of offensive language is proposed by Zampieri et al. (2019), a three-level hierarchical taxonomy considering both the type and the target of offense. The Offensive Language Identification Dataset (OLID) dataset is collected based on this taxonomy.\nFig. 1. The three-level hierarchical taxonomy for categorizing offensive language, proposed by Zampieri et al. (2019). Level A: \u0026ldquo;Is it offensive?\u0026rdquo; [OFF] Offensive: Inappropriate language, insults, or threats. [NOT] Not offensive: No offense or profanity. Level B: \u0026ldquo;Is the offensive text targeted?\u0026rdquo; [TIN] Targeted Insult: Targeted insult or threat towards an individual, a group or other. [UNT] Untargeted: Non-targeted profanity and swearing. Level C: What is the target? [IND] The offense targets an individual, often defined as \u0026ldquo;cyberbullying\u0026rdquo;. [GRP] The offense targets a group of people based on ethnicity, gender, sexual orientation, religion, or other common characteristic, often defined as \u0026ldquo;hate speech\u0026rdquo;. [OTH] The target can belong to other categories, such as an organization, an event, an issue, etc. Data Collection Preparing a dataset of samples labelled as \u0026ldquo;safe\u0026rdquo; vs \u0026ldquo;unsafe\u0026rdquo; is the foundation for training a toxic language classifier and further providing signals for model detoxification.\nHuman Annotations Vidgen \u0026amp; Derczynski (2020) summarized that training data annotations for toxicity detection on the high level can be collected by:\n Expert coding: An expert has enough knowledge or training to complete the annotation tasks with good quality, such as a researcher who studies prejudice, a student with moderate level of training, or a NLP practitioner. It is more expensive but produces high-quality data. Crowdsourcing: Crowdsourcing platform pairs a large number of non-expert annotators with tasks. It is easier to scale up but demands more attention on quality control. Professional moderators: Professional moderators are experienced, well-trained on the tasks, but their goals are likely to optimize for the output specific to the platform. Synthetic data: Training dataset can also be manually created by relevant content creators to cover a broad range of toxic content types. Crowdsourcing is the most common approach among them (Davidson et al. 2017, Zampieri et al. 2019) and there are several good practices to improve the data quality:\n Test data: A small set of annotations collected from a few experts can be used as test questions (Zampieri et al. 2019) to filter out human annotators on the crowdsourcing platform who cannot achieve a certain threshold. Clear guidelines: Detailed instructions are useful to guide annotators to produce aligned and consistent labels. Without any guideline, annotators are encouraged to apply their personal perceptions, which could be problematic because (1) subjective interpretation of toxic content varies across individuals greatly and (2) it is tricky to mark certain types of noise like sarcasm and irony without any guideline. Majority vote: It is very common that we need labels from multiple annotators per sample and take the majority vote. Understanding annotators' identities: Demographic background has a big impact on the annotator\u0026rsquo;s understanding of the task. We should aim to recruit diverse and qualified annotators. Semi-supervised Dataset Khatri et al. (2018) proposed a simple approach to bootstrap a large amount of semi-supervised dataset for learning toxic content classifiers. Their approach relies on a small annotated dataset and a large unlabelled dataset.\n First, they gather a blacklist of 800+ words covering topics of profanity, hate, sexual content and insults. A black list of profanities may have high precision and low recall, but it can provide weak supervised signals. Subreddits are sorted by the percentage of blacklisted words. Then sensitive examples are sampled from the top subreddits and non-sensitive ones from the bottom, respectively. Train a weak binary classifier to further select more samples from the sorted subreddits, Sensitive: contain blacklisted words or toxic classifier confidence \u0026gt; 0.8; Non-sensitive: not contain blacklisted words and toxic classifier confidence \u0026lt; 0.3 Given this large expanded dataset, train a new classifier named \u0026ldquo;Two-stage bootstrap\u0026rdquo; (TS bootstrap). Their experiments showed that the TS bootstrap classifier achieved pretty good numbers on F1 score, accuracy and recall and it could also transfer to out-of-domain test data.\nFig. 2. The two-stage bootstrap classifier is trained on a dataset bootstrapped by a weak toxic binary classifier on Reddit data. (Image source: Khatri et al. 2018) SOLID (Semi-Supervised Offensive Language Identification Dataset; Rosenthal et al. 2020) contains 9+ M tweets annotated with the same taxonomy system as for OLID. SOLID treats OLID as a seed and extends it via a semi-supervised technique called democratic co-training. Democratic co-training (Zhou \u0026amp; Goldman, 2004) creates a large dataset from noisy labels provided by a collection of diverse models trained on a small supervised dataset. SOLID is constructed by:\n First, train a diverse set of supervised models on the labeled dataset OLID. The paper experimented with PMI (n-gram-based similarity), FastText (shallow neural model similar to BoW model), LSTM and BERT. For each sample in the unannotated dataset, each model predicts a confidence score for the target class. The scores are aggregated by taking avg() or min(). Samples with high confidence are added into the dataset. BERT model performance does not improve when the supervised dataset is large enough for a simple task, but can benefit from a big semi-supervised dataset if the original supervised dataset is too small for the task.\nToxicity Detection Given a supervised dataset, we can train a text classifier from scratch or fine-tune a pretrained language model to perform the classification task. But what if training samples are not good or sufficient enough? What if we don’t have access to such a supervised dataset?\nAdversarial Attacks To create a toxicity detection model that is robust to adversarial attacks, Dinan et al. (2019) proposed an iterative \u0026ldquo;build it, break it, fix it\u0026rdquo; strategy to improve the dialogue system safety with humans in the loop.\n Build it: A BERT model is trained to classify toxic comments on the Jigsaw dataset. Break it: Crowdsourced workers are asked to write toxic messages that are mistakenly labelled as \u0026ldquo;safe\u0026rdquo; by the model. Fix it: The model is re-trained on the combination of the original dataset and newly collected adversarial samples. Repeat: Redeploy the robustified model and repeat a new round from step 1. Fig. 3. The illustration of iteratively improving a toxic content detection model via the \"build it, break it, fix it\" process. (Image source: Dinan et al. 2019) One baseline in their experiments is to replace the adversarial collection in the \u0026ldquo;break it\u0026rdquo; step with the standard collection where workers are asked to submit \u0026ldquo;offensive\u0026rdquo; messages directly . Compared to the standard collection, the adversarial collection has less explicit profanity and more negations to trick the model. The tasks become more challenging in the later rounds.\nAdversarial models are more robust against adversarial attacks than baseline models trained on the standard collection. The third round adversarial model has worse performance on the standard task than the standard model, likely due to overfitting. I’m curious about how the model performance would be like if it is trained on both adversarial and standard collection, but I didn’t find it in the paper.\nFig. 4. The comparison of performance on standard and adversarial tasks of models trained on standard ($S\\_i$) and adversarial data collection ($A\\_i$). The subscript $i$ indicates the number of training rounds. (Image source: Dinan et al. 2019) Another type of adversarial attack is to trick the detection model to mistakenly classify a toxic sentence as safe by replacing or scrambling a subset of characters. Kurita et al. (2019) developed a method of generating such model-agnostic adversarial attacks, incorporating several types of character-level perturbations:\n Character scrambling: randomly permute character positions. Homoglyph substitution: replace one or multiple letters with similar looking international letters. Dictionary-based near-neighbor replacement: find closest but distinct token in terms of Levenshtein distance. Distractor injection: inject distractor tokens by repeating random selected sequences of non-toxic tokens. Adversarial noise combining token obfuscation and distractor tokens leads to substantial performance degradation of a toxic classifier. Character-level perturbation degrades performance more than distractors.\nThe paper proposed two ways to resolve adversarial attacks:\n Adversarial training refers to training the model on a dataset with noise. However, you need to know the details of the incoming attacks in advance. And there is no guarantee that training samples with arbitrary noise would generalize to the test set. CDAE (contextual denoising autoencoder) uses character-level and contextual information to denoise obfuscated tokens. CDAE takes a noise sample to predict the denoised version. Still, you need to know what types of character-level perturbation can be applied to create noise samples. CDAE performs comparable to BERT, but not substantially better. Perspective API perspective API (www.perspectiveapi.com) is the most widely used commercial API for toxic content detection. Perspective trains machine learning models to provide scores for several different attributes: toxicity, severe toxicity, insult, profanity, identity attack, threat, and sexually explicit. Each score is a number between [0, 1], indicating how likely the message contains a given attribute (i.e. confidence of a binary classifier) and it does not signify the severity of the attribute.\nFig. 5. The overview of Perspective API scores. (Image source: About Perspective API) Gehman et al. (2020) measured the Perspective API toxicity scores of unprompted generations sampled from several pretrained language models. \u0026ldquo;Unprompted\u0026rdquo; means that the generation is only conditioned on the start-of-sentence tokens, without injecting any additional context. Noticeably, all the tested models get to the expected maximum toxicity \u0026gt; 0.5 after 100 generations. They also pointed out that training datasets for large LMs contain an non-negligible amount of toxic content.\nFig. 6. Perspective API toxicity scores of unprompted generations. Each model generates a pool of 10k samples and the expected maximum toxicity score is estimated via bootstrapping. (Image source: Gehman et al. 2020) They collected the RealToxicityPrompt dataset for studying toxicity in conditional language model generation. It contains 100k naturally occurring prompts with associated toxicity scores from Perspective API. Some prompts that do not contain any toxic language still can trigger very offensive completion.\nDespite of its popularity, Perspective API contains known biases, as summarized by Gehman et al. (2020):\n \u0026hellip; exhibit biases against minorities and suffer from low agreement in annotations, partially due to annotator identity influencing their perception of hate speech and differences in annotation task setup.\n Notably, recent work has found that systems are overestimating the prevalence of toxicity in text that contains a minority identity mention (e.g., \u0026ldquo;I’m a gay man\u0026rdquo;) or text by racial minorities (e.g., text in African American English). This is partially due to detectors’ over-reliance on lexical cues of toxicity (including swearwords, slurs, and other \u0026ldquo;bad\u0026rdquo; words).\n Prompt-based Detection Self-diagnosis (Schick, et al. 2021) is a process of exploiting the capacity of a pretrained language model to detect socially undesired attributes in its own outputs. The diagnosis depends on a predefined prompt template where the attributes are described in short text and measures the normalized probability of the model outputting \u0026ldquo;yes\u0026rdquo; versus \u0026ldquo;no\u0026rdquo;. Note that self-diagnosis does not need to access a labelled dataset for training.\n# Self-diagnosis template \u0026quot;{sample text}\u0026quot; Question: Does the above text contain {undesired attribute}? Answer: {yes/no} They use RealToxicityPrompts dataset and Perspective API for evaluation in the experiments. The self-diagnosis performance is positively correlated with the model size.\nFig. 7. Self-diagnosis abilities for identifying undesired attributes. The ground truth is provided by Perspective API. (Image source: Schick, et al. 2021) Detoxification Blacklisting Bad word filtering is a pretty intuitive and effective way to avoid explicit profane words in the language model generation. At decoding time, we can manually reduce the probabilities of blocked words to avoid sampling them. However, it is not perfect, as it is still possible to have unsafe content composed of safe tokens.\nVocabulary shifting (Gehman et al. 2020) learns a 2-dimensional representation of toxicity versus non-toxicity for every token in the vocabulary of the pretrained model. Then the representation that encodes the non-toxicity is used to boost the likelihood of non-toxic tokens at decoding time.\nPrompt-based Detox Self-debiasing (Schick et al. 2021) follows the similar idea as in self-diagnosis. It is a process for using the internal knowledge of a pretrained language model to reduce the probability of undesired attributes in the model generation.\n# Self-debiasing template, denoted as sdb(.) The following text contains {undesired attribute s}: {sample text x} Given an input prompt $\\mathbf{x}$, a textual description of undesired attributes $s$, and the language model $M$, self-debiasing computes the difference between the probability of next words without and with the self-debiasing template $\\text{sdb}(.)$:\n $$ \\Delta(w, \\mathbf{x}, s) = p_M(w\\vert\\mathbf{x}) - p_M(w\\vert\\text{sdb}(\\mathbf{x}, s)) $$ Because $\\text{sdb}(.)$ is expected to boost the probabilities of undesired words, $\\Delta(w, \\mathbf{x}, s)$ should be negative for undesirable words.\nIn self-diasing decoding, a scaling function of the probability difference $\\alpha(\\Delta(w, \\mathbf{x}, s)): \\mathbb{R}\\to[0,1]$ is used to alter the true sampling distribution,\n $$ \\tilde{p}_M(w\\vert\\mathbf{x}) \\propto \\alpha(\\Delta(w, \\mathbf{x}, s)) p_M(w\\vert\\mathbf{x}) $$ In the paper, they used a soft variant where the probabilities of the words with negative $\\Delta$ are reduced w.r.t. the magnitude of $\\Delta(w, \\mathbf{x}, s)$:\n $$ \\alpha(x)=\\begin{cases} 1 \u0026 \\text{ if } x\\geq 0 \\\\ e^{\\lambda\\cdot x} \u0026 \\text{ otherwise} \\end{cases} $$ Fig. 8. Self-diasing decoding can reduce the probabilities of undesirable attributes. The scores are provided by Perspective API. (Image source: Schick et al. 2021) There are a couple of major limitations in self-debiasing detoxification:\n The evaluation solely relies on Perspective API, so it cannot capture bias \u0026amp; toxicity attributes that are not covered by Perspective API, such as gender biases. Using human evaluation is another alternative but the scale is limited. Self-debiasing sometimes acts too aggressively and filters out harmless words and it does not maintain the same level of perplexity as the original model. The approach is constrained by the internal capacity of the model. For example, if the model is not aware of certain biases, it would not be able to correct them. Text Style Transfer Unsupervised style transfer can be used to translate offensive sentences into innocuous ones (Santos et al. 2018). The approach should work for non-parallel datasets, meaning that we only have access to two separate datasets of offensive and non-offensive samples, but not paired versions. To preserve the content when transferring the text into another style, a cycle consistency loss (Zhu et al. 2017) is adopted.\nFig. 9. The training process of a neural text style transfer algorithm using non-parallel data. (Image source: Santos et al. 2018) Let $s_i$ be the desired style ($i=0$ for offensive and $i=1$ for non-offensive), and $\\mathbf{x}^i_k$ be the $k$-th sample of style $s_i$, $k = 1, \\dots, n$. Both the encoder $E$ and decoder $G$ take a sample (or hidden state) along with a style label. The classifier $C$ predicts a probability distribution over the style labels given an input sample.\nFollowing the illustration in Fig. 9:\n The top branch of forward transfer is auto encoder: ​$E(\\mathbf{x}^i_k, s_i) \\to H^i_k \\to G(H^i_k, s_i) \\to \\hat{\\mathbf{x}}^{i\\to i}_k$. Two losses are computed: Reconstruction loss measures how well the decoder can reconstruct the sample back: $$ \\mathcal{L}_\\text{self} = \\mathbb{E}_{\\mathbf{x}^i_k \\sim \\mathcal{X}} [-\\log p_G(\\mathbf{x}_k^i \\mid E(\\mathbf{x}^i_k, s_i), s_i)] $$ The bottom branch of forward transfer: $E(\\mathbf{x}^i_k, s_i) \\to H^i_k \\to G(H^i_k, s_j) \\to \\hat{\\mathbf{x}}^{i\\to j}_k$ Classification loss measures the effectiveness of style transfer: $$ \\mathcal{L}_\\text{style_fwd} = \\mathbb{E}_{\\hat{\\mathbf{x}}^{i\\to j}_k \\sim \\hat{\\mathcal{X}}} [-\\log p_C(s_j \\mid \\hat{\\mathbf{x}}^{i\\to j}_k)] $$ The back transfer uses cycle consistency loss: $E(\\hat{\\mathbf{x}}^{i\\to j}_k, s_j) \\to H^{i\\to j}_k \\to G(H^{i\\to j}_k, s_i) \\to \\hat{\\mathbf{x}}^{i\\to j \\to i}_k$ The cycle consistency loss controls how well the transferred sample can be converted back to the original form to encourage content preservation: $$ \\mathcal{L}_\\text{cycle} = \\mathbb{E}_{\\mathbf{x}^i_k \\sim \\mathcal{X}} [-\\log p_G(\\mathbf{x}_k^i \\mid E(\\hat{\\mathbf{x}}^{i \\to j}_k, s_j), s_i)] $$ - The classification loss ensures that the back-transferred sample has the correct label: $$ \\mathcal{L}_\\text{style_back} = \\mathbb{E}_{\\hat{\\mathbf{x}}^{i\\to j}_k \\sim \\hat{\\mathcal{X}}} [-\\log p_C(s_i \\mid G(E(\\hat{\\mathbf{x}}^{i\\to j}_k, s_j), s_i))] $$ There is an additional supervised classification loss for training an accurate classifier: $$ \\mathcal{L}_\\text{class} = \\mathbb{E}_{\\hat{\\mathbf{x}}^{i\\to j}_k \\sim \\hat{\\mathcal{X}}} [-\\log p_C(s_i \\mid \\hat{\\mathbf{x}}^i_k)] $$ The final training objective is as follows and the encoder, decoder and classifier are jointly trained:\n $$ \\mathcal{L}(\\theta_E, \\theta_G, \\theta_C) = \\min_{E, G, C} \\mathcal{L}_\\text{self} + \\mathcal{L}_\\text{style_fwd} + \\mathcal{L}_\\text{cycle} + \\mathcal{L}_\\text{style_back}+ \\mathcal{L}_\\text{class} $$ Style Transformer (Dai et al. 2019) also aims to learn unsupervised text style transfer. Different from the encoder-decoder model in Santos et al. 2018, it learns a Transformer-based style transfer function $f_\\theta(\\mathbf{x}, s)$ for a given input sample $\\mathbf{x}$ and a desired style control variable $s$.\nFig. 10. The comparison of style transformer and previous models that depend on disentangled latent representation. (Image source: Dai et al. 2019) Without access to the parallel corpus, the style transformer adopts a discriminator to create supervision from non-parallel dataset.\nLet $s$ and $\\hat{s}$ be two mutually exclusive style variables and $\\mathbf{x}$ is a sample of style $s$, style transformer computes several losses:\n Self reconstruction loss: $\\mathcal{L}_\\text{self} = - p_\\theta (\\mathbf{x} \\vert \\mathbf{x}, s)$ Cycle-consistency loss: $\\mathcal{L}_\\text{cycle} = - p_\\theta (\\mathbf{x} \\vert f_\\theta(\\mathbf{x}, \\hat{s}), s)$ Style controlling loss: This is necessary because otherwise the model would simply learn to copy the input over. $$ \\mathcal{L}_\\text{style} = - p_\\phi(\\text{class} = 1 \\vert f_\\theta(\\mathbf{x}, \\hat{s}), \\hat{s}) $$ , where the discriminator is a simple binary classifier trained to optimize the negative log-likelihood of the correct style. The discriminator is trained by labelling\n $\\{(\\mathbf{x}, s), (f_\\theta(\\mathbf{x}, s), s), (f_\\theta(\\mathbf{x}, \\hat{s}), \\hat{s})\\}$ as positive class 1 $\\{(\\mathbf{x}, \\hat{s}), (f_\\theta(\\mathbf{x}, s), \\hat{s}), (f_\\theta(\\mathbf{x}, \\hat{s}), s)\\}$ as negative class 0. Fig. 11. The training process of Style Transformer. (Image source: Dai et al. 2019) Driven by the research question \u0026ldquo;Can we fine-tune a pre-trained language model to suggest civil rephrasings of rude comments using a dataset solely annotated in toxicity?\u0026rdquo;, Laugier et al. (2021) fine-tuned a pretrained text-to-text transformer with a denoising and cyclic auto-encoder loss.\nLet $s$ be the attribute of $\\mathbf{x}$ (e.g. \u0026ldquo;civil\u0026rdquo;) and $\\bar{s}$ be the other opposite attribute (e.g. \u0026ldquo;toxic\u0026rdquo;). These two attributes are mutually exclusive. The goal is to learn a mapping function $f_\\theta$ such that it translates $x$ to a new fluent sequence $y$ with target attribute $a$ while preserving $x$\u0026rsquo;s content.\nThe encoder-decoder model is trained with the loss:\n $$ \\mathcal{L} = \\lambda_\\text{DAE} \\mathcal{L}_\\text{DAE} + \\lambda_\\text{cycle} \\mathcal{L}_\\text{cycle} $$ The denoising auto-encoder loss is the loss for denoising auto-encoders, where $\\eta$ is a masking function same as in BERT training: $$ \\mathcal{L}_\\text{DAE} = \\mathbb{E}_{\\mathbf{x} \\sim \\mathcal{X}} [−\\log p_\\theta(\\mathbf{x} \\mid \\eta(\\mathbf{x}), s)] $$ The cycle consistency loss (Zhu et al. 2017) has $\\tilde{\\theta}$ to produce a non-differentiable pseudo-prediction $\\hat{\\mathbf{y}}$ and it does not take gradient backpropagation. $$ \\mathcal{L}_\\text{cycle} = \\mathbb{E}_{\\mathbf{x} \\sim \\mathcal{X}} [−\\log p_\\theta(\\mathbf{x} \\mid f_{\\tilde{\\theta}}(\\mathbf{x}, \\bar{s}), s)] $$ They used the above loss to fine-tune a T5 model, resulting in a model named CAE-T5. The conditioning is implemented like CTRL via control code (\u0026ldquo;civil\u0026rdquo; or \u0026ldquo;toxic\u0026rdquo;) prepended to the start of a sequence.\nAutomatic evaluation of the text style transferred results relies on three metrics:\n Accuracy: Classification accuracy measures how successful the style transfer is. Fluency: Fluency is commonly measured by perplexity by another separately trained LM on non-toxic samples. Content preservation: It is the content similarity between transferred and original sentences, measured by BLEU or embedding based content similarity. Human evaluation is also necessary but more costly.\nCompared to the baseline (Shen et al. 2017), the style transfer method by Santos et al. 2018 achieves better classification accuracy, better content preservation, but worse perplexity. CAE-T5 has worse classification accuracy, competitive content preservation, and better perplexity compared to a set of baselines including Style Transformer.\nControllable Generation We can try to avoid toxic outputs via controllable text generation. There are several popular approaches for steering a pretrained language model toward desired styles, topics or safety criteria:\n Apply guided decoding strategies and select desired outputs at test time. Optimize for the most desired outcomes via good prompt design. Fine-tune the base model or steerable layers to do conditioned content generation. Read more in my last post on controllable neural text generation, introducing methods like AutoPrompt, CTRL, PPLM, GeDi and many more.\nGehman et al. (2020) experimented with both data-based (supervised fine-tuning, CTRL training) and decoding-based (vocabulary shifting, blocked word filtering, PPLM) methods for language model detoxification. They found that toxicity control tokens (CTRL) and swear word filters are less successful than more computationally or data-intensive methods like fine-tuning on non-toxic corpora and PPLM.\nFig. 12. Table list expected maximum toxicity score over 25 generations (left) and the empirical probability of generating toxic text over 25 generations (right) for several detoxification methods. Scores are provided by Perspective API. (Image source: Gehman et al., 2020) System-level Safety Solution Xu et al. (2020) presented a thorough system-level design for building safe chatbots.\nFig. 13. Illustration of a safe chat bot system. (Image source: Xu et al. 2020) They consider four general strategies in the recipes for making the bot safer:\n Detect unsafe content: Adopt a classifier for detecting unsafe language on both the input and output side, as an extra safety layer on top of the language model. The classifier is trained on an enhanced version of the Jigsaw toxic comment dataset (safe vs unsafe binary labels), extended with adversarial human attacks (Dinan et al. 2019) and semi-supervision (Khatri et al. 2018). The safety classifier can be used on both the user input and the model output. If it detects unsafe content, the system is configured to return a canned, predefined response (e.g \u0026ldquo;I\u0026rsquo;m sorry I\u0026rsquo;m not sure what to say.\u0026quot;), or decide to change topics. It is worthy noting that this approach relies on a high-quality classifier. The conversation experience would be drastically disrupted with too many false positives. Bot adversarial dialogue (BAD) safety: The idea is to collect data on humans adversarially probing the system to make mistakes and then use the data for further training. During annotation, human labellers can tag the bot\u0026rsquo;s response with an unsafe-safe rating based on the percentage of population who may consider it as unsafe. This probing data collection is used to train a multi-turn safety classifier, predicting whether a response is offensive given the dialogue context. Safe generation: Train a model that is less likely to output unsafe responses. A predefined list of unsafe words/n-grams can be blocked at decoding time. The pretraining data is filtered by the above safety classifier, or filtered based on known authors. The problem with pre-training only with safe datasets is that if the model has never seen toxic language during training, it would not know how to respond at test time (OOD; e.g. may just copy the offensive content). They instead prepare a collection of training samples where the last utterance is labelled as \u0026ldquo;unsafe\u0026rdquo; and then attach a safe response following that unsafe attack. Then the model is fine-tuned on the \u0026ldquo;baked-in\u0026rdquo; safety data. Do CTRL style training by assigning \u0026ldquo;safe\u0026rdquo; vs \u0026ldquo;unsafe\u0026rdquo; label using the safety classifier. Avoid sensitive topics: In order to avoid sensitive topics (politics, religion, drug use, medical advice, and NSFW and relationships/dating), they trained a multi-class classifier to detect those topics using crowdsourced lists of subreddits. The classifier can be periodically re-trained to capture the changes within topics over time. A small validation set is collected by recruiting crowdsourced workers to discuss one of the target topics. Gender bias mitigation: They used CTRL style training to mitigate gender biases. Precisely, given a gendered word list, tag the training samples with $F^0 M^0$, $F^0 M^+$, $F^+ M^+$, and $F^+ M^0$ labels, indicating whether the response contains female / male words ($+$ contains, $-$ does not contain). At test time, the system runs with a control label $F^0 M^0$ to avoid outputting gender specific words. Appendix: Datasets (*Only datasets in English are listed here.)\nHate Speech and Offensive Language Dataset (2017): contains about 25k tweets, each labelled manually as one of three categories: hate speech, offensive but not hate speech, or neither offensive nor hate speech. [Download]\nJigsaw Toxic Comments Classification Dataset (2018): contains about 160k examples extracted from Wikipedia discussion pages, each annotated for 7 classes: toxic, severe toxic, obscene, threat, insult, identity hate and non-toxic. The labelling process involved 5000 crowdsourced annotators. [Download]\nJigsaw Unintended Bias in Toxicity Classification Dataset (2019): contains about 2 Millions comments from the Civil Comments platform, which shut down in 2017. This data is annotated for toxicity, toxicity sub-types, and mentions of identities, which enables evaluation of unintended bias with respect to identity mentions. [Download]\nOLID (Offensive Language Identification Dataset; 2019): contains 14,100 English tweets, annotated according to the three-level taxonomy as described here. [Download]\nSOLID (Semi-Supervised Offensive Language Identification Dataset; 2020): contains 9+ Millions tweets annotated following OLID\u0026rsquo;s three level taxonomy. [Download]\nRealToxicityPrompts dataset (2020): contains 100k sentence snippets from the web with Perspective API toxicity scores for studying the risk of neural toxic degeneration in language models. [Download]\nCitation Cited as:\n Weng, Lilian. (Mar 2021). Reducing toxicity in language models. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2021-03-21-lm-toxicity/.\n Or\n@article{weng2021toxic, title = \u0026quot;Reducing Toxicity in Language Models.\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2021\u0026quot;, month = \u0026quot;Mar\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2021-03-21-lm-toxicity/\u0026quot; } References [1] Vidgen, et al. \u0026ldquo;Challenges and frontiers in abusive content detection.\u0026quot; Workshop on Abusive Language Online 2019.\n[2] Zampieri et al. \u0026ldquo;Predicting the type and target of offensive posts in social media.\u0026quot; NAACL 2019.\n[3] Vidgen \u0026amp; Deczynski. \u0026ldquo;Directions in abusive language training data, a systematic review: Garbage in, garbage out.\u0026quot; PLoS ONE 15(12): e0243300 (2020).\n[4] Davidson et al. \u0026ldquo;Automated hate speech detection and the problem of offensive language.\u0026quot; ICWSM 2017.\n[5] Khatri et al. \u0026ldquo;Detecting offensive content in open-domain conversations using two stage semi-supervision.\u0026quot; NeuriIPS CONVAI Workshop 2018.\n[6] Rosenthal et al. \u0026ldquo;A Large-Scale Semi-Supervised Dataset for Offensive Language Identification\u0026rdquo; arXiv:2004.14454 (2020).\n[7] Pavlopoulos et al. \u0026ldquo;Toxicity Detection: Does Context Really Matter?\u0026quot; arXiv:2006.00998 (2020).\n[8] Dinan et al. \u0026ldquo;Build it, break it, fix it for dialogue safety: Robustness from adversarial human attack.\u0026quot; arXiv:1908.06083 (2019).\n[9] Kurita et al. \u0026ldquo;Towards Robust Toxic Content Classification\u0026rdquo; arXiv:1912.06872 (2019)\n[10] Santos et al. \u0026ldquo;Fighting offensive language on social media with unsupervised text style transfer.\u0026quot; arXiv:1805.07685 (2018)\n[11] Dai et al. \u0026ldquo;Style Transformer: Unpaired Text Style Transfer without Disentangled Latent Representation\u0026rdquo; ACL 2019.\n[12] Laugier et al. \u0026ldquo;Civil Rephrases Of Toxic Texts With Self-Supervised Transformers\u0026rdquo; arXiv:2102.05456 (2021). code\n[13] Schick et al. \u0026ldquo;Self-Diagnosis and Self-Debiasing: A Proposal for Reducing Corpus-Based Bias in NLP\u0026rdquo; arXiv:2103.00453 (2021).\n[14] Gehman et al. \u0026ldquo;RealToxicityPrompts: Evaluating Neural Toxic Degeneration in Language Models\u0026rdquo; EMNLP 2020.\n[15] Xu et al. \u0026ldquo;Recipes for Safety in Open-domain Chatbots\u0026rdquo; arXiv:2010.07079 (2020).\n","permalink":"https://lilianweng.github.io/posts/2021-03-21-lm-toxicity/","summary":"Large pretrained language models are trained over a sizable collection of online data. They unavoidably acquire certain toxic behavior and biases from the Internet. Pretrained language models are very powerful and have shown great success in many NLP tasks. However, to safely deploy them for practical real-world applications demands a strong safety control over the model generation process.\nMany challenges are associated with the effort to diminish various types of unsafe content:","title":"Reducing Toxicity in Language Models"},{"content":"[Updated on 2021-02-01: Updated to version 2.0 with several work added and many typos fixed.] [Updated on 2021-05-26: Add P-tuning and Prompt Tuning in the \u0026ldquo;prompt design\u0026rdquo; section.] [Updated on 2021-09-19: Add \u0026ldquo;unlikelihood training\u0026rdquo;.]\nThere is a gigantic amount of free text on the Web, several magnitude more than labelled benchmark datasets. The state-of-the-art language models (LM) are trained with unsupervised Web data in large scale. When generating samples from LM by iteratively sampling the next token, we do not have much control over attributes of the output text, such as the topic, the style, the sentiment, etc. Many applications would demand a good control over the model output. For example, if we plan to use LM to generate reading materials for kids, we would like to guide the output stories to be safe, educational and easily understood by children.\nHow to steer a powerful unconditioned language model? In this post, we will delve into several approaches for controlled content generation with an unconditioned langage model. Note that model steerability is still an open research question. Each introduced method has certain pros \u0026amp; cons.\n Apply guided decoding strategies and select desired outputs at test time. Optimize for the most desired outcomes via good prompt design. Fine-tune the base model or steerable layers to do conditioned content generation. In the following discussion, we assume we have access to a pretrained generative language model $p_\\theta$. The model has learned the distribution over token sequences by optimizing for the next token prediction: $ \\mathcal{L}_\\text{ML} = - \\sum_t \\log p_\\theta(x_t \\vert x_{\u0026lt;t}) $.\nDecoding Strategies By adopting different decoding methods, we can place restrictions or preferences on the sampling process to alter the generated samples without modifying any model weights. Even though decoding strategies do not change the values of any trainable parameter, it is a quite important component.\nCommon Decoding Methods Since the final layer of the model predicts logits $o$ over the vocabulary space, the next token can be sampled by applying softmax with temperature $T$. The probability of sampling the $i$-th token is\n $$ p_i \\propto \\frac{\\exp(o_i / T)}{\\sum_j \\exp(o_j/T)} $$ A low temperature would make the distribution sharper and a high value makes it softer.\nGreedy search: Always pick the next token with the highest probability, equivalent to setting temperature $T=0$. However, it tends to create repetitions of phrases, even for well-trained models.\nBeam search: It essentially does breadth-first search, one token per tree level, but with a limited bandwidth. At each level of the search tree, beam search keeps track of $n$ (named \u0026ldquo;beam width\u0026rdquo;) best candidates and expands all the successors of these candidates in the next level. Beam search could stop expanding a node if it hits the EOS (end-of-sentence) token.\nHowever, maximization-based decoding does not guarantee high-quality generation.\n Fig. 1. The probability assigned to the next token by beam search versus by humans. The human selected tokens have much higher variance in predicted probability and thus more surprising. (Image source: Holtzman et al. 2019) Top-k sampling (Fan et al., 2018): At each sampling step, only the top $k$ most likely tokens are selected and the probability mass is redistributed among them. In Fan et al., 2018, the authors proposed to use top-k random sampling where the next token is randomly selected among the top $k$ most likely candidates and they argued that this approach can generate more novel and less repetitive content than beam search.\nNucleus sampling (Holtzman et al. 2019): Also known as \u0026ldquo;Top-p sampling\u0026rdquo;. One drawback of top-k sampling is that the predefined number $k$ does not take into consideration how skewed the probability distribution might be. The nucleus sampling selects the smallest set of top candidates with the cumulative probability exceeding a threshold (e.g. 0.95) and then the distribution is rescaled among selected candidates.\nBoth top-k and nucleus sampling have less repetitions with a proper set of hyperparameters.\nPenalized sampling (Keskar et al. 2019): To avoid the common failure case of generating duplicate substrings, the CTRL paper proposed a new sampling method to penalize repetitions by discounting the scores of previously generated tokens. The probability distribution for the next token with repetition penalty is defined as:\n $$ p_i = \\frac{\\exp(o_i / (T \\cdot \\mathbb{1}(i \\in g)))}{\\sum_j \\exp(o_j / (T \\cdot \\mathbb{1}(j \\in g)))} \\quad \\mathbb{1}(c) = \\theta \\text{ if the condition }c\\text{ is True else }1 $$ where $g$ contains a set of previously generated tokens, $\\mathbb{1}(.)$ is an identity function. $\\theta=1.2$ is found to yield a good balance between less repetition and truthful generation.\nGuided Decoding All the above standard decoding strategies sample tokens according to the predicted probability, with no additional information. Our preferences on topic or sentiment can be baked into the candidate ranking function to guide the sample generation by altering the candidate ranking score. The ranking score for token selection at each decoding step can be set as a combination of LM log-likelihood and a set of desired feature discriminators. The features are designed to quantify human preferences by heuristics (Ghazvininejad et al., 2017), supervised learning (Holtzman et al., 2018) or RL (Li et al., 2017).\nGhazvininejad et al. (2017) built a system called \u0026ldquo;Hafez\u0026rdquo; for generating poetry in desired style by adjusting sampling weights in beam search at decoding steps. The likelihood of sampling for the next token $x_{t+1}$ at step $t$ is augmented by a scoring function:\n $$ \\text{score}(x_{t+1}, b_t) = \\text{score}(b_t) + \\log p(x_{t+1}) + \\color{green}{\\sum_i \\alpha_i f_i(x_{t+1})} $$ where $\\log p(x_{t+1})$ is the log-likelihood predicted by LM. $\\text{score}(b_t)$ is the accumulated score of the already-generated words in the current beam state $b_t$. The green part can incorporate many different features for steering the style of the output. A set of feature functions $f_i(.)$ define the preferences and the associated weights $alpha_i$ work like \u0026ldquo;control knobs\u0026rdquo; that can be easily customized at decoding time. Features can measure a variety of attributes and can be easily combined; for example,\n whether $x_{t+1}$ exists in a bag of desired or banned topical words. whether $x_{t+1}$ indicates certain sentiments. whether $x_{t+1}$ is a repeated token (and thus $f_i$ needs to take the history as input too). the length of $x_{t+1}$ if longer or shorter words are in particular preferred. Similar to Hafez, Baheti et al. (2018) manually designed features for ranking and altered the sampling distribution by appending similarity scores between topic distribution or embeddings of the context and the completion.\nHoltzman et al. (2018) adopted a set of learned discriminators, each specializing in a different principle of communication guided by Grice’s maxims: quality, quantity, relation and manner. The discriminators learn to encode these desired principles by measuring repetition, entailment, relevance, and lexical diversity, respectively. Given some ground truth completion, all the discriminator models are trained to minimize the ranking log-likelihood, $\\log\\sigma(f_i(y_g) - f_i(y))$, because the gold continuation $y_g$ is expected to obtain a higher score than the generated one $y$. Here the weight coefficients $\\alpha_i$ are also learned to minimize the score difference between the golden standard and the generated completion. Discriminative Adversarial Search (DAS; Scialom et al., 2020) is inspired by GAN and trains the discriminator to tell apart human created text from machine generated text. The discriminator predicts a label for each token instead of for the entire sequence. The discriminator logprob is added to the score to guide sampling towards the human-written style.\nMeister et al. (2020) studied beam search in a regularized decoding framework:\n $$ \\mathbf{y}^* = \\arg\\max_{\\mathbf{y}\\in\\mathcal{Y}} \\big( \\underbrace{\\log p_\\theta(\\mathbf{y}\\vert\\mathbf{x})}_\\text{MAP} - \\underbrace{\\lambda\\mathcal{R}(\\mathbf{y})}_\\text{regularizer} \\big) $$ Since we expect maximum probability to have minimum surprise, the surprisal of a LM at time step $t$ can be defined as follows:\n $$ \\begin{aligned} u_0(\\texttt{BOS}) \u0026= 0 \\text{ ; BOS is a placeholder token for the beginning of a sentence.}\\\\ u_t(y) \u0026= -\\log P_\\theta(y \\vert \\mathbf{x}, \\mathbf{y}_{The MAP (maximum a posteriori) part demands for sequences with maximum probability given context, while the regularizer introduces other constraints. It is possible a global optimal strategy may need to have a high-surprisal step occasionally so that it can shorten the output length or produce more low-surprisal steps afterwards.\nBeam search has gone through the test of time in the field of NLP. The question is: If we want to model beam search as exact search in a regularized decoding framework, how should $\\mathcal{R}(\\mathbf{y})$ be modeled? The paper proposed a connection between beam search and the uniform information density (UID) hypothesis.\n \u0026ldquo;The uniform information density hypothesis (UID; Levy and Jaeger, 2007) states that—subject to the constraints of the grammar—humans prefer sentences that distribute information (in the sense of information theory) equally across the linguistic signal, e.g., a sentence.\u0026rdquo;\n In other words, it hypothesizes that humans prefer text with evenly distributed surprisal. Popular decoding methods like top-k sampling or nuclear sampling actually filter out high-surprisal options, thus implicitly encouraging the UID property in output sequences.\nThe paper experimented with several forms of regularizers:\n Greedy: $\\mathcal{R}_\\text{greedy}(\\mathbf{y}) = \\sum_{t=1}^{\\vert\\mathbf{y}\\vert} \\big(u_t(y_t) - \\min_{y' \\in \\mathcal{V}} u_t(y') \\big)^2$; if set $\\lambda \\to \\infty$, we have greedy search. Note that being greedy at each individual step does not guarantee global optimality. Variance regularizer: $\\mathcal{R}_\\text{var}(\\mathbf{y}) = \\frac{1}{\\vert\\mathbf{y}\\vert}\\sum_{t=1}^{\\vert\\mathbf{y}\\vert} \\big(u_t(y_t) - \\bar{u} \\big)^2$ , where $\\bar{u}$ is the average surprisal over all timesteps. It directly encodes the UID hypothesis. Local consistency: $\\mathcal{R}_\\text{local}(\\mathbf{y}) = \\frac{1}{\\vert\\mathbf{y}\\vert}\\sum_{t=1}^{\\vert\\mathbf{y}\\vert} \\big(u_t(y_t) - u_{t-1}(y_{t-1}) \\big)^2$; this decoding regularizer encourages adjacent tokens to have similar surprisal. Max regularizer: $\\mathcal{R}_\\text{max}(\\mathbf{y}) = \\max_t u_t(y_t)$ penalizes the maximum compensation of surprisal. Squared regularizer: $\\mathcal{R}_\\text{square}(\\mathbf{y}) = \\sum_{t=1}^{\\vert\\mathbf{y}\\vert} u_t(y_t)^2$ encourages all the tokens to have surprisal close to 0. An experiment with greedy regularizers showed that larger $\\lambda$ results in better performance (e.g. measured by BLEU for NMT task) and lower std dev of surprisal.\nFig. 2. The plot of BLEU and std. dev of surprisals as functions of the strength of the regularizer $\\lambda$. The subgraph in grey shows the relationship between BLEU and surprisal std. dev. (Image source: Meister et al. 2020) A default beam search would have text generation of decreased quality when beam size increases. Regularized beam search greatly helps alleviate this issue. A combined regularizer further improves the performance. In their experiments for NMT, they found $\\lambda=5$ for greedy and $\\lambda=2$ for squared work out as the optimal combined regularizer.\nFig. 3. The plot of BLEU of a function of beam size (left) and BLEU scores for translations created by different regularized decoding strategies. (Image source: Meister et al. 2020) Guided decoding essentially runs a more expensive beam search where the sampling probability distribution is altered by side information about human preferences.\nTrainable Decoding Given a trained language model, Gu et al (2017) proposed a trainable greedy decoding algorithm to maximize an arbitrary objective for sampling sequences. The idea is based on the noisy, parallel approximate decoding (NPAD). NPAD injects unstructured noise into the model hidden states and runs noisy decoding multiple times in parallel to avoid potential degradation. To take a step further, trainable greedy decoding replaces the unstructured noise with a learnable random variable, predicted by a RL agent that takes the previous hidden state, the previous decoded token and the context as input. In other words, the decoding algorithm learns a RL actor to manipulate the model hidden states for better outcomes.\nGrover et al. (2019) trained a binary classifier to distinguish samples from data distribution and samples from the generative model. This classifier is used to estimate importance weights for constructing a new unnormalized distribution. The proposed strategy is called likelihood-free importance weighting (LFIW).\nLet $p$ be the real data distribution and $p_\\theta$ be a learned generative model. A classical approach for evaluating the expectation of a given function $f$ under $p$ using samples from $p_\\theta$ is to use importance sampling.\n $$ \\mathbb{E}_{\\mathbf{x}\\sim p} [f(\\mathbf{x})] = \\mathbb{E}_{\\mathbf{x}\\sim p_\\theta} \\Big[\\frac{p(\\mathbf{x})}{p_\\theta(\\mathbf{x})} f(\\mathbf{x})\\Big] \\approx \\frac{1}{N} \\sum_{i=1}^N w(\\mathbf{x}_i)f(\\mathbf{x}_i) $$ However, $p(\\mathbf{x})$ can only be estimated via finite datasets. Let $c_\\phi: \\mathcal{X} \\to [0,1]$ be a probabilistic binary classifier for predicting whether a sample $\\mathbf{x}$ is from the true data distribution ($y=1$). The joint distribution over $\\mathcal{X}\\times\\mathcal{Y}$ is denoted as $q(\\mathbf{x}, y)$.\n $$ q(\\mathbf{x}\\vert y) = \\begin{cases} p_\\theta(\\mathbf{x}) \u0026 \\text{ if }y=0\\text{; predicted to be generated data} \\\\ p(\\mathbf{x}) \u0026 \\text{ otherwise; from the true data distribution} \\end{cases} $$ Then if $c_\\phi$ is Bayes optimal, the importance weight can be estimated by:\n $$ w_\\phi(\\mathbf{x}) = \\frac{p(\\mathbf{x})}{p_\\theta(\\mathbf{x})} = \\frac{q(\\mathbf{x} \\vert y=1)}{q(\\mathbf{x} \\vert y=0)} = \\frac{q(y=0)}{q(y=1)} \\frac{q(y=1 \\vert \\mathbf{x})}{q(y=0 \\vert \\mathbf{x})} = \\gamma \\frac{c_\\phi(\\mathbf{x})}{1 - c_\\phi(\\mathbf{x})} $$ where $\\gamma = \\frac{q(y=0)}{q(y=1)} \u0026gt; 0$ is a fixed odd ratio.\nSince we cannot learn a perfect optimal classifier, the importance weight would be an estimation $\\hat{w}_\\phi$. A couple of practical tricks can be applied to offset cases when the classifier exploits artifacts in the generated samples to make very confident predictions (i.e. very small importance weights):\n Self-normalization: normalize the weight by the sum $\\hat{w}_\\phi(\\mathbf{x}_i) / \\sum_{j=1}^N \\hat{w}_\\phi(\\mathbf{x}_j)$. Flattening: add a power scaling parameter $\\alpha \u0026gt; 0$, $\\hat{w}_\\phi(\\mathbf{x}_i)^\\alpha$. Clipping: specify a lower bound $\\max(\\hat{w}_\\phi(\\mathbf{x}_i), \\beta)$. To sample from an importance resampled generative model, $\\mathbf{x}\\sim p_{\\theta, \\phi}(\\mathbf{x}) \\propto p_\\theta(\\mathbf{x})\\hat{w}_\\phi(\\mathbf{x})$, they adopt SIR (Sampling-Importance-Resampling),\nFig. 4. The algorithm for sampling from a generative model according to importance weights $\\hat{w}(\\mathbf{x}\\_i)$ using SIR. (Image source: Grover et al., 2019)) Deng et al., 2020 proposed to learn a EBM to steer a LM in the residual space, $P_\\theta(x) \\propto P_\\text{LM}(x)\\exp(-E_\\theta(x))$, where $P_\\theta$ is the joint model; $E_\\theta$ is the residual energy function to be learned. If we know the partition function $Z$, we can model the generative model for generative a sequence $x_{p+1}, \\dots, x_T$ as:\n $$ P_\\theta(x_{p+1:T}\\vert x_{1:p}) = \\frac{P_\\text{LM}(x_{p+1:T}\\vert x_{1:p}) \\exp(-E_\\theta(x_{1:T}))}{Z_\\theta(x_{1:p})} $$ The goal is to learn the parameters of the energy function $E_\\theta$ such that the joint model $P_\\theta$ gets closer to the desired data distribution. The residual energy function is trained by noise contrastive estimation (NCE), considering $P_\\theta$ as the model distribution and $P_\\text{LM}$ as the noise distribution:\n $$ \\theta = \\arg\\max_{\\theta} \\mathbb{E}_{x^+ \\sim P_\\text{data}} \\log\\frac{1}{1+\\exp(E_\\theta(x^+))} + \\mathbb{E}_{x^- \\sim P_\\text{LM}} \\log\\frac{1}{1+\\exp(-E_\\theta(x^-))} $$ However, the partition function is intractable in practice. The paper proposed a simple way to first sample from the original LM and then to resample from them according to the energy function. This is unfortunately quite expensive.\nFig. 5. Top k samples from the base LM are resampled according to the residual energy function. (Image source: Deng et al., 2020) Smart Prompt Design Large language models have been shown to be very powerful on many NLP tasks, even with only prompting and no task-specific fine-tuning (GPT2, GPT3. The prompt design has a big impact on the performance on downstream tasks and often requires time-consuming manual crafting. For example, factual questions can gain a big boost with smart prompt design in \u0026ldquo;closed-book exam\u0026rdquo; (Shin et al., 2020, Jiang et al., 2020)). I’m expecting to see an increasing amount of literature on automatic smart prompt design.\nGradient-based Search AutoPrompt (Shin et al., 2020; code) is a method to automatically create prompts for various tasks via gradient-based search. AutoPrompt constructs a prompt by combining the original task inputs $x$ with a collection of trigger tokens $x_\\text{trig}$ according to a template $\\lambda$. The trigger tokens are shared across all inputs and thus universally effective.\nFig. 6. The overview of AutoPrompt. The trigger tokens are retrieved to optimize for the target outputs across all inputs. (Image source: Shin et al., 2020) The universal trigger tokens are identified using a gradient-guided search strategy same as in Wallace et al., 2019. The universal setting means that the trigger tokens $x_\\text{trig}$ can optimize for the target output $\\tilde{y}$ for all inputs from a dataset:\n $$ x_\\text{trig} = \\arg\\min_{x’_\\text{trig}} \\mathbb{E}_{x\\sim\\mathcal{X}} [\\mathcal{L}(\\tilde{y}, f(x’_\\text{trig}; x))] $$ The search operates in the embedding space. The embedding of every trigger token $e_{\\text{trig}_i}$ is first initialized to some default value and then gets updated to minimize the first-order Taylor expansion of the task-specific loss around the current token embedding:\n $$ e^{(t+1)}_\\text{trig} = \\arg\\min_{e\\in\\mathcal{V}} [e - e^{(t)}_{\\text{trig}_i}]^\\top \\nabla_{e^{(t)}_{\\text{trig}_i}} \\mathcal{L} $$ where $\\mathcal{V}$ refers to the embedding matrix of all the tokens. $\\nabla_{e^{(t)}_{\\text{trig}_i}} \\mathcal{L}$ is the average gradient of the task loss over a batch at iteration $t$. We can brute-force the optimal $e$ by a $\\vert \\mathcal{V} \\vert d$-dimensional dot product, which is cheap and can be computed in parallel.\nFig. 7. We search for trigger tokens by updating their embeddings with the gradient of the task loss per batch. (Image source: Wallace et al., 2019) The above token replacement method can be augmented with beam search. When looking for the optimal token embedding $e$, we can pick top-$k$ candidates instead of a single one, searching from left to right and score each beam by $\\mathcal{L}$ on the current data batch.\nFig. 8. Example prompts discovered by AutoPrompt for different tasks. (Image source: Shin et al., 2020) Smart prompt design essentially produces efficient context that can lead to desired completion. Motivated by this observation, Li \u0026amp; Liang (2021) proposed Prefix-Tuning which assigns a small number of trainable parameters at the beginning of an input sequence (named \u0026ldquo;prefix\u0026rdquo;) to steer a LM, $[\\text{PREFIX}; x; y]$. Let $\\mathcal{P}_\\text{idx}$ be a set of prefix indices and $\\text{dim}(h_i)$ be the embedding size. The prefix parameters $P_\\theta$ has the dimension $\\vert\\mathcal{P}_\\text{idx}\\vert \\times \\text{dim}(h_i) $ and the hidden state takes the form:\n $$ h_i = \\begin{cases} P_\\theta[i,:], \u0026 \\text{if }i \\in \\mathcal{P}_\\text{idx}\\\\ \\text{LM}_\\phi(z_i, h_{Note that only $P_\\theta$ is trainable and the LM parameters $\\phi$ is frozen during training.\nFig. 9. Illustrations of fine-tuning versus prefix-tuning. (Image source: Li \u0026 Liang 2021) The prefix parameters do not tie to any embeddings associated with the real words and thus they are more expressive for steering the context. Direct optimizing $P_\\theta$ unfortunately results in poor performance. To reduce the difficulty associated with high dimensionality training, the matrix $P_\\theta$ is reparameterized by a smaller matrix $P'_\\theta \\in \\mathbb{R}^{\\vert\\mathcal{P}_\\text{idx}\\vert \\times c}$ and a large feed forward network $\\text{MLP}_\\theta \\in \\mathbb{R}^{c\\times \\text{dim}(h_i)}$.\nThe performance increases with the prefix length $\\vert\\mathcal{P}_\\text{idx}\\vert$ up to some value. And this value varies with tasks.\nFig. 10. Task performance, summarization (left) and table-to-text (right), as a function of prefix length. (Image source: Li \u0026 Liang 2021) A few other interesting learnings from their ablation studies include:\n Tuning only the embedding layer (without prefix) is not sufficiently expressive. Placing the trainable parameter between $x$ and $y$, $[x; \\text{INFIX}; y]$, slightly underperforms prefix-tuning, likely because it only affects the context for $y$ while prefix affects both. Random initialization of $P_\\theta$ leads to low performance with high variance. In contrast, initializing $P_\\theta$ with activations of real words improves generation, even the words are irrelevant to the task. Fine-tuned models achieve better task performance but they can fail in the low data regime. Both AutoPrompt and Prefix-Tuning were found to outperform fine-tuning in the regime where the training dataset is small (i.e. $10^2-10^3$ samples). As an alternative to fine-tuning, prompt design or learning the context embedding is much cheaper. AutoPrompt improves the accuracy for sentiment classification a lot more than manual prompts and achieves similar performance as linear probing. For the NLI task, AutoPrompt obtains higher accuracy than linear probing. It is able to retrieve facts more accurately than manual prompts too. In low data regime, Prefix-Tuning achieves performance comparable with fine-tuning on table-to-text generation and summarization.\nTwo successive works, P-tuning (Liu et al. 2021; code) and Prompt Tuning (Lester et al. 2021), follow the similar idea of explicit training continuous prompt embeddings but with a few different choices over the trainable parameters and architecture. Different from Prefix-Tuning which concatenates continuous prompt tokens in every hidden state layer of the transformer, both P-tuning and Prompt Tuning non-invasively add continuous prompts only in the input to work well.\nLet $[P_i]$ be the $i$-th token in the prompt template of P-tuning (Liu et al. 2021), we can denote a prompt as a sequence $T=\\{[P_{0:i}], \\mathbf{x}, [P_{i+1:m}], \\mathbf{y}\\}$. Each token $[P_i]$ does not have to be a real token in the model vocabulary (\u0026ldquo;pseudo-token\u0026rdquo;), and thus the encoded template $T^e$ looks like the following and the pseudo-token hidden state can be optimized with gradient descent.\n $$ T^e = \\{ h_0, \\dots, h_i, \\text{embed}(\\mathbf{x}), h_{i+1}, \\dots, h_m, \\text{embed}(\\mathbf{y})\\} $$ Fig. 11. The illustration of P-tuning. Sometimes, adding a few task-related anchor tokens, such as “capital” in the figure, can bring further improvement. (Image source: Liu et al. 2021) There are two major optimization challenges in P-tuning:\n Discreteness: The word embedding of a pretrained language model are highly discrete. It is hard to optimize $h_i$ if they are intialized at random. Association: $h_i$ should be dependent on each other. Thus they develop a mechanism to model this dependency by training a light-weighted LSTM-based prompt encoder: $$ h_i = \\text{MLP}([\\text{LSTM}(h_{0:i}): \\text{LSTM}(h_{i:m})]) $$ P-tuning is more flexible than prefix-tuning, as it inserts trainable tokens in the middle of a prompt not just at the beginning. The usage of task-specific anchor tokens is like combining manual prompt engineering with trainable prompts.\nPrompt Tuning (Lester et al. 2021) largely simplifies the idea of prefix tuning by only allowing an additional $k$ tunable tokens per downstream task to be prepended to the input text. The conditional generation is $p_{\\theta, \\theta_P}(Y \\vert [P; X])$, where $P$ is the \u0026ldquo;pseudo prompt\u0026rdquo; with parameters $\\theta_P$ trainable via back-propagation. Both $X$ and $P$ are embedding vectors and we have $X \\in \\mathbb{R}^{n \\times d^e}, P \\in \\mathbb{R}^{k \\times d^e}$ and $[P;X] \\in \\mathbb{R}^{(n+k) \\times d^e}$, where $d^e$ is the embedding space dimensionality.\n Prompt tuning produces competitive results as model fine-tuning when the model gets large (billions of parameters and up). This result is especially interesting given that large models are expensive to fine-tune and execute at inference time. With learned task-specific parameters, prompt tuning achieves better transfer learning when adapting to new domains. It outperforms fine-tuning on domain shift problems. They also showed that prompt ensembling of multiple prompts for the same task introduces further improvement. Fig. 12. The illustration of how Prompt Tuning works. (Image source: Lester et al. 2021) The experiments investigated several prompt initialization schemes:\n Random initialization by uniformly sampling from [-0.5, 0.5]; Sample embeddings of top 5000 common tokens; Use the embedding values of the class label strings. If we don\u0026rsquo;t have enough class labels to initialize the soft-prompt, we fall back to scheme 2. Random initialization performs noticeably worse than the other two options. Fig. 13. The effect of (a) different prompt initialization schemes and (b) different prompt lengths. (Image source: Lester et al. 2021) The pre-training objectives also have a big impact on the quality of prompt tuning. T5’s “span corruption” is not a good option here.\nPrompt tuning is found to be less likely to overfit to a specific dataset. To evaluate the robustness to data shifting problem, they trained the model on one dataset of one task and evaluated it on the test dataset but in a different domain. Prompt tuning is more resilient and can generalize to different domains better.\nFig. 14. Prompt tuning is more resilient to domain shift between train and test sets. (Image source: Lester et al. 2021) Heuristic-based Search Paraphrasing is a quick way to explore more prompts similar to the known version, which can be done via back-translation. Using back-translation, the initial prompt is translated into $B$ candidates in another language and then each is translated back into $B$ candidates in the original language. The resulting total $B^2$ candidates are scored and ranked by their round-trip probabilities.\nRibeiro et al (2018) identified semantically equivalent adversaries (SEA) by generating a variety of paraphrases $\\{x'\\}$ of input $x$ until it triggers a different prediction of target function $f$:\n $$ \\begin{aligned} SEA(x, x') \u0026= \\mathbb{1}[\\text{SemEq}(x, x') \\land f(x) \\neq f(x')] \\\\ \\text{where SemEq}(x, x') \u0026= \\mathbb{1}[\\min\\Big(1, \\frac{p(x'\\vert x)}{p(x\\vert x)} \\Big) \\geq \\tau] \\end{aligned} $$ The rules extracted from SEA are considered as \u0026ldquo;bugs\u0026rdquo; in the model. Applying those rules as data augmentation in model training helps robustify the model and fix bugs.\nJiang et al (2020) attempts to validate whether a trained language model knows certain knowledge by automatically discovering better prompts to query. Within the scope of knowledge retrieval where factual knowledge is represented in the form of a triple $\\langle x, r, y \\rangle$ (subject, relation, object). The prompts can be mined from training sentences (e.g. Wikipedia description) or expanded by paraphrase.\nInterestingly some small modifications in the prompts may lead to big gain, as shown in Fig. X.\nFig. 15. Small modifications in prompt templates can lead to big performance gains: replacement in blue, insertion in green, deletion in red. (Image source: Jiang et al., 2020) Fine-tuning Fine-tuning is an intuitive way to guide a LM to output desired content, commonly by training on supervised datasets or by RL. We can fine-tune all the weights in the model or restrict the fine-tuning to only top or additional layers.\nConditional Training Conditional training aims to learn a generative model conditioned on a control variable $z$, $p(y \\vert x, z)$.\nFan et al (2018) trained a conditional language model for 2-step story generation. First, a model outputs the story sketch and then a story writing model creates a story following that sketch. The mechanism of conditioning on the sketch is implemented by a fusion model architecture. The fusion model enforces a form of residual learning that allows the story writing model to focus on learning what the first sketch generation model is missing. Also for story generation, Peng et al (2018) experimented with an ending valence-conditioned story generator LM, $p(x_t \\vert x_{\u0026lt;t}, z)$ where $z$ is the label of the story ending (sad, happy or neutral). Their language model is a bidirectional LSTM and the label is mapped into a learned embedding which then blends into the LSTM cell.\nCTRL (Keskar et al., 2019; code) aims to train a language model conditioned control code $z$ using controllable datasets. CTRL learns the conditioned distribution $p(x \\vert z)$ by training on raw text sequences with control code prefixes, such as [horror], [legal], etc. Then the learned model is able to generate text with respect to the prompt prefix. The training data contains Wikipedia, OpenWebText, books, Amazon reviews, reddit corpus and many more, where each dataset is assigned with a control code and subreddit in the reddit corpus has its own topic as control code.\nFig. 16. Datasets used for training CTRL and associated control codes. (Image source: Edited from Table 7 in Keskar et al., 2019) The control code also can be used for domain annotation given tokens, because $p(z \\vert x) \\propto p(x \\vert z) p(z)$, assuming the prior over domains is uniform. One limitation of CTRL is the lack of control for what not to generate (e.g. avoid toxicity).\nFig. 17. The examples of conditioned sample generation by CTRL. (Image source: Keskar et al., 2019) Note that CTRL trains a transformer model from scratch. However, labelling all the text within the same dataset with the same control code (e.g. All the wikipedia articles have \u0026ldquo;wikipedia\u0026rdquo; as control code) feels quite constrained. Considering that often we need highly customized control codes but only have a limited amount of labelled data, I would expect fine-tuning an unconditional LM with a small labelled dataset in the same way as CTRL to work out well too. Although how much data is needed and how good the sample quality might be are subject to experimentation.\nRL Fine-tuning Fine-tuning a sequential model with RL regarding any arbitrary and possibly non-differentiable reward function has been proved to work well years ago (Ranzato et al., 2015). RL fine-tuning can resolve several problems with teacher forcing method. With teacher forcing, the model only minimizes a maximum-likelihood loss at each individual decoding step during training but it is asked to predict the entire sequence from scratch at test time. Such a discrepancy between train and test could lead to exposure bias and accumulated error. In contrast, RL fine-tuning is able to directly optimize task-specific metrics on the sequence level, such as BLEU for translation (Ranzato et al., 2015, Wu et al., 2016, Nguyen et al., 2017), ROUGE for summarization (Ranzato et al., 2015, Paulus et al., 2017, Wu and Hu, 2018) and customized metric for story generation (Tambwekar et al., 2018).\nRanzato et al (2015) applied REINFORCE to train RNN models for sequence generation tasks. The model is first trained to predict the next token using cross-entropy loss (ML loss) and then fine-tuned alternatively by both ML loss and REINFORCE (RL loss). At the second fine-tuning stage, the number of training steps for next-token prediction is gradually decreasing until none and eventually only RL loss is used. This sequence-level RL fine-tuning was shown by experiments to lead to great improvements over several supervised learning baselines back then.\nGoogle implemented the similar approach in their neural machine translation system (Wu et al., 2016) and Paulus et al (2017) adopted such approach for summarization task. The training objective contains two parts, ML loss for next token prediction, $\\mathcal{L}_\\text{ML} = \\sum_{(x, y^*)\\sim\\mathcal{D}} \\log p_\\theta(y^* \\vert x)$, and RL loss $\\mathcal{L}_\\text{RL}$ for maximizing the expected reward where the reward per sequence is measured by BLEU or ROUGE. The model is first trained with $\\mathcal{L}_\\text{ML}$ until convergence and then fine-tuned with a linear combination of two losses, $\\mathcal{L}_\\text{mix} = \\alpha \\mathcal{L}_\\text{ML} + (1 - \\alpha)\\mathcal{L}_\\text{RL}$.\nThe RL loss of Google NMT is to maximize the expected BLEU score:\n $$ \\mathcal{L}_\\text{RL} = - \\sum_{(x, y^*)\\sim\\mathcal{D}} \\mathbb{E}_{y\\sim p_\\theta(.\\vert x)} [R(y, y^*)] $$ where $y$ is the predicted sequence and $y^*$ is the ground truth.\nPaulus et al (2017) added an extra weighting term based on the reward difference between two output sequences, $y$ by sampling the next token according to the predicted probability and $\\hat{y}$ by greedily taking the most likely token. This RL loss maximizes the conditional likelihood of the sampled sequence $y$ if it obtains a higher reward than the greedy baseline $\\hat{y}$:\n $$ \\mathcal{L}_\\text{RL} = \\sum_{(x, y^*)\\sim\\mathcal{D}} (R(\\hat{y}, y^*) - R(y, y^*)) \\sum_{t=1}^{n'} \\log p(y_t \\vert y_{RL Fine-tuning with Human Preferences Reward learning is critical for defining human preferences. Quantitative measurement like BLEU or ROUGE computes the overlap of words and n-gram phrases between sequences and does not always correlate with better quality by human judges. Reward learning from human feedback (Christiano et al., 2017) is a better way to align what we measure with what we actually care about. Human feedback has been applied to learn a reward function for applications like story generation (Yi et al., 2019) and summarization (Böhm et al., 2019, Ziegler et al., 2019, Stiennon et al., 2020).\nIn order to generate more coherent conversation, Yi et al (2019) collected 4 types of binary human feedback given a conversation pair (user utterance, system response), whether the system response is (1) comprehensive, (2) on topic, (3) interesting and (4) leading to continuation of the conversation. An evaluator is trained to predict human feedback and then is used to rerank the beam search samples, to finetune the model or to do both. (Actually they didn’t use RL fine-tuning but rather use the evaluator to provide a discriminator loss in supervised fine-tuning.)\nLet\u0026rsquo;s define a learned reward function $R_\\psi(x, y)$ parameterized by $\\psi$ as a measurement for the quality of output $y$ given the input $x$.\nTo learn the ground truth reward $R^*$ defined by human judgements, Böhm et al (2019) compared two loss functions:\n(1) Regression loss: simply minimizing the mean squared error.\n $$ \\mathcal{L}^\\text{MSE}_\\text{rm} = [R^*(x, y) - R_\\psi(x, y)]^2 $$ (2) Preference loss: learning to agree with the ground truth reward,\n $$ \\begin{aligned} \\mathcal{L}^\\text{pref}_\\text{rm} =\u0026 - \\sum_{i,j} \\big(\\mathbb{1}[R^*(x, y_i) R^*(x, y_j)] \\log P(y_i \\succ y_j) + \\\\ \u0026\\mathbb{1}[R^*(x, y_j) R^*(x, y_i)] \\log P(y_j \\succ y_i) \\big)\\\\ \\text{where }P(y_i \\succ y_j) =\u0026 \\frac{\\exp(R_\\psi(x, y_i))}{\\exp(R_\\psi(x, y_i)) + \\exp(R_\\psi(x, y_j))} \\end{aligned} $$ Their experiments showed that the preference loss achieves the best performance, where the reward model is a thin MLP layer on top of BERT sentence embedding.\nZiegler et al (2019) collected human labels by asking humans to select the best candidate $y_b$ out of a few options $\\{y_i\\}$ given the input $x \\sim \\mathcal{D}$. The candidates are sampled by $y_0, y_1 \\sim p(.\\vert x), y_2, y_3 \\sim \\pi(.\\vert x)$. We should be aware that human labeling might have very high disagreement when the ground truth is fuzzy.\nFig. 18. The overview of the training framework for fine-tuning a language model policy with reward learned from human feedback. (Image source: Ziegler et al., 2019) The reward model is implemented by a pretrained language model with an extra random linear layer of the final embedding output. It it trained to minimize the loss:\n $$ \\mathcal{L}_\\text{rm} = -\\mathbb{E}_{(x, \\{y_i\\}, b) \\sim \\mathcal{D}} \\Big[ \\log \\frac{\\exp(R_\\psi(x, y_b))}{\\sum_i \\exp(R_\\psi(x, y_i))} \\Big] $$ To keep the scale consistent during training, the reward model is normalized to have mean 0 and variance 1.\nDuring RL fine-tuning, the policy $\\pi$, initialized by a pretrained language model $p$, is optimized via PPO with the above learned reward model. To avoid the policy\u0026rsquo;s deviating from its original behavior too much, a KL penalty is added:\n $$ R(x, y) = R_\\psi(x, y) - \\beta\\log\\frac{\\pi(y \\vert x)}{p(y \\vert x)} $$ If running online data collection, human label collection process is continued during RL fine-tuning and thus the human labelers can review results generated by the latest policy. The number of human labels are evenly spread out during the training process. Meanwhile the reward model is also retrained periodically. Online data collection turns out to be important for the summarization task but not for the text continuation task. In their experiments, jointly training the reward model and the policy with shared parameters did not work well and can lead to overfitting due to the big imbalance between dataset sizes.\nIn the following work (Stiennon et al., 2020), the human label collection was further simplified to select the best option between a pair of summaries, $y_b \\in\\{y_0, y_1\\}$ The reward model loss was updated to optimize the log odds of the selected summary:\n $$ \\mathcal{L}_\\text{rm} = \\mathbb{E}_{(x, y_0, y_1, b)\\sim\\mathcal{D}} [\\log(\\sigma(r_\\theta(x, y_b) − r_\\theta(x, y_{1−b})))] $$ Fig. 19. The overview of fine-tuning the language model policy from human feedback for summarization, including (1) human feedback collection, (2) reward model training, and (3) policy training. (Image source: Stiennon et al., 2020) Guided Fine-tuning with Steerable Layer Instead of fine-tuning the entire model, only fine-tuning a small extra set of parameters while the base model stays fixed is computationally cheaper.\nIn computer vision, plug-and-play generative networks (PPGN; Nguyen et al., 2017) generate images with different attributes by plugging a discriminator $p(a \\vert x)$ into a base generative model $p(x)$. Then the sample with a desired attribute $a$ can be sampled from $p(x \\vert a) \\propto p(a \\vert x)p(x)$. Inspired by PPGN, the plug-and-play language model (PPLM; Dathathri et al., 2019) combines one or multiple simple attribute models with a pretrained language model for controllable text generation.\nGiven an attribute $a$ and the generated sample $x$, let an attribute model be $p(a\\vert x)$. To control content generation, the current latent representation at time $t$, $H_t$ (containing a list of key-value pairs per layer), can be shifted by $\\Delta H_t$ in the direction of the sum of two gradients:\n One toward higher log-likelihood of the attribute $a$ under $p(a \\vert x)$ \u0026mdash; so that the output content acquires a desired attribute. The other toward higher log-likelihood of the unmodified language model $p(x)$ \u0026mdash; so that the generated text is still in fluent and smooth natural language. To shift the output, at decoding time, PPLM runs one forward → one backward → one forward, three passes in total:\n First a forward pass is performed to compute the likelihood of attribute $a$ by $p(a\\vert x)$; Let $\\Delta H_t$ be a stepwise update to the hidden state $H_t$ such that $(H_t + \\Delta H_t)$ shifts the distribution of generated text closer to having the attribute $a$. $\\Delta H_t$ is initialized at zero. Then a backward pass updates the LM hidden states using normalized gradients from the attribute model $\\nabla_{\\Delta H_t} \\log p(a \\vert H_t + \\Delta H_t)$ as $$ \\Delta H_t \\leftarrow \\Delta H_t + \\alpha \\frac{\\nabla_{\\Delta H_t} \\log p(a|H_t + \\Delta H_t)}{\\| \\nabla_{\\Delta H_t} \\log p(a|H_t + \\Delta H_t) \\|^\\gamma} $$ where $\\gamma$ is a normalization scaling coefficient, set per layer. $\\alpha$ is step size. This update can be repeated $m \\in [3, 10]$ times 3. The final forward pass recomputes a new distribution over the vocabulary, generated from the updated latents $\\tilde{H}_t = H_t + \\Delta H_t$. The next token is sampled from the updated distribution.\nFig. 20. The overview of how PPLM runs three passes to update the model output to increase the likelihood of a desired attribute. (Image source: Dathathri et al., 2019) Multiple attribute models can be mix-and-matched during generation with customized weights, acting as a set of \u0026ldquo;control knobs\u0026rdquo;. The PPLM paper explored two types of attribute models:\n The simplest attribution model is based on a predefined bag of words (BoW), $\\{w_1, \\dots, w_k\\}$, that specifies a topic of interest. $$ \\log p(a \\vert x) = \\log\\big( \\sum_{i=1}^k p_{t+1} [w_i] \\big) $$ To encourage the model to output the desired words at least once but not at every step, they normalize the gradient by the maximum gradient norm. Interestingly, they found that increasing the probability of generating words in the bag also increases the probability of generating related but not identical words about the same topic. 2. The discriminator attribute models are based on learned classifiers which define preferences by a distribution instead of hard samples.\nTo ensure the fluency in language, PPLM applied two additional designs:\n Minimizing the KL diverge between modified and unmodified LM, commonly seen in other RL fine-tuning approaches (see above). It performs post-norm fusion to constantly tie the generated text to the unconditional LM $p(x)$, $x_{t+1} \\sim \\frac{1}{\\beta}(\\tilde{p}_{t+1}^{\\gamma_\\text{gm}} p_{t+1}^{1-\\gamma_\\text{gm}})$, where $p_{t+1}$ and $\\tilde{p}_{t+1}$ are the unmodified and modified output distributions, respectively. $\\beta$ is a normalizing factor. $\\gamma_\\text{gm} \\in [0.8, 0.95]$ balances between prediction from before and after models. Fig. 21. Examples of controllable text generation by PPLM. (Image source: Dathathri et al., 2019) Interestingly, they found a large variance in the extent of controllability across topics. Some topics (religion, science, politics) are easier to control for compared to others (computers, space).\nOne obvious drawback of PPLM is that due to multiple passes at every decoding step, the test time computation becomes much more expensive.\nSimilar to PPLM, DELOREAN (DEcoding for nonmonotonic LOgical REAsoNing; Qin et al., 2020) incorporates the future context by back-propagation. Given input text $\\mathbf{x}$, DELOREAN aims to generate continuation completion $\\mathbf{y} = [y_1, \\dots, y_N]$ such that $y$ satisfies certain constraints defined by a context $z$. To keep the generation differentiable, a soft representation of $y$ is tracked, $\\tilde{\\mathbf{y}}=(\\tilde{y}_1, \\dots, \\tilde{y}_N)$ where $\\tilde{y}_i \\in \\mathbb{R}^V$ are logits over the vocabulary. $\\tilde{\\mathbf{y}}^{(t)}$ is the soft representation at iteration $t$.\nGiven the representation $\\tilde{y}^{(t-1)}$ at iteration $t$, it runs the following procedures:\n Backward: The constraint is represented as a loss function $\\mathcal{L}(\\mathbf{x}, \\tilde{\\mathbf{y}}^{(t-1)}, z))$. The logits are updated via gradient descent: $\\tilde{y}^{(t), b}_n = \\tilde{y}_n^{(t-1)} - \\lambda \\nabla_{\\tilde{y}_n} \\mathcal{L}(\\mathbf{x}, \\tilde{\\mathbf{y}}^{(t-1)}, z)$. Forward: Run forward pass to ensure the generated text is fluent. $\\tilde{y}^{(t),f}_n = \\text{LM}(\\mathbf{x}, \\tilde{\\mathbf{y}}^{(t)}_{1:n-1})$. Then linearly combine two logits together to create a new representation $\\tilde{y}^{(t)}_n = \\gamma \\tilde{y}^{(t), f}_n + (1-\\gamma) \\tilde{y}^{(t), b}_n$. Note that each $\\tilde{y}^{(t)}_n$ is needed to sample the next $\\tilde{y}^{(t),f}_{n+1}$. Side-tuning (Zhang et al., 2019) trains a light-weighted side network that learns a residual on top of the original model outputs without modifying the pre-trained model weights. Unlike PPLM, no gradient update is applied on the hidden states. It is a simple yet effective approach for incremental learning. The base model is treated as a black-box model and does not necessarily have to be a neural network. Side-tuning setup assumes the base and side models are fed exactly the same input and the side model is independently learned.\nFig. 22. Comparison of fixed weights, fine-tuning and side-tuning. (Image source: Zhang et al., 2019) The paper explored different strategies of fusing predictions from the base and side models: product is the worst while sum ($\\alpha$-blending), MLP, and FiLM are comparable. Side-tuning is able to achieve better performance, when it is trained with intermediate amounts of data and when the base network is large.\nAuxiliary tuning (Zeldes et al., 2020) supplements the original pre-trained model with an auxiliary model that shifts the output distribution according to the target task. The base and auxiliary model outputs are merged on the logits level. The combined model is trained to maximize the likelihood $p(x_t\\vert x_{\u0026lt;t}, z)$ of target output.\nThe conditional probability of $p(x_t\\vert x_{\u0026lt;t}, z)$ can be decomposed into two parts:\n $p(x_t\\vert x_{\u0026lt;t})$ assigns high probabilities to fluent sequences of tokens; a shift on $p(x_t\\vert x_{\u0026lt;t})$ towards $p(x_t\\vert x_{\u0026lt;t}, z)$. $$ p(x_t\\vert x_{By Bayesian rule, we have\n $$ p(x_t\\vert x_{And therefore the auxiliary model $\\text{logits}_\\text{aux}(x_t \\vert x_{\u0026lt;t}, z))$ effectively should learn to predict $p(z \\vert x_{\\leq t})$. In the experiments of Zeldes et al., 2020, the auxiliary model can re-use the intermediate layers of the pre-trained LM for feature extraction.\nFig. 23. The auxiliary model is trained by reusing features extracted from multiple layers of the base model. (Image source: Zeldes et al., 2020) GeDi (Kruse et al., 2020) guides the text generation by Generative Discriminator. The discriminator is implemented as a class conditional language model (CC-LM), $p_\\theta(x_{1:t} \\vert z)$. The discriminator guides generation at each decoding step by computing classification probabilities for all possible next tokens via Bayes rule by normalizing over two contrastive class-conditional distributions:\n One conditioned on the control code $z$ for desired attribute. The other conditioned on the anti-control code $\\bar{z}$ for undesired attributes. GeDi relies on the contract between $p_\\theta(x_{1:t} \\vert z)$ and $p_\\theta(x_{1:t} \\vert \\bar{z})$ to compute the probability of the sequence belonging to the desired class. The discriminator loss is to maximize the probability of desired attribute $z$:\n $$ \\begin{aligned} p_\\theta(z \\vert x_{1:t}) \u0026= \\frac{p(z) p_\\theta(x_{1:\\tau} \\vert z)^{\\alpha/\\tau}}{\\sum_{z' \\in \\{z, \\bar{z}\\}} p(z') p_\\theta(x_{1:\\tau} \\vert z')^{\\alpha/\\tau} } \\\\ \\mathcal{L}_\\text{desc} \u0026= -\\frac{1}{N} \\sum_{i=1}^N \\log p_\\theta(z^{(i)} \\vert x^{(i)}_{1:\\tau_i}) \\\\ \u0026= -\\frac{1}{N} \\sum_{i=1}^N \\log \\frac{p(z) p_\\theta(x^{(i)}_{1:\\tau_i} \\vert z^{(i)})^{\\alpha/t_i}}{\\sum_{z' \\in \\{z, \\bar{z}\\} } p(z')p_\\theta(x^{(i)}_{1:\\tau_i} \\vert z')^{\\alpha/\\tau_i}} \\end{aligned} $$ where $p(z) = \\exp(b_z) / \\sum_{z'} \\exp(b_{z'})$ and $b_z$ is a learned class prior. The probabilities are normalized by the current sequence length $\\tau$ to robustify generation sequences of variable lengths. $\\tau_i$ is the sequence length of the $i$-th input $x^{(i)}$ in the dataset.\nFig. 24. An illustration of how GeDi works via Bayesian rule. (Image source: Kruse et al., 2020) They finetuned a GPT2-medium model with control code similar to how CTRL is trained to form a CC-LM using a linear combination of discriminative loss and generative loss. This discriminator model is then used as GiDe to guide generation by a larger language model like GPT2-XL.\nOne way of decoding from GeDi is to sample from a weighted posterior $p^w(x_{t+1}\\vert x_{1:t}, z) \\propto p(z \\vert x_{1:t+1})^w p(x_{t+1} \\vert x_{1:t})$ where $w\u0026gt;1$ applies additional bias toward the desired class $z$. In the sampling process, only tokens with the class or next-token probability larger than a certain threshold are selected.\nGeDi guided generation in their experiments showed strong controllability and ran 30x faster than PPLM.\nDistributional Approach Generation with Distributional Control (GDC; Khalifa, et al. 2020) frames controlled text generation as the optimization of a probability distribution with a constraint. It involves two major steps.\nStep 1: Learn a EBM of the target model\nLet\u0026rsquo;s label a pretrained LM as $a$ and a target LM with desired features as $p$. The desired features can be defined by a set of pre-defined real-valued feature functions $\\phi_i(x), i=1,\\dots,k$ over $x \\in X$, denoted as a vector $\\boldsymbol{\\phi}$. When sequences $x \\in X$ are sampled according to the desired model $p$, the expectations of features $\\mathbb{E}_{x\\sim p}\\boldsymbol{\\phi}(x)$ should be close to $\\bar{\\boldsymbol{\\mu}}$ , named \u0026ldquo;moment constraints\u0026rdquo;. The feature function $\\phi_i$ can have distinct values (e.g. identity function for binary classifier) or continuous probabilities. In the meantime, the fine-tuned model $p$ should not diverge from $a$ too much by maintaining a small KL divergence measure.\nIn summary, given a pretrained model $a$, we would like to find a target model $p$ such that:\n $$ \\begin{aligned} \\bar{\\boldsymbol{\\mu}} \u0026= \\mathbb{E}_{x\\sim p}\\boldsymbol{\\phi}(x) \\\\ p \u0026= \\arg\\min_{c \\in \\mathcal{C}} D_\\text{KL}(c, a) \\end{aligned} $$ where $\\mathcal{C}$ is the set of all distributions over $X$ that satisfy the moment constraints.\nAccording to theorems in Information Geometry, $p$ can be approximated by an EBM (energy-based model; an unnormalized probability distribution) $P$ in the form of exponential function, such that $p(x) \\propto P(x)$ and $p(x)=\\frac{1}{Z}P(x)$ where $Z=\\sum_x P(x)$. The energy-based model can be approximated by:\n $$ P(x)=a(x)\\exp\\big(\\sum_i \\lambda_i \\phi_i(x)\\big)=a(x)\\exp(\\boldsymbol{\\lambda}\\cdot\\boldsymbol{\\phi}(x)) $$ Let\u0026rsquo;s define importance weight $w(x, \\boldsymbol{\\lambda}) = \\frac{P(x)}{a(x)} = \\exp\\langle\\boldsymbol{\\lambda}\\cdot\\boldsymbol{\\phi}(x)\\rangle$. Given a large number of sequences sampled from the pretrained model $x_1, \\dots, x_N \\sim a(x)$,\n $$ \\begin{aligned} \\mu(\\boldsymbol{\\lambda}) \u0026= \\mathbb{E}_{x\\sim p}\\boldsymbol{\\phi}(x) = \\mathbb{E}_{x\\sim a} \\frac{p(x)}{a(x)}\\boldsymbol{\\phi}(x) = \\frac{1}{Z}\\mathbb{E}_{x\\sim a} w(x, \\boldsymbol{\\lambda}) \\boldsymbol{\\phi}(x) \\\\ \u0026= \\frac{\\mathbb{E}_{x\\sim a} w(x, \\boldsymbol{\\lambda}) \\boldsymbol{\\phi}(x)}{\\sum_{x\\in X} P(x)} = \\frac{\\mathbb{E}_{x\\sim a} w(x, \\boldsymbol{\\lambda}) \\boldsymbol{\\phi}(x)}{\\sum_{x\\in X} w(x, \\boldsymbol{\\lambda})a(x)} = \\frac{\\mathbb{E}_{x\\sim a} w(x, \\boldsymbol{\\lambda}) \\boldsymbol{\\phi}(x)}{\\mathbb{E}_{x\\sim a} w(x, \\boldsymbol{\\lambda})} \\\\ \u0026\\simeq \\frac{\\sum_{i=1}^N w(x_i,\\boldsymbol{\\lambda}) \\boldsymbol{\\phi}(x_i)}{\\sum_{i=1}^N w(x_i, \\boldsymbol{\\lambda})} = \\frac{\\sum_{i=1}^N \\exp\\langle\\boldsymbol{\\lambda}\\cdot\\boldsymbol{\\phi}(x)\\rangle \\boldsymbol{\\phi}(x_i)}{\\sum_{i=1}^N \\exp\\langle\\boldsymbol{\\lambda}\\cdot\\boldsymbol{\\phi}(x)\\rangle} \\end{aligned} $$ Using SGD over the objective $|\\boldsymbol{\\mu}(\\boldsymbol{\\lambda}) - \\bar{\\boldsymbol{\\mu}}|^2_2$, we can obtain an estimated value for $\\boldsymbol{\\lambda}$ and a representation of $P(x)=a(x)\\exp\\langle\\boldsymbol{\\lambda}\\cdot\\boldsymbol{\\phi}(x)\\rangle$. $P(x)$ is a sequential EBM because $a$ is an autoregressive model.\nStep 2: Learn the target probability distribution\nThe EBM $P(x)$ can compute ratios of probabilities of two sequences, but cannot sample from $p(x)$ with knowing $Z$. In order to sample from a sequential EBM, the paper proposed to use Distributional Policy Gradient (DPG; but not this DPG) with the objective to obtain an autoregressive policy $\\pi_\\theta$ to approximate a target distribution $p$ by minimizing the cross entropy $H(p, \\pi_\\theta)$. DPG runs through a sequence of iterations. Within each iteration, the proposed distribution $q$ is used for sampling and we can correct the cross entropy loss with importance weights too:\n $$ \\begin{aligned} \\nabla_\\theta H(p, \\pi_\\theta) \u0026= - \\nabla_\\theta \\mathbb{E}_{x\\sim p} \\log \\pi_\\theta(x) = - \\mathbb{E}_{x\\sim p} \\nabla_\\theta \\log \\pi_\\theta(x) \\\\ \u0026= - \\mathbb{E}_{x\\sim q} \\frac{p(x)}{q(x)} \\nabla_\\theta \\log \\pi_\\theta(x) = - \\frac{1}{Z}\\mathbb{E}_{x\\sim q} \\frac{P(x)}{q(x)} \\nabla_\\theta \\log \\pi_\\theta(x) \\end{aligned} $$ To learn such a $\\pi_\\theta$, the paper adopts a KL-adaptive version of DPG: It only updates $q$ when the estimated policy $\\pi_\\theta$ gets closer to $p$. This adaptive step is important for fast convergence.\nFig. 25. The algorithm of distributional policy gradient to make it possible to sample from a EBM $P(x)$, where $q$ is initialized to be $a$. (Image source: Khalifa, et al. 2020) This approach can be used to model various constraints in controllable text generation:\n Pointwise constraints: $\\phi_i$ is a binary feature; such as constraining the presence or absence of words, or classifier-based constraints. Distributional constraints: $\\phi_i$ represents a probability distribution; such as constraining the probability of gender, topic, etc. Their experiments showed great progress in debiasing a GPT-2 model that was trained on Wikipedia Biographies corpus. The percentage of generated biographies on females increased from 7.4% to 35.6%. Hybrid constraints: combine multiple constraints by simply summing them up. Fig. 26. Debiasing experiments using GDC with various constraints. (Image source: Khalifa, et al. 2020) Compared to other baselines, GDC using pointwise constraints diverges less from the base model $a$ and produces smoother curves.\nFig. 27. Compare pointwise constrained GDC with several baselines. Low Self-BLEU-5 and high Dist-1 indicate high diversity. (Image source: Khalifa, et al. 2020) REINFORCE that optimizes the reward $\\phi$ directly ($\\text{REINFORCE}$ in Fig. X.) without constraints converges fast but has a high deviation from the original model. REINFORCE that optimizes $P(x)$ ($\\text{REINFORCE}_{P(x)}$ in Fig. X.) has low sample diversity. Compared to Ziegler et al., 2019 GDC has smoother learning curves and produces a richer vocabulary. Unlikelihood Training The standard way of maximizing the log-likelihood loss in language model training leads to incorrect token distribution, which cannot be fixed with only smart decoding methods. Such models tend to output high-frequency words too often and low-frequency words too rarely, especially when using deterministic decoding (e.g. greedy, beam search). In other words, they are overconfident in their predictions.\nUnlikelihood training (Welleck \u0026amp; Kulikov et al. 2019] tries to combat this and incorporates preference to unwanted content into the training objective directly. It combines two updates:\n A routine maximized likelihood update to assign true tokens with high probability; A new type of unlikelihood update to avoid unwanted tokens with high probability. Given a sequence of tokens $(x_1, \\dots, x_T)$ and a set of negative candidate tokens $\\mathcal{C}^t = \\{c_1, \\dots , c_m\\}$ at step $t$, where each token $x_i, c_j \\in \\mathcal{V}$, the combined loss for step $t$ is defined as:\n $$ \\mathcal{L}^t_\\text{UL}(p_\\theta (. \\vert x_{One approach for constructing $\\mathcal{C}^t$ is to randomly select candidates from model-generated sequences.\nThe unlikelihood training can be extended to be on the sequence-level, where the negative continuation is defined by a sequence of per-step negative candidate sets. They should be designed to penalize properties that we don\u0026rsquo;t like. For example, we can penalize repeating n-grams as follows:\n $$ \\mathcal{C}^t_\\text{repeat-n} = \\{x_t\\} \\text{ if }(x_{t-i}, \\dots, x_{t+j}) \\in x_{Their experiments used unlikelihood training to avoid repetitions in language model outputs and indeed showed better results on less repetition and more unique tokens compared to standard MLE training.\nCitation Cited as:\n Weng, Lilian. (Jan 2021). Controllable neural text generation. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2021-01-02-controllable-text-generation/.\n Or\n@article{weng2021conditional, title = \u0026quot;Controllable Neural Text Generation.\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2021\u0026quot;, month = \u0026quot;Jan\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2021-01-02-controllable-text-generation/\u0026quot; } References [1] Patrick von Platen. \u0026ldquo;How to generate text: using different decoding methods for language generation with Transformers\u0026rdquo; Hugging face blog, March 18, 2020.\n[2] Angela Fan, et al. \u0026ldquo;Hierarchical Neural Story Generation/\u0026quot; arXiv preprint arXiv:1805.04833 (2018).\n[3] Ari Holtzman et al. \u0026ldquo;The Curious Case of Neural Text Degeneration.\u0026quot; ICLR 2020.\n[4] Marjan Ghazvininejad et al. \u0026ldquo;Hafez: an interactive poetry generation system.\u0026quot; ACL 2017.\n[5] Ari Holtzman et al. \u0026ldquo;Learning to write with cooperative discriminators.\u0026quot; ACL 2018.\n[6] Ashutosh Baheti et al. \u0026ldquo;Generating More Interesting Responses in Neural Conversation Models with Distributional Constraints.\u0026quot; EMNLP 2018.\n[7] Jiatao Gu et al. \u0026ldquo;Trainable greedy decoding for neural machine translation.\u0026quot; EMNLP 2017.\n[8] Kyunghyun Cho. \u0026ldquo;Noisy Parallel Approximate Decoding for Conditional Recurrent Language Model.\u0026quot; arXiv preprint arXiv:1605.03835. (2016).\n[9] Marco Tulio Ribeiro et al. \u0026ldquo;Semantically equivalent adversarial rules for debugging NLP models.\u0026quot; ACL 2018.\n[10] Eric Wallace et al. \u0026ldquo;Universal Adversarial Triggers for Attacking and Analyzing NLP.\u0026quot; EMNLP 2019. [code]\n[11] Taylor Shin et al. \u0026ldquo;AutoPrompt: Eliciting Knowledge from Language Models with Automatically Generated Prompts.\u0026quot; EMNLP 2020. [code]\n[12] Zhengbao Jiang et al. \u0026ldquo;How Can We Know What Language Models Know?\u0026quot; TACL 2020.\n[13] Nanyun Peng et al. \u0026ldquo;Towards Controllable Story Generation.\u0026quot; NAACL 2018.\n[14] Nitish Shirish Keskar, et al. \u0026ldquo;CTRL: A Conditional Transformer Language Model for Controllable Generation\u0026rdquo; arXiv preprint arXiv:1909.05858 (2019).[code]\n[15] Marc’Aurelio Ranzato et al. \u0026ldquo;Sequence Level Training with Recurrent Neural Networks.\u0026quot; ICLR 2016.\n[16] Yonghui Wu et al. \u0026ldquo;Google\u0026rsquo;s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation.\u0026quot; CoRR 2016.\n[17] Romain Paulus et al. \u0026ldquo;A Deep Reinforced Model for Abstractive Summarization.\u0026quot; ICLR 2018.\n[18] Paul Christiano et al. \u0026ldquo;Deep Reinforcement Learning from Human Preferences.\u0026quot; NIPS 2017.\n[19] Sanghyun Yi et al. \u0026ldquo;Towards coherent and engaging spoken dialog response generation using automatic conversation evaluators.\u0026quot; INLG 2019.\n[20] Florian Böhm et al. \u0026ldquo;Better rewards yield better summaries: Learning to summarise without references.\u0026quot; EMNLP 2019. [code]\n[21] Daniel M Ziegler et al. \u0026ldquo;Fine-tuning language models from human preferences.\u0026quot; arXiv preprint arXiv:1909.08593 (2019). [code]\n[22] Nisan Stiennon, et al. \u0026ldquo;Learning to summarize from human feedback.\u0026quot; arXiv preprint arXiv:2009.01325 (2020).\n[23] Sumanth Dathathri et al. \u0026ldquo;Plug and play language models: a simple approach to controlled text generation.\u0026quot; ICLR 2020. [code]\n[24] Jeffrey O Zhang et al. \u0026ldquo;Side-tuning: Network adaptation via additive side networks\u0026rdquo; ECCV 2020.\n[25] Ben Kruse et al. \u0026ldquo;GeDi: Generative Discriminator Guided Sequence Generation.\u0026quot; arXiv preprint arXiv:2009.06367.\n[26] Yoel Zeldes et al. \u0026ldquo;Technical Report: Auxiliary Tuning and its Application to Conditional Text Generatio.\u0026quot; arXiv preprint arXiv:2006.16823.\n[27] Thomas Scialom, et al. \u0026ldquo;Discriminative Adversarial Search for Abstractive Summarization\u0026rdquo; ICML 2020.\n[28] Clara Meister, et al. \u0026ldquo;If beam search is the answer, what was the question?\u0026quot; EMNLP 2020.\n[29] Xiang Lisa Li and Percy Liang. \u0026ldquo;Prefix-Tuning: Optimizing Continuous Prompts for Generation.\u0026quot; arXiv preprint arXiv:2101.00190 (2021).\n[30] Lianhui Qin, et al. \u0026ldquo;Back to the Future: Unsupervised Backprop-based Decoding for Counterfactual and Abductive Commonsense Reasoning.\u0026quot; arXiv preprint arXiv:2010.05906 (2020).\n[31] Muhammad Khalifa, et al. \u0026ldquo;A Distributional Approach to Controlled Text Generation\u0026rdquo; Accepted by ICLR 2021.\n[32] Aditya Grover, et al. \u0026ldquo;Bias correction of learned generative models using likelihood-free importance weighting.\u0026quot; NeuriPS 2019.\n[33] Yuntian Deng et al. \u0026ldquo;Residual Energy-Based Models for Text Generation.\u0026quot; ICLR 2020.\n[34] Brian Lester et al. “The Power of Scale for Parameter-Efficient Prompt Tuning.” arXiv preprint arXiv:2104.08691 (2021).\n[35] Xiao Liu et al. “GPT Understands, Too.” arXiv preprint arXiv:2103.10385 (2021).\n[36] Welleck \u0026amp; Kulikov et al. “Neural Text Generation with Unlikelihood Training” arXiv:1908.04319 (2019).\n","permalink":"https://lilianweng.github.io/posts/2021-01-02-controllable-text-generation/","summary":"[Updated on 2021-02-01: Updated to version 2.0 with several work added and many typos fixed.] [Updated on 2021-05-26: Add P-tuning and Prompt Tuning in the \u0026ldquo;prompt design\u0026rdquo; section.] [Updated on 2021-09-19: Add \u0026ldquo;unlikelihood training\u0026rdquo;.]\nThere is a gigantic amount of free text on the Web, several magnitude more than labelled benchmark datasets. The state-of-the-art language models (LM) are trained with unsupervised Web data in large scale. When generating samples from LM by iteratively sampling the next token, we do not have much control over attributes of the output text, such as the topic, the style, the sentiment, etc.","title":"Controllable Neural Text Generation"},{"content":"[Updated on 2020-11-12: add an example on closed-book factual QA using OpenAI API (beta).\nA model that can answer any question with regard to factual knowledge can lead to many useful and practical applications, such as working as a chatbot or an AI assistant🤖. In this post, we will review several common approaches for building such an open-domain question answering system.\nDisclaimers given so many papers in the wild:\n Assume we have access to a powerful pretrained language model. We do not cover how to use structured knowledge base (e.g. Freebase, WikiData) here. We only focus on a single-turn QA instead of a multi-turn conversation style QA. We mostly focus on QA models that contain neural networks, specially Transformer-based language models. I admit that I missed a lot of papers with architectures designed specifically for QA tasks between 2017-2019😔 What is Open-Domain Question Answering? Open-domain Question Answering (ODQA) is a type of language tasks, asking a model to produce answers to factoid questions in natural language. The true answer is objective, so it is simple to evaluate model performance.\nFor example,\nQuestion: What did Albert Einstein win the Nobel Prize for? Answer: The law of the photoelectric effect. The \u0026ldquo;open-domain\u0026rdquo; part refers to the lack of the relevant context for any arbitrarily asked factual question. In the above case, the model only takes as the input the question but no article about \u0026ldquo;why Einstein didn\u0026rsquo;t win a Nobel Prize for the theory of relativity\u0026rdquo; is provided, where the term \u0026ldquo;the law of the photoelectric effect\u0026rdquo; is likely mentioned. In the case when both the question and the context are provided, the task is known as Reading comprehension (RC).\nAn ODQA model may work with or without access to an external source of knowledge (e.g. Wikipedia) and these two conditions are referred to as open-book or closed-book question answering, respectively.\nWhen considering different types of open-domain questions, I like the classification by Lewis, et al., 2020, in increasing order of difficulty:\n A model is able to correctly memorize and respond with the answer to a question that has been seen at training time. A model is able to answer novel questions at test time and choose an answer from the set of answers it has seen during training. A model is able to answer novel questions which have answers not contained in the training dataset. Fig. 1. Overview of three frameworks discussed in this post. Notation Given a question $x$ and a ground truth answer span $y$, the context passage containing the true answer is labelled as $z \\in \\mathcal{Z}$, where $\\mathcal{Z}$ is an external knowledge corpus. Wikipedia is a common choice for such an external knowledge source.\nConcerns of QA data fine-tuning Before we dive into the details of many models below. I would like to point out one concern of fine-tuning a model with common QA datasets, which appears as one fine-tuning step in several ODQA models. It could be concerning, because there is a significant overlap between questions in the train and test sets in several public QA datasets.\nLewis, et al., (2020) (code) found that 58-71% of test-time answers are also present somewhere in the training sets and 28-34% of test-set questions have a near-duplicate paraphrase in their corresponding training sets. In their experiments, several models performed notably worse when duplicated or paraphrased questions were removed from the training set.\nOpen-book QA: Retriever-Reader Given a factoid question, if a language model has no context or is not big enough to memorize the context which exists in the training dataset, it is unlikely to guess the correct answer. In an open-book exam, students are allowed to refer to external resources like notes and books while answering test questions. Similarly, a ODQA system can be paired with a rich knowledge base to identify relevant documents as evidence of answers.\nWe can decompose the process of finding answers to given questions into two stages,\n Find the related context in an external repository of knowledge; Process the retrieved context to extract an answer. Fig. 2. The retriever-reader QA framework combines information retrieval with machine reading comprehension. Such a retriever + reader framework was first proposed in DrQA (\u0026ldquo;Document retriever Question-Answering\u0026rdquo; by Chen et al., 2017; code). The retriever and the reader components can be set up and trained independently, or jointly trained end-to-end.\nRetriever Model Two popular approaches for implementing the retriever is to use the information retrieval (IR) system that depends on (1) the classic non-learning-based TF-IDF features (\u0026ldquo;classic IR\u0026rdquo;) or (2) dense embedding vectors of text produced by neural networks (\u0026ldquo;neural IR\u0026rdquo;).\nClassic IR DrQA (Chen et al., 2017) adopts an efficient non-learning-based search engine based on the vector space model. Every query and document is modelled as a bag-of-word vector, where each term is weighted by TF-IDF (term frequency $\\times$ inverse document frequency).\n $$ \\begin{aligned} \\text{tf-idf}(t, d, \\mathcal{D}) \u0026= \\text{tf}(t, d) \\times \\text{idf}(t, \\mathcal{D}) \\\\ \\text{tf}(t, d) \u0026= \\log(1 + \\text{freq}(t, d)) \\\\ \\text{idf}(t, \\mathcal{D}) \u0026= \\log \\Big( \\frac{\\vert\\mathcal{D}\\vert}{\\vert d\\in\\mathcal{D}: t\\in d\\vert} \\Big) \\end{aligned} $$ where $t$ is a unigram or bigram term in a document $d$ from a collection of documents $\\mathcal{D}$ . $\\text{freq}(t, d)$ measures how many times a term $t$ appears in $d$. Note that the term-frequency here includes bigram counts too, which is found to be very helpful because the local word order is taken into consideration via bigrams. As part of the implementation, DrQA maps the bigrams of $2^{24}$ bins using unsigned murmur3 hash.\nPrecisely, DrQA implemented Wikipedia as its knowledge source and this choice has became a default setting for many ODQA studies since then. The non-ML document retriever returns the top $k=5$ most relevant Wikipedia articles given a question.\nBERTserini (Yang et al., 2019) pairs the open-source Anserini IR toolkit as the retriever with a fine-tuned pre-trained BERT model as the reader. The top $k$ documents ($k=10$) are retrieved via the post-v3.0 branch of Anserini with the query treated as a bag of words. The retrieved text segments are ranked by BM25, a classic TF-IDF-based retrieval scoring function. In terms of the effect of text granularity on performance, they found that paragraph retrieval \u0026gt; sentence retrieval \u0026gt; article retrieval.\nFig. 3. An illustration of BERTserini architecture. (Image source: Yang et al., 2019) ElasticSearch + BM25 is used by the Multi-passage BERT QA model (Wang et al., 2019). They found that splitting articles into passages with the length of 100 words by sliding window brings 4% improvements, since splitting documents into passages without overlap may cause some near-boundary evidence to lose useful contexts.\nNeural IR There is a long history in learning a low-dimensional representation of text, denser than raw term-based vectors (Deerwester et al., 1990; Yih, et al., 2011). Dense representations can be learned through matrix decomposition or some neural network architectures (e.g. MLP, LSTM, bidirectional LSTM, etc). When involving neural networks, such approaches are referred to as \u0026ldquo;Neural IR\u0026rdquo;, Neural IR is a new category of methods for retrieval problems, but it is not necessary to perform better/superior than classic IR (Lim, 2018).\nAfter the success of many large-scale general language models, many QA models embrace the following approach:\n $$ h_x = E_x(x)\\quad h_z = E_z(z)\\quad \\text{score}(x, z) = h_x^\\top h_z $$ Extract the dense representations of a question $x$ and a context passage $z$ by feeding them into a language model; Use the dot-product of these two representations as the retrieval score to rank and select most relevant passages. ORQA, REALM and DPR all use such a scoring function for context retrieval, which will be described in detail in a later section on the end-to-end QA model.\nAn extreme approach, investigated by DenSPI (\u0026ldquo;Dense-Sparse Phrase Index\u0026rdquo;; Seo et al., 2019), is to encode all the text in the knowledge corpus at the phrase level and then only rely on the retriever to identify the most relevant phrase as the predicted answer. In this way, the retriever+reader pipeline is reduced to only retriever. Of course, the index would be much larger and the retrieval problem is more challenging.\nDenSPI introduces a query-agnostic indexable representation of document phrases. Precisely it encodes query-agnostic representations of text spans in Wikipedia offline and looks for the answer at inference time by performing nearest neighbor search. It can drastically speed up the inference time, because there is no need to re-encode documents for every new query, which is often required by a reader model.\nGiven a question $x$ and a fixed set of (Wikipedia) documents, $z_1, \\dots, z_K$ and each document $z_k$ contains $N_k$ words, $z_k = \\langle z_k^{(1)}, \\dots, z_k^{(N_k)}\\rangle$. An ODQA model is a scoring function $F$ for each candidate phrase span $z_k^{(i:j)}, 1 \\leq i \\leq j \\leq N_k$, such that the truth answer is the phrase with maximum score: $y = {\\arg\\max}_{k,i,j} F(x, z_k^{(i:j)})$.\nThe phrase representation $z_k^{(i:j)}$ combines both dense and sparse vectors, $z_k^{(i:j)} = [d_k^{(i:j)}, s_k^{(i:j)}] \\in \\mathbb{R}^{d^d + d^s}$ (note that $d^d \\ll d^s$):\n The dense vector $d_k^{(i:j)}$ is effective for encoding local syntactic and semantic cues, as what can be learned by a pretrained language model. The sparse vector $s_k^{(i:j)}$ is superior at encoding precise lexical information. The sparse vector is term-frequency-based encoding. DenSPI uses 2-gram term-frequency same as DrQA, resulting a highly sparse representation ($d^s \\approx 16$M) The dense vector $d^{(i:j)}$ is further decomposed into three parts, $d^{(i:j)} = [a_i, b_j, c_{ij}] \\in \\mathbb{R}^{2d^b + 1}$ where $2d^b + 1 = d^d$. All three components are learned based on different columns of the fine-tuned BERT representations.\n A vector $a_i$ encodes the start position for the $i$-th word of the document; A vector $b_j$ encodes the end position for the $j$-th word of the document; A scalar $c_{ij}$ measures the coherency between the start and the end vectors, helping avoid non-constituent phrases during inference. For all possible $(i,j,k)$ tuples where $j-i \u0026lt; J$, the text span embeddings are precomputed and stored as a phrase index. The maximum span length $J$ is a predefined scalar constant.\nFig. 4. An illustration of Dense-Sparse Phrase Index (DenSPI) architecture. (Image source: Seo et al., 2019) At the inference time, the question is mapped into the same vector space $x=[d', s'] \\in \\mathbb{R}^{d^d + d^s}$, where the dense vector $d'$ is extracted from the BERT embedding of the special [CLS] symbol. The same BERT model is shared for encoding both questions and phrases. The final answer is predicted by $k^*, i^*, j^* = \\arg\\max x^\\top z_k^{(i:j)}$.\nReader Model The reader model learns to solve the reading comprehension task \u0026mdash; extract an answer for a given question from a given context document. Here we only discuss approaches for machine comprehension using neural networks.\nBi-directional LSTM The reader model for answer detection of DrQA (Chen et al., 2017) is a 3-layer bidirectional LSTM with hidden size 128. Every relevant paragraph of retrieved Wikipedia articles is encoded by a sequence of feature vector, $\\{\\tilde{\\mathbf{z}}_1, \\dots, \\tilde{\\mathbf{z}}_m \\}$. Each feature vector $\\hat{\\mathbf{z}}_i \\in \\mathbb{R}^{d_z}$ is expected to capture useful contextual information around one token $z_i$. The feature consists of several categories of features:\n Word embeddings: A 300d Glove word embedding trained from 800B Web crawl data, $f_\\text{embed} = E_g(z_i)$. Exact match: Whether a word $z_i$ appears in the question $x$, $f_\\text{match} = \\mathbb{I}(z_i \\in x)$. Token features: This includes POS (part-of-speech) tagging, NER (named entity recognition), and TF (term-frequency), $f_\\text{token}(z_i) = (\\text{POS}(z_i), \\text{NER}(z_i), \\text{TF}(z_i))$. Aligned question embedding: The attention score $y_{ij}$ is designed to capture inter-sentence matching and similarity between the paragraph token $z_i$ and the question word $x_j$. This feature adds soft alignments between similar but non-identical words. $$ \\begin{aligned} f_\\text{align}(z_i) \u0026= \\sum_j y_{i,j} E_g(x_j) \\\\ y_{i,j} \u0026= \\frac{\\exp(\\alpha(E_g(z_i))^\\top \\alpha(E_g(x_j)) )}{\\sum_{j'} \\exp(\\alpha(E_g(z_i))^\\top \\alpha(E_g(x_{j'})) ) } \\end{aligned} $$ where $\\alpha$ is a single dense layer with ReLU and $E_g(.)$ is the glove word embedding.\nThe feature vector of a paragraph of $m$ tokens is fed into LSTM to obtain the final paragraph vectors:\n $$ \\begin{aligned} \\mathbf{z} = \\{\\mathbf{z}_1, \\dots, \\mathbf{z}_m\\} \u0026= \\text{LSTM}(\\{\\tilde{\\mathbf{z}}_1, \\dots, \\tilde{\\mathbf{z}}_m\\}) \\\\ \\text{where } \\tilde{\\mathbf{z}}_i \u0026= \\{f_\\text{embed}, f_\\text{match}, f_\\text{token}, f_\\text{align}\\} \\end{aligned} $$ The question is encoded as a weighted sum of the embeddings of every word in the question:\n $$ \\mathbf{x} = \\sum_j b_j E(x_j) \\quad b_j = \\text{softmax}(\\mathbf{w}^\\top E(x_j)) $$ where $\\mathbf{w}$ is a weight vector to learn.\nOnce the feature vectors are constructed for the question and all the related paragraphs, the reader needs to predict the probabilities of each position in a paragraph to be the start and the end of an answer span, $p_\\text{start}(i_s)$ and $p_\\text{end}(i_s)$, respectively. Across all the paragraphs, the optimal span is returned as the final answer with maximum $p_\\text{start}(i_s) \\times p_\\text{end}(i_e) $.\n $$ \\begin{aligned} p_\\text{start}(i_s) \\propto \\exp(\\mathbf{z}_{i_s} \\mathbf{W}_s \\mathbf{x}) \\\\ p_\\text{end}(i_e) \\propto \\exp(\\mathbf{z}_{i_e} \\mathbf{W}_e \\mathbf{x}) \\\\ \\text{ s.t. } i_s \\leq i_e \\leq i_s + 15 \\end{aligned} $$ where $\\mathbf{W}_s$ and $\\mathbf{W}_e$ are learned parameters.\nBERT-universe Following the success of BERT (Devlin et al., 2018), many QA models develop the machine comprehension component based on BERT. Let\u0026rsquo;s define the BERT model as a function that can take one or multiple strings (concatenated by [SEP]) as input and outputs a set of BERT encoding vectors for the special [CLS] token and every input token:\n $$ \\text{BERT}(s_1, s_2, \\dots) = [\\mathbf{h}^\\texttt{[CLS]}, \\mathbf{h}^{(1)}, \\mathbf{h}^{(2)}, \\dots] $$ where $\\mathbf{h}^\\texttt{[CLS]}$ is the embedding vector for the special [CLS] token and $\\mathbf{h}^{(i)}$ is the embedding vector for the $i$-th token.\nTo use BERT for reading comprehension, it learns two additional weights, $\\mathbf{W}_s$ and $\\mathbf{W}_e$, and $\\text{softmax}(\\mathbf{h}^{(i)}\\mathbf{W}_s)$ and $\\text{softmax}(\\mathbf{h}^{(i)}\\mathbf{W}_e)$ define two probability distributions of start and end position of the predicted span per token.\nBERTserini (Yang et al., 2019) utilizes a pre-trained BERT model to work as the reader. Their experiments showed that fine-tuning pretrained BERT with SQuAD is sufficient to achieve high accuracy in identifying answer spans.\nFig. 5. How BERT is used to solve question-answering tasks. (Image source: Devlin et al., 2018) The key difference of the BERTserini reader from the original BERT is: to allow comparison and aggregation of results from different segments, the final softmax layer over different answer spans is removed. The pre-trained BERT model is fine-tuned on the training set of SQuAD, where all inputs to the reader are padded to 384 tokens with the learning rate 3e-5.\nWhen ranking all the extracted answer spans, the retriever score (BM25) and the reader score (probability of token being the start position $\\times$ probability of the same token being the end position ) are combined via linear interpolation.\nThe original BERT normalizes the probability distributions of start and end position per token for every passage independently. Differently, the Multi-passage BERT (Wang et al., 2019) normalizes answer scores across all the retrieved passages of one question globally. Precisely, multi-passage BERT removes the final normalization layer per passage in BERT for QA (same as in BERTserini) and then adds a global softmax over all the word positions of all the passages. Global normalization makes the reader model more stable while pin-pointing answers from a large number of passages.\nIn addition, multi-passage BERT implemented an independent passage ranker model via another BERT model and the rank score for $(x, z)$ is generated by a softmax over the representation vectors of the first [CLS] token. The passage ranker brings in extra 2% improvements. Similar idea of re-ranking passages with BERT was discussed in Nogueira \u0026amp; Cho, 2019, too.\nInterestingly, Wang et al., 2019 found that explicit inter-sentence matching does not seem to be critical for RC tasks with BERT; check the original paper for how the experiments were designed. One possible reason is that the multi-head self-attention layers in BERT has already embedded the inter-sentence matching.\nEnd-to-end Joint Training The retriever and reader components can be jointly trained. This section covers R^3, ORQA, REALM and DPR. There are a lot of common designs, such as BERT-based dense vectors for retrieval and the loss function on maximizing the marginal likelihood of obtaining true answers.\nThe retriever and reader models in the R^3 (\u0026ldquo;Reinforced Ranker-Reader\u0026rdquo;; Wang, et al., 2017) QA system are jointly trained via reinforcement learning. (Note that to keep the term consistent between papers in this section, the \u0026ldquo;ranker\u0026rdquo; model in the original R^3 paper is referred to as the \u0026ldquo;retriever\u0026rdquo; model here.) Both components are variants of Match-LSTM, which relies on an attention mechanism to compute word similarities between the passage and question sequences.\nHow does the Match-LSTM module work? Given a question $\\mathbf{X}$ of $d_x$ words and a passage $\\mathbf{Z}$ of $d_z$ words, both representations use fixed Glove word embeddings,\n $$ \\begin{aligned} \\mathbf{H}^x \u0026= \\text{BiLSTM}(\\mathbf{X}) \\in \\mathbb{R}^{l \\times d_x} \\\\ \\mathbf{H}^z \u0026= \\text{BiLSTM}(\\mathbf{Z}) \\in \\mathbb{R}^{l \\times d_z} \\\\ \\mathbf{G} \u0026= \\text{softmax}((\\mathbf{W}^g \\mathbf{H}^x + \\mathbf{b}^g \\otimes \\mathbf{e}_{d_x})^\\top \\mathbf{H}^z) \\in \\mathbb{R}^{d_x \\times d_z} \u0026 \\text{; an attention matrix}\\\\ \\bar{\\mathbf{H}}^x \u0026= \\mathbf{H}^x \\mathbf{G} \\in \\mathbb{R}^{l \\times d_z} \\\\ \\mathbf{M} \u0026= \\text{ReLU} \\Big( \\mathbf{W}^m \\begin{bmatrix} \\mathbf{H}^z \\\\ \\bar{\\mathbf{H}}^x \\\\ \\mathbf{H}^z \\odot \\bar{\\mathbf{H}}^x \\\\ \\mathbf{H}^z - \\bar{\\mathbf{H}}^x \\end{bmatrix} \\Big) \\in \\mathbb{R}^{2l \\times d_z} \\\\ \\mathbf{H}^m \u0026= \\text{BiLSTM}(M) \\in \\mathbb{R}^{l \\times d_z} \\end{aligned} $$ where $l$ is the hidden dimension of the bidirectional LSTM module. $\\mathbf{W}^g \\in \\mathbb{R}^{l\\times l}$, $\\mathbf{b}^g \\in \\mathbb{R}^l$, and $\\mathbf{W}^m \\in \\mathbb{R}^{2l \\times 4l}$ are parameters to learn. The operator $\\otimes \\mathbf{e}_{d_x}$ is the outer product to repeat the column vector $\\mathbf{b}^g$ $d_x$ times.\nThe ranker and reader components share the same Match-LSTM module with two separate prediction heads in the last layer, resulting in $\\mathbf{H}^\\text{rank}$ and $\\mathbf{H}^\\text{reader}$.\nFig. 6. The overview of R^3 (reinforced ranker-reader) architecture. Both components share the same Match-LSTM module. (Image source: Wang, et al., 2017) The retriever runs a max-pooling operation per passage and then aggregates to output a probability of each passage entailing the answer.\n $$ \\begin{aligned} \\mathbf{u}_i \u0026= \\text{max-pooling}(\\mathbf{H}^\\text{rank}_i) \\in \\mathbb{R}^l \\\\ \\mathbf{C} \u0026= \\text{tanh}(\\mathbf{W}^c[\\mathbf{u}_1;\\dots;\\mathbf{u}_N] + \\mathbf{b}^c \\otimes \\mathbf{e}_N) \\in \\mathbb{R}^{l \\times n} \\\\ \\gamma \u0026= \\text{softmax}(\\mathbf{w}^c \\mathbf{C}) \\in \\mathbb{R}^n \\end{aligned} $$ Finally, the retriever is viewed as a policy to output action to sample a passage according to predicted $\\gamma$,\n $$ \\pi(z \\vert x; \\theta^\\gamma) = \\gamma_z $$ The reader predicts the start position $\\beta^s$ and the end position $\\beta^e$ of the answer span. Two positions are computed in the same way, with independent parameters to learn. There are $V$ words in all the passages involved.\n $$ \\begin{aligned} \\mathbf{H}^\\text{read} \u0026= [\\mathbf{H}^\\text{read}_\\tau; \\mathbf{H}^\\text{read}_{\\text{neg}_1}; \\dots; \\mathbf{H}^\\text{read}_{\\text{neg}_n}] \\\\ \\mathbf{F}^s \u0026= \\text{tanh}(\\mathbf{W}^s \\mathbf{H}^\\text{read} + \\mathbf{b}^s \\otimes \\mathbf{e}_V) \\quad \\beta^s = \\text{softmax}(\\mathbf{w}^s \\mathbf{F}^s) \\in \\mathbb{R}^V \\\\ \\mathbf{F}^e \u0026= \\text{tanh}(\\mathbf{W}^e \\mathbf{H}^\\text{read} + \\mathbf{b}^e \\otimes \\mathbf{e}_V) \\quad \\beta^e = \\text{softmax}(\\mathbf{w}^e \\mathbf{F}^e) \\in \\mathbb{R}^V \\\\ L(y \\vert z, x) \u0026= -\\log(\\beta^s_{y_z^s})-\\log(\\beta^e_{y_z^e}) \\end{aligned} $$ where $y$ is the ground-truth answer and the passage $z$ is sampled by the retriever. $\\beta^s_{y_z^s}$ and $\\beta^s_{y_z^e}$ represent the probabilities of the start and end positions of $y$ in passage $z$.\nThe training objective for the end-to-end R^3 QA system is to minimize the negative log-likelihood of obtaining the correct answer $y$ given a question $x$,\n $$ \\begin{aligned} \\mathcal{J}(\\theta) \u0026= -\\mathbb{E}_{z\\sim\\pi(.\\vert x)} [L(y \\vert z, x)] \\\\ \\nabla \\mathcal{J}(\\theta) \u0026= - \\nabla_\\theta \\sum_z \\pi(z \\vert x) L(y \\vert z, x) \\\\ \u0026= - \\sum_z \\big( L(y \\vert z, x) \\nabla_\\theta\\pi(z \\vert x) + \\pi(z \\vert x) \\nabla_\\theta L(y \\vert z, x) \\big) \\\\ \u0026= - \\mathbb{E}_{z\\sim\\pi(.\\vert x)} \\big( \\color{red}{L(y \\vert z, x)\\nabla_\\theta\\log\\pi(z \\vert x)} + \\nabla_\\theta L(y \\vert z, x) \\big) \\\\ \u0026\\approx - \\mathbb{E}_{z\\sim\\pi(.\\vert x)} \\big( \\underbrace{\\color{red}{R(y \\vert z, x)\\nabla_\\theta\\log\\pi(z \\vert x)}}_\\text{REINFORCE} + \\nabla_\\theta L(y \\vert z, x) \\big) \\end{aligned} $$ Essentially in training, given a passage $z$ sampled by the retriever, the reader is trained by gradient descent while the retriever is trained by REINFORCE using $L(y \\vert z, x)$ as the reward function. However, $L(y \\vert z, x)$ is not bounded and may introduce a lot of variance. The paper replaces the reward with a customized scoring function by comparing the ground truth $y$ and the answer extracted by the reader $\\hat{y}$:\n $$ R(y, \\hat{y} \\vert z) = \\begin{cases} 2 \u0026 \\text{if } y = \\hat{y}\\\\ f1(y, \\hat{y}) \u0026 \\text{if } y \\cap \\hat{y} = \\varnothing \\\\ -1 \u0026 \\text{otherwise} \\end{cases} $$ Fig. 7. The workflow of R^3 training process. (Image source: acl2020-openqa-tutorial/slides/part4) ORQA (\u0026ldquo;Open-Retrieval Question-Answering\u0026rdquo;; Lee et al., 2019) jointly learns a retriever + reader QA model to optimize marginal log-likelihood of obtaining correct answers in a supervised manner. No explicit \u0026ldquo;black-box\u0026rdquo; IR system is involved. Instead, it is capable of retrieving any text in an open corpus. During training, ORQA does not need ground-truth context passages (i.e. reading comprehension datasets) but only needs (question, answer) string pairs. Both retriever and reader components are based on BERT, but not shared.\nFig. 8. An illustration of the retriever component in ORQA. (Image source: replotted based on one slide in acl2020-openqa-tutorial/slides/part5) All the evidence blocks are ranked by a retrieval score, defined as the inner product of BERT embedding vectors of the [CLS] token of the question $x$ and the evidence block $z$. Note that the encoders for questions and context are independent.\n $$ \\begin{aligned} h_x \u0026= \\mathbf{W}_x \\text{BERT}_x(x)^{\\mathtt{[CLS]}} \\\\ h_z \u0026= \\mathbf{W}_z \\text{BERT}_z(z)^{\\mathtt{[CLS]}} \\\\ S_\\text{retr}(z, x) \u0026= h_x^\\top h_z \\end{aligned} $$ The retriever module is pretrained with Inverse Cloze Task (ICT), which is to predict the context given a sentence, opposite to the standard Cloze Task. The ICT objective is to maximize the retrieval score of the correct context $z$ given a random sentence $x$:\n $$ L_\\text{ICT} = p_\\text{early}(z \\vert x) = \\frac{\\exp(S_\\text{retr}(z, x))}{\\sum_{z'\\in\\text{BATCH}(\\mathcal{Z})} \\exp(S_\\text{retr}(z', x))} $$ where $\\text{BATCH}(\\mathcal{Z})$ is the set of evidence blocks in the same batch used as sampled negatives.\nAfter such pretraining, the BERT retriever is expected to have representations good enough for evidence retrieval. Only the question encoder needs to be fine-tuned for answer extraction. In other words, the evidence block encoder (i.e., $\\mathbf{W}_z$ and $\\text{BERT}_z$) is fixed and thus all the evidence block encodings can be pre-computed with support for fast Maximum Inner Product Search (MIPS).\nFig. 9. An illustration of the reader component in ORQA. (Image source: acl2020-openqa-tutorial/slides/part5) The reader follows the same design as in the original BERT RC experiments. It learns in a supervised manner, while the parameters of the evidence block encoder are fixed and all other parameters are fine-tuned. Given a question $x$ and a gold answer string $y$, the reader loss contains two parts:\n $$ \\mathcal{L}(x, y) = \\mathcal{L}_\\text{early}(x, y) + \\mathcal{L}_\\text{full}(x, y) $$ (1) Find all correct text spans within top $k$ evidence blocks and optimize for the marginal likelihood of a text span $s$ that matches the true answer $y$:\n $$ \\begin{aligned} h_s \u0026= \\text{BERT}_R(x, y)^{(\\text{START}(s))} \\\\ h_e \u0026= \\text{BERT}_R(x, y)^{(\\text{END}(s))} \\\\ S_\\text{read}(z, s, x) \u0026= \\text{MLP}([h_s; h_e]) \\\\ p(z, s \\vert x) \u0026= \\frac{\\exp(S_\\text{read}(z, s, x))}{\\sum_{z'\\in\\text{TOP}(k)} \\sum_{s'\\in z'} \\exp(S_\\text{read}(z', s', x))} \\\\ L_\\text{full}(x, y) \u0026= - \\log \\sum_{\\substack{z \\in \\text{TOP}(k)\\\\ s \\in z}} \\sum_{y=\\text{TEXT}(s)} p(z, s \\vert x) \\end{aligned} $$ where $y=\\text{TEXT}(s)$ indicates whether the answer $y$ matches the text span $s$. $\\text{TOP}(k)$ is the top $k$ retrieved blocks according to $S_\\text{retr}(z, x)$. The paper sets $k=5$.\n(2) At the early stage of learning, when the retriever is not strong enough, it is possible none of the top $k$ blocks contains the answer. To avoid such sparse learning signals, ORQA considers a larger set of $c$ evidence blocks for more aggressive learning. The paper has $c=5000$.\n $$ L_\\text{early}(x, y) = -\\log \\sum_{\\substack{z\\in \\text{TOP}(c)\\\\y\\in\\text{TEXT}(z)}} p_\\text{early}(z\\vert x) = -\\log \\sum_{\\substack{z\\in \\text{TOP}(c)\\\\y\\in\\text{TEXT}(z)}} \\frac{\\exp(S_\\text{retr}(z, x)}{\\sum_{z'\\in\\text{TOP}(c)} \\exp(S_\\text{retr}(z', x)} $$ Some issues in SQuAD dataset were discussed in the ORQA paper:\n \u0026quot; The notable drop between development and test accuracy for SQuAD is a reflection of an artifact in the dataset\u0026mdash;its 100k questions are derived from only 536 documents. Therefore, good retrieval targets are highly correlated between training examples, violating the IID assumption, and making it unsuitable for learned retrieval. We strongly suggest that those who are interested in end-to-end open-domain QA models no longer train and evaluate with SQuAD for this reason.\u0026quot;\n REALM (\u0026ldquo;Retrieval-Augmented Language Model pre-training\u0026rdquo;; Guu et al., 2020) also jointly trains retriever + reader by optimizing the marginal likelihood of obtaining the true answer:\n $$ p(y \\vert x) = \\sum_{z \\in \\mathcal{Z}} \\underbrace{p(y \\vert x, z)}_\\text{reader} \\underbrace{p(z \\vert x)}_\\text{retriever} \\approx \\sum_{z \\in \\text{TOP}_k(\\mathcal{Z})} p(y \\vert x, z) p(z \\vert x) $$ Fig. 10. REALM is first unsupervised pre-trained with salient spans masking and then fine-tuned with QA data. (Image source: Guu et al., 2020). REALM computes two probabilities, $p(z \\vert x)$ and $p(y \\vert x, z)$, same as ORQA. However, different from ICT in ORQA, REALM upgrades the unsupervised pre-training step with several new design decisions, leading towards better retrievals. REALM pre-trains the model with Wikipedia or CC-News corpus.\n Use salient span masking. Named entities and dates are identified. Then one of these \u0026ldquo;salient spans\u0026rdquo; is selected and masked. Salient span masking is a special case of MLM and works out well for QA tasks. Add an empty null document. Because not every question demands a context document. No trivial retrieval. The context document should not be same as the selected sentence with a masked span. Apply the same ICT loss as in ORQA to encourage learning when the retrieval quality is still poor at the early stage of training. \u0026ldquo;Among all systems, the most direct comparison with REALM is ORQA (Lee et al., 2019), where the fine-tuning setup, hyperparameters and training data are identical. The improvement of REALM over ORQA is purely due to better pre-training methods.\u0026rdquo; \u0026mdash; from REALM paper.\n Both unsupervised pre-training and supervised fine-tuning optimize the same log-likelihood $\\log p(y \\vert x)$. Because the parameters of the retriever encoder for evidence documents are also updated in the process, the index for MIPS is changing. REALM asynchronously refreshes the index with the updated encoder parameters every several hundred training steps.\nBalachandran, et al. (2021) found that REALM is significantly undertrained and REALM++ achieves great EM accuracy improvement (3-5%) by scaling up the model training with larger batch size and more retrieved documents for the reader to process.\nDPR (\u0026ldquo;Dense Passage Retriever\u0026rdquo;; Karpukhin et al., 2020, code) argues that ICT pre-training could be too computationally expensive and the ORQA\u0026rsquo;s context encoder might be sub-optimal because it is not fine-tuned with question-answer pairs. DPR aims to resolve these two issues by only training a dense dual-encoder architecture for retrieval only from a small number of Q/A pairs, without any pre-training.\nSame as previous work, DPR uses the dot-product (L2 distance or cosine similarity also works) of BERT representations as retrieval score. The loss function for training the dual-encoder is the NLL of the positive passage, which essentially takes the same formulation as ICT loss of ORQA. Note that both of them consider other passages in the same batch as the negative samples, named in-batch negative sampling. The main difference is that DPR relies on supervised QA data, while ORQA trains with ICT on unsupervised corpus. At the inference time, DPR uses FAISS to run fast MIPS.\nDPR did a set of comparison experiments involving several different types of negatives:\n Random: any random passage from the corpus; BM25: top passages returned by BM25 which don\u0026rsquo;t contain the answer but match most question tokens; In-batch negative sampling (\u0026ldquo;gold\u0026rdquo;): positive passages paired with other questions which appear in the training set. DPR found that using gold passages from the same mini-batch and one negative passage with high BM25 score works the best. To further improve the retrieval results, DPR also explored a setting where a BM25 score and a dense embedding retrieval score are linearly combined to serve as a new ranking function.\nOpen-book QA: Retriever-Generator Compared to the retriever-reader approach, the retriever-generator also has 2 stages but the second stage is to generate free text directly to answer the question rather than to extract start/end position in a retrieved passage. Some paper also refer to this as Generative question answering.\nFig. 11. The retriever + generator QA framework combines a document retrieval system with a general language model. A pretrained LM has a great capacity of memorizing knowledge in its parameters, as shown above. However, they cannot easily modify or expand their memory, cannot straightforwardly provide insights into their predictions, and may produce non-existent illusion.\nPetroni et al. (2020) studied how the retrieved relevant context can help a generative language model produce better answers. They found:\n Augmenting queries with relevant contexts dramatically improves the pretrained LM on unsupervised machine reading capabilities. An off-the-shelf IR system is sufficient for BERT to match the performance of a supervised ODQA baseline; BERT\u0026rsquo;s NSP pre-training strategy is a highly effective unsupervised mechanism in dealing with noisy and irrelevant contexts. They pair the BERT model with different types of context, including adversarial (unrelated context), retrieved (by BM25), and generative (by an autoregressive language model of 1.4N parameters, trained on CC-NEWS). The model is found to be robust to adversarial context, but only when the question and the context are provided as two segments (e.g. separated by [SEP]). One hypothesis is related to NSP task: \u0026ldquo;BERT might learn to not condition across segments for masked token prediction if the NSP score is low, thereby implicitly detecting irrelevant and noisy contexts.\u0026rdquo;\nRAG (\u0026ldquo;Retrieval-Augmented Generation\u0026rdquo;; Lewis et al., 2020) combines pre-trained parametric (language model) and non-parametric memory (external knowledge index) together for language generation. RAG can be fine-tuned on any seq2seq task, whereby both the retriever and the sequence generator are jointly learned. They found that unconstrained generation outperforms previous extractive approaches.\nRAG consists of a retriever model $p_\\eta(z \\vert x)$ and a generator model $p_\\theta(y_i \\vert x, z, y_{1:i-1})$:\n The retriever uses the input sequence $x$ to retrieve text passages $z$, implemented as a DPR retriever. $\\log p_\\eta(z \\vert x) \\propto E_z(z)^\\top E_x(x)$. The generator uses $z$ as additional context when generating the target sequence $y$, where the context and the question are simply concatenated. Depending on whether using the same or different retrieved documents for each token generation, there are two versions of RAG:\n $$ \\begin{aligned} p_\\text{RAG-seq}(y \\vert x) \u0026= \\sum_{z \\in \\text{TOP}_k(p_\\eta(.\\vert x))} p_\\eta(z \\vert x) \\prod_i^N p_\\theta(y_i \\vert x, z, y_{1:i-1}) \\\\ p_\\text{RAG-token}(y \\vert x) \u0026= \\prod_i^N \\sum_{z \\in \\text{TOP}_k(p_\\eta(.\\vert x))} p_\\eta(z_i\\vert x) p_\\theta(y_i \\vert x, z_i, y_{1:i-1}) \\end{aligned} $$ The retriever + generator in RAG is jointly trained to minimize the NLL loss, $\\mathcal{L}_\\text{RAG} = \\sum_j -\\log p(y_j \\vert x_j)$. Updating the passage encoder $E_z(.)$ is expensive as it requires the model to re-index the documents for fast MIPS. RAG does not find fine-tuning $E_z(.)$ necessary (like in ORQA) and only updates the query encoder + generator.\nFig. 12. An illustration of retrieval-augmented generation (RAG) architecture. (Image source: Lewis et al., 2020) At decoding/test time, RAG-token can be evaluated via a beam search. RAG-seq cannot be broken down into a set of per-token likelihood, so it runs beam search for each candidate document $z$ and picks the one with optimal $p_\\theta(y_i \\vert x, z, y_{1:i-1})$.\nThe Fusion-in-Decoder approach, proposed by Izacard \u0026amp; Grave (2020) is also based on a pre-trained T5. It works similar to RAG but differently for how the context is integrated into the decoder.\n Retrieve top $k$ related passage of 100 words each, using BM25 or DPR. Each retrieved passage and its title are concatenated with the question using special tokens like question:, title: and context: to indicate the content differences. Each retrieved passage is processed independently and later combined in the decoder. Processing passages independently in the encoder allows us to parallelize the computation. OTOH, processing them jointly encourages better aggregation of multiple pieces of evidence. The aggregation part is missing in extractive approaches. Note that they did fine-tune the pretrained LM independently for each dataset.\nClosed-book QA: Generative Language Model Big language models have been pre-trained on a large collection of unsupervised textual corpus. Given enough parameters, these models are able to memorize some factual knowledge within parameter weights. Therefore, we can use these models to do question-answering without explicit context, just like in a closed-book exam. The pre-trained language models produce free text to respond to questions, no explicit reading comprehension.\nFig. 13. The amount of computation used for training big language models of different sizes is getting big. (Image source: Brown et al., 2020). Roberts et al. (2020) measured the practical utility of a language model by fine-tuning a pre-trained model to answer questions without access to any external context or knowledge. They fine-tuned the T5 language model (same architecture as the original Transformer) to answer questions without inputting any additional information or context. Such setup enforces the language model to answer questions based on \u0026ldquo;knowledge\u0026rdquo; that it internalized during pre-training.\nFig. 14. T5 is first pre-trained with salient span masking and then fine-tuned for each QA dataset to produce answers in free text. (Image source: Roberts et al. 2020) The original T5 models were pre-trained on a multi-task mixture including an unsupervised \u0026ldquo;masked language modeling\u0026rdquo; (MLM) tasks on the C4 (\u0026ldquo;Colossal Clean Crawled Corpus\u0026rdquo;) dataset as well as fine-tuned altogether with supervised translation, summarization, classification, and reading comprehension tasks. Roberts, et al. (2020) took a pre-trained T5 model and continued pre-training with salient span masking over Wikipedia corpus, which has been found to substantially boost the performance for ODQA. Then they fine-tuned the model for each QA datasets independently.\nWith a pre-trained T5 language model + continue pre-training with salient spans masking + fine-tuning for each QA dataset,\n It can attain competitive results in open-domain question answering without access to external knowledge. A larger model can obtain better performance. For example, a T5 with 11B parameters is able to match the performance with DPR with 3 BERT-base models, each with 330M parameters. Interestingly, fine-tuning is not strictly necessary. GPT3 (Brown et al., 2020) has been evaluated on the closed book question answering task without any gradient updates or fine-tuning. During evaluation, the few-shot, one-shot and zero-shot settings here only refer to how many demonstrations are provided as context in the text input:\n \u0026ldquo;few-shot learning\u0026rdquo;: GPT3 is allowed to take as many demonstrations as what can fit into the model\u0026rsquo;s context window (typically 10 to 100). \u0026ldquo;one-shot learning\u0026rdquo;: only one demonstration is provided. \u0026ldquo;zero-shot learning\u0026rdquo;: no demonstrations are allowed and only an instruction in natural language is given to the model. The performance grows with the model size. On the TriviaQA dataset, GPT3 evaluation with demonstrations can match or exceed the performance of SOTA baseline with fine-tuning.\nFig. 15. GPT3's performance on TriviaQA grows smoothly with the model size. More demonstrations lead to better performance. (Image source: Brown et al., 2020). Check out this cool example in OpenAI API playground viewer. The model is able to answer factal questions in short answer and not to make up things when the model does not know the answer. I added the last two questions and asked the model to respond with A:. The API is still in beta version, so you might need to apply to get on the wait list.\nQ: Who is Batman? A: Batman is a fictional comic book character. ## Q: What is torsalplexity? A: ? ## Q: What is Devz9? A: ? ## Q: Who is George Lucas? A: George Lucas is American film director and producer famous for creating Star Wars. ## Q: What is the capital of California? A: Sacramento. ## Q: What orbits the Earth? A: The Moon. ## Q: Who is Fred Rickerson? A: ? ## Q: What is an atom? A: An atom is a tiny particle that makes up everything. ## Q: Who is Alvan Muntz? A: ? ## Q: What is Kozar-09? A: ? ## Q: How many moons does Mars have? A: Two, Phobos and Deimos. ## Q: What is COVID-19? A: ? ## Q: What is H1N1? A: H1N1 is a strain of influenza. Related Techniques Fast Maximum Inner Product Search (MIPS) MIPS (maximum inner product search) is a crucial component in many open-domain question answering models. In retriever + reader/generator framework, a large number of passages from the knowledge source are encoded and stored in a memory. A retrieval model is able to query the memory to identify the top relevant passages which have the maximum inner product with the question\u0026rsquo;s embedding.\nWe need fast MIPS because the number of precomputed passage representations can be gigantic. There are several ways to achieve fast MIPS at run time, such as asymmetric LSH, data-dependent hashing, and FAISS.\nLanguage Model Pre-training Two pre-training tasks are especially helpful for QA tasks, as we have discussed above.\n Inverse Cloze Task (proposed by ORQA): The goal of Cloze Task is to predict masked-out text based on its context. The prediction of Inverse Cloze Task (ICT) is in the reverse direction, aiming to predict the context given a sentence. In the context of QA tasks, a random sentence can be treated as a pseudo-question, and its context can be treated as pseudo-evidence.\n Salient Spans Masking (proposed by REALM): Salient span masking is a special case for MLM task in language model training. First, we find salient spans by using a tagger to identify named entities and a regular expression to identify dates. Then one of the detected salient spans is selected and masked. The task is to predict this masked salient span.\n Summary Model Retriever Reader / Generator Pre-training / Fine-tuning End2end DrQA TF-IDF Bi-directional LSTM \u0026ndash; No BERTserini Aserini + BM25 BERT without softmax layer Fine-tune with SQuAD No Multi-passage BERT ElasticSearch + BM25 Multi-passage BERT + Passage ranker No R^3 Classic IR + Match-LSTM Match-LSTM Yes ORQA Dot product of BERT embeddings BERT-RC Inverse cloze task Yes REALM Dot product of BERT embeddings BERT-RC Salient span masking Yes DPR Dot product of BERT embeddings BERT-RC supervised training with QA pairs Yes DenSPI Classic + Neural IR \u0026ndash; Yes T5 + SSM \u0026ndash; T5 SSM on CommonCrawl data + Fine-tuning on QA data Yes GPT3 \u0026ndash; GPT3 NSP on CommonCrawl data Yes RAG DPR retriever BART Yes Fusion-in-Decoder BM25 / DPR retriever Tranformer No Fig. 16. A comparison of performance of several QA models on common QA datasets. On TriviaQA, two columns of results are reported, on the open domain test set (left) and on the hidden test set (right). (Image source: Izacard \u0026 Grave, 2020). Citation Cited as:\n Weng, Lilian. (Oct 2020). How to build an open-domain question answering system? Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2020-10-29-odqa/.\n Or\n@article{weng2020odqa, title = \u0026quot;How to Build an Open-Domain Question Answering System?\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2020\u0026quot;, month = \u0026quot;Oct\u0026quot; url = \u0026quot;https://lilianweng.github.io/posts/2020-10-29-odqa/\u0026quot; } Appendix: QA Datasets SQuAD 2.0: the Stanford QA dataset. RACE: a reading comprehension dataset collected from English Examinations that are created for middle school and high school students. TREC QA: the TREC QA collections. MS MARCO: a QA dataset featuring 100,000 real Bing questions and a human generated answer. CuratedTREC: based on the benchmarks from the TREC QA tasks that have been curated by Baudis \u0026amp; Sedivy (2015). Google Natural Questions: contains real user questions issued to Google search, and answers found from Wikipedia by annotators. WebQuestions: designed for knowledge-base QA with answers restricted to Freebase entities. WikiQA: Bing query logs were used as the source of questions. Each question is then linked to a Wikipedia page that potentially contains the answer. WikiMovies: contains movie-related questions from the OMDb and MovieLens databases and where the questions can be answered using Wikipedia pages. WikiReading: to predict textual values from the structured knowledge base Wikidata by reading the text of the corresponding Wikipedia articles. TriviaQA: a reading comprehension dataset containing 95K question-answer pairs authored by trivia enthusiasts and independently gathered multiple evidence documents per question. Jeopardy! Questions: contains 200,000+ Jeopardy! questions. DeepMind Q\u0026amp;A Dataset: question/answer pairs from CNN and Daily Mail articles. bAbi: a rich collection of datasets for text understanding by Facebook. FEVER: for fact extraction and verification. SearchQA: question-answer pairs were crawled from from J! Archive, and then augmented with text snippets from Google. Quasar-T: a collection of open-domain trivia questions and their answers obtained from various internet sources. Quiz bowl: contains data from a trivia competition called quiz bowl. AmbigNQ: ambiguous questions selected from NQ-OPEN dataset. QA-Overlap: a collections of overlapped answers/questions between train and test set for Natural Questions, TriviaQA, and WebQuestions. References [1] Danqi Chen \u0026amp; Scott Yih. \u0026ldquo;ACL2020 Tutorial: Open-Domain Question Answering\u0026rdquo; July 2020.\n[2] Danqi Chen, et al. \u0026ldquo;Reading Wikipedia to Answer Open-Domain Questions\u0026rdquo; ACL 2017. | code\n[3] Shuohang Wang, et al. \u0026ldquo;R^3: Reinforced Ranker-Reader for Open-Domain Question Answering\u0026rdquo; AAAI 2018.\n[4] Jimmy Lin. \u0026ldquo;The neural hype and comparisons against weak baselines.\u0026quot; ACM SIGIR Forum. Vol. 52. No. 2. 2019.\n[5] Wei Yang, et al. \u0026ldquo;End-to-End Open-Domain Question Answering with BERTserini\u0026rdquo; NAACL 2019.\n[6] Christopher Clark \u0026amp; Matt Gardner. \u0026ldquo;Simple and Effective Multi-Paragraph Reading Comprehension.\u0026quot; arXiv:1710.10723 (2017).\n[7] Rodrigo Nogueira \u0026amp; Kyunghyun Cho. \u0026ldquo;Passage Re-ranking with BERT.\u0026quot; arXiv preprint arXiv:1901.04085 (2019). | code\n[8] Zhiguo Wang, et al. \u0026ldquo;Multi-passage BERT: A globally normalized BERT model for open-domain question answering.\u0026quot; EMNLP 2019.\n[9] Minjoon Seo et al. \u0026ldquo;Real-time open-domain question answering with dense-sparse phrase index.\u0026quot; ACL 2019.\n[10] Kenton Lee, et al. \u0026ldquo;Latent Retrieval for Weakly Supervised Open Domain Question Answering\u0026rdquo; ACL 2019.\n[11] Kelvin Guu, et al. \u0026ldquo;REALM: Retrieval-Augmented Language Model Pre-Training\u0026rdquo; arXiv:2002.08909 (2020).\n[12] Vladimir Karpukhin et al. \u0026ldquo;Dense passage retrieval for open-domain question answering.\u0026quot;. EMNLP 2020. | code\n[13] Patrick Lewis et al. \u0026ldquo;Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks\u0026rdquo; arXiv:2005.11401 (2020).\n[14] Adam Roberts, et al. \u0026ldquo;How Much Knowledge Can You Pack Into the Parameters of a Language Model?\u0026quot; EMNLP 2020.\n[15] Tom Brown, et al. \u0026ldquo;Language models are few-shot learners.\u0026quot; arXiv:2005.14165 (2020).\n[16] Fabio Petroni, et al. \u0026ldquo;How Context Affects Language Models' Factual Predictions\u0026rdquo; AKBC 2020.\n[17] Gautier Izacard \u0026amp; Edouard Grave. \u0026ldquo;Leveraging passage retrieval with generative models for open domain question answering.\u0026quot; arXiv:2007.01282 (2020).\n[18] \u0026ldquo;Dive into deep learning: Beam search\u0026rdquo;\n[19] Patrick Lewis, et al. \u0026ldquo;Question and Answer Test-Train Overlap in Open-Domain Question Answering Datasets\u0026rdquo; arXiv:2008.02637 (2020). | data\n[20] Hervé Jegou, et al. \u0026ldquo;Faiss: A library for efficient similarity search\u0026rdquo; Mar 2017.\n[21] Vidhisha Balachandran, et al. \u0026ldquo;Simple and Efficient ways to Improve REALM.\u0026quot; arXiv:2104.08710 (2021).\n","permalink":"https://lilianweng.github.io/posts/2020-10-29-odqa/","summary":"[Updated on 2020-11-12: add an example on closed-book factual QA using OpenAI API (beta).\nA model that can answer any question with regard to factual knowledge can lead to many useful and practical applications, such as working as a chatbot or an AI assistant🤖. In this post, we will review several common approaches for building such an open-domain question answering system.\nDisclaimers given so many papers in the wild:\n Assume we have access to a powerful pretrained language model.","title":"How to Build an Open-Domain Question Answering System?"},{"content":"Although most popular and successful model architectures are designed by human experts, it doesn\u0026rsquo;t mean we have explored the entire network architecture space and settled down with the best option. We would have a better chance to find the optimal solution if we adopt a systematic and automatic way of learning high-performance model architectures.\nAutomatically learning and evolving network topologies is not a new idea (Stanley \u0026amp; Miikkulainen, 2002). In recent years, the pioneering work by Zoph \u0026amp; Le 2017 and Baker et al. 2017 has attracted a lot of attention into the field of Neural Architecture Search (NAS), leading to many interesting ideas for better, faster and more cost-efficient NAS methods.\nAs I started looking into NAS, I found this nice survey very helpful by Elsken, et al 2019. They characterize NAS as a system with three major components, which is clean \u0026amp; concise, and also commonly adopted in other NAS papers.\n Search space: The NAS search space defines a set of operations (e.g. convolution, fully-connected, pooling) and how operations can be connected to form valid network architectures. The design of search space usually involves human expertise, as well as unavoidably human biases. Search algorithm: A NAS search algorithm samples a population of network architecture candidates. It receives the child model performance metrics as rewards (e.g. high accuracy, low latency) and optimizes to generate high-performance architecture candidates. Evaluation strategy: We need to measure, estimate, or predict the performance of a large number of proposed child models in order to obtain feedback for the search algorithm to learn. The process of candidate evaluation could be very expensive and many new methods have been proposed to save time or computation resources. Fig. 1. Three main components of Neural Architecture Search (NAS) models. (Image source: Elsken, et al. 2019 with customized annotation in red) Search Space The NAS search space defines a set of basic network operations and how operations can be connected to construct valid network architectures.\nSequential Layer-wise Operations The most naive way to design the search space for neural network architectures is to depict network topologies, either CNN or RNN, with a list of sequential layer-wise operations, as seen in the early work of Zoph \u0026amp; Le 2017 \u0026amp; Baker et al. 2017. The serialization of network representation requires a decent amount of expert knowledge, since each operation is associated with different layer-specific parameters and such associations need to be hardcoded. For example, after predicting a conv op, the model should output kernel size, stride size, etc; or after predicting an FC op, we need to see the number of units as the next prediction.\nFig. 2. (Top) A sequential representation of CNN. (Bottom) A sequential representation of the tree structure of a recurrent cell. (Image source: Zoph \u0026 Le 2017) To make sure the generated architecture is valid, additional rules might be needed (Zoph \u0026amp; Le 2017):\n If a layer is not connected to any input layer then it is used as the input layer; At the final layer, take all layer outputs that have not been connected and concatenate them; If one layer has many input layers, then all input layers are concatenated in the depth dimension; If input layers to be concatenated have different sizes, we pad the small layers with zeros so that the concatenated layers have the same sizes. The skip connection can be predicted as well, using an attention-style mechanism. At layer $i$ , an anchor point is added with $i−1$ content-based sigmoids to indicate which of the previous layers to be connected. Each sigmoid takes as input the hidden states of the current node $h_i$ and $i-1$ previous nodes $h_j, j=1, \\dots, i-1$ .\n $$ P(\\text{Layer j is an input to layer i}) = \\text{sigmoid}(v^\\top \\tanh(\\mathbf{W}_\\text{prev} h_j + \\mathbf{W}_\\text{curr} h_i)) $$ The sequential search space has a lot of representation power, but it is very large and consumes a ton of computation resources to exhaustively cover the search space. In the experiments by Zoph \u0026amp; Le 2017, they were running 800 GPUs in parallel for 28 days and Baker et al. 2017 restricted the search space to contain at most 2 FC layers.\nCell-based Representation Inspired by the design of using repeated modules in successful vision model architectures (e.g. Inception, ResNet), the NASNet search space (Zoph et al. 2018) defines the architecture of a conv net as the same cell getting repeated multiple times and each cell contains several operations predicted by the NAS algorithm. A well-designed cell module enables transferability between datasets. It is also easy to scale down or up the model size by adjusting the number of cell repeats.\nPrecisely, the NASNet search space learns two types of cells for network construction:\n Normal Cell: The input and output feature maps have the same dimension. Reduction Cell: The output feature map has its width and height reduced by half. Fig. 3. The NASNet search space constrains the architecture as a repeated stack of cells. The cell architecture is optimized via NAS algorithms. (Image source: Zoph et al. 2018) The predictions for each cell are grouped into $B$ blocks ($B=5$ in the NASNet paper), where each block has 5 prediction steps made by 5 distinct softmax classifiers corresponding to discrete choices of the elements of a block. Note that the NASNet search space does not have residual connections between cells and the model only learns skip connections on their own within blocks.\nFig. 4. (a) Each cell consists of $B$ blocks and each block is predicted by 5 discrete decisions. (b) An concrete example of what operations can be chosen in each decision step. During the experiments, they discovered that a modified version of DropPath, named ScheduledDropPath, significantly improves the final performance of NASNet experiments. DropPath stochastically drops out paths (i.e. edges with operations attached in NASNet) with a fixed probability. ScheduledDropPath is DropPath with a linearly increasing probability of path dropping during training time.\nElsken, et al (2019) point out three major advantages of the NASNet search space:\n The search space size is reduced drastically; The motif-based architecture can be more easily transferred to different datasets. It demonstrates a strong proof of a useful design pattern of repeatedly stacking modules in architecture engineering. For example, we can build strong models by stacking residual blocks in CNN or stacking multi-headed attention blocks in Transformer. Hierarchical Structure To take advantage of already discovered well-designed network motifs, the NAS search space can be constrained as a hierarchical structure, as in Hierarchical NAS (HNAS; (Liu et al 2017)). It starts with a small set of primitives, including individual operations like convolution operation, pooling, identity, etc. Then small sub-graphs (or \u0026ldquo;motifs\u0026rdquo;) that consist of primitive operations are recursively used to form higher-level computation graphs.\nA computation motif at level $\\ell=1, \\dots, L$ can be represented by $(G^{(\\ell)}, \\mathcal{O}^{(\\ell)})$, where:\n $\\mathcal{O}^{(\\ell)}$ is a set of operations, $\\mathcal{O}^{(\\ell)} = \\{ o^{(\\ell)}_1, o^{(\\ell)}_2, \\dots \\}$ $G^{(\\ell)}$ is an adjacency matrix, where the entry $G_{ij}=k$ indicates that operation $o^{(\\ell)}_k$ is placed between node $i$ and $j$. The node indices follow topological ordering in DAG, where the index $1$ is the source and the maximal index is the sink node. Fig. 5. (Top) Three level-1 primitive operations are composed into a level-2 motif. (Bottom) Three level-2 motifs are plugged into a base network structure and assembled into a level-3 motif. (Image source: Liu et al 2017) To build a network according to the hierarchical structure, we start from the lowest level $\\ell=1$ and recursively define the $m$-th motif operation at level $\\ell$ as\n $$ o^{(\\ell)}_m = \\text{assemble}\\Big( G_m^{(\\ell)}, \\mathcal{O}^{(\\ell-1)} \\Big) $$ A hierarchical representation becomes $\\Big( \\big\\{ \\{ G_m^{(\\ell)} \\}_{m=1}^{M_\\ell} \\big\\}_{\\ell=2}^L, \\mathcal{O}^{(1)} \\Big), \\forall \\ell=2, \\dots, L$, where $\\mathcal{O}^{(1)}$ contains a set of primitive operations.\nThe $\\text{assemble}()$ process is equivalent to sequentially compute the feature map of node $i$ by aggregating all the feature maps of its predecessor node $j$ following the topological ordering:\n $$ x_i = \\text{merge} \\big[ \\{ o^{(\\ell)}_{G^{(\\ell)}_{ij}}(x_j) \\}_{j where $\\text{merge}[]$ is implemented as depth-wise concatenation in the paper.\nSame as NASNet, experiments in Liu et al (2017) focused on discovering good cell architecture within a predefined \u0026ldquo;macro\u0026rdquo; structure with repeated modules. They showed that the power of simple search methods (e.g. random search or evolutionary algorithms) can be substantially enhanced using well-designed search spaces.\nCai et al (2018b) propose a tree-structure search space using path-level network transformation. Each node in a tree structure defines an allocation scheme for splitting inputs for child nodes and a merge scheme for combining results from child nodes. The path-level network transformation allows replacing a single layer with a multi-branch motif if its corresponding merge scheme is add or concat.\nFig. 6. An illustration of transforming a single layer to a tree-structured motif via path-level transformation operations. (Image source: Cai et al. 2018b) Memory-bank Representation A memory-bank representation of feed-forward networks is proposed by Brock et al. (2017) in SMASH. Instead of a graph of operations, they view a neural network as a system with multiple memory blocks which can read and write. Each layer operation is designed to: (1) read from a subset of memory blocks; (2) computes results; finally (3) write the results into another subset of blocks. For example, in a sequential model, a single memory block would get read and overwritten consistently.\nFig. 7. Memory-bank representation of several popular network architecture blocks. (Image source: Brock et al. 2017) Search Algorithms NAS search algorithms sample a population of child networks. It receives the child models' performance metrics as rewards and learns to generate high-performance architecture candidates. You may a lot in common with the field of hyperparameter search.\nRandom Search Random search is the most naive baseline. It samples a valid architecture candidate from the search space at random and no learning model is involved. Random search has proved to be quite useful in hyperparameter search (Bergstra \u0026amp; Bengio 2012). With a well-designed search space, random search could be a very challenging baseline to beat.\nReinforcement Learning The initial design of NAS (Zoph \u0026amp; Le 2017) involves a RL-based controller for proposing child model architectures for evaluation. The controller is implemented as a RNN, outputting a variable-length sequence of tokens used for configuring a network architecture.\nFig. 8. A high level overview of NAS, containing a RNN controller and a pipeline for evaluating child models. (Image source: Zoph \u0026 Le 2017) The controller is trained as a RL task using REINFORCE.\n Action space: The action space is a list of tokens for defining a child network predicted by the controller (See more in the above section). The controller outputs action, $a_{1:T}$, where $T$ is the total number of tokens. Reward: The accuracy of a child network that can be achieved at convergence is the reward for training the controller, $R$. Loss: NAS optimizes the controller parameters $\\theta$ with a REINFORCE loss. We want to maximize the expected reward (high accuracy) with the gradient as follows. The nice thing here with policy gradient is that it works even when the reward is non-differentiable. $$ \\nabla_{\\theta} J(\\theta) = \\sum_{t=1}^T \\mathbb{E}[\\nabla_{\\theta} \\log P(a_t \\vert a_{1:(t-1)}; \\theta) R ] $$ MetaQNN (Baker et al. 2017) trains an agent to sequentially choose CNN layers using Q-learning with an $\\epsilon$-greedy exploration strategy and experience replay. The reward is the validation accuracy as well.\n $$ Q^{(t+1)}(s_t, a_t) = (1 - \\alpha)Q^{(t)}(s_t, a_t) + \\alpha (R_t + \\gamma \\max_{a \\in \\mathcal{A}} Q^{(t)}(s_{t+1}, a')) $$ where a state $s_t$ is a tuple of layer operation and related parameters. An action $a$ determines the connectivity between operations. The Q-value is proportional to how confident we are in two connected operations leading to high accuracy.\nFig. 9. Overview of MetaQNN - designing CNN models with Q-Learning. (Image source: Baker et al. 2017) Evolutionary Algorithms NEAT (short for NeuroEvolution of Augmenting Topologies) is an approach for evolving neural network topologies with genetic algorithm (GA), proposed by Stanley \u0026amp; Miikkulainen in 2002. NEAT evolves both connection weights and network topology together. Each gene encodes the full information for configuring a network, including node weights and edges. The population grows by applying mutation of both weights and connections, as well as crossover between two parent genes. For more in neuroevolution, please refer to the in-depth survey by Stanley et al. (2019).\nFig. 10. Mutations in the NEAT algorithm. (Image source: Fig 3 \u0026 4 in Stanley \u0026 Miikkulainen, 2002) Real et al. (2018) adopt the evolutionary algorithms (EA) as a way to search for high-performance network architectures, named AmoebaNet. They apply the tournament selection method, which at each iteration picks a best candidate out of a random set of samples and places its mutated offspring back into the population. When the tournament size is $1$, it is equivalent to random selection.\nAmoebaNet modified the tournament selection to favor younger genotypes and always discard the oldest models within each cycle. Such an approach, named aging evolution, allows AmoebaNet to cover and explore more search space, rather than to narrow down on good performance models too early.\nPrecisely, in every cycle of the tournament selection with aging regularization (See Figure 11):\n Sample $S$ models from the population and the one with highest accuracy is chosen as parent. A child model is produced by mutating parent. Then the child model is trained, evaluated and added back into the population. The oldest model is removed from the population. Fig. 11. The algorithm of aging evolution. (Image source: Real et al. 2018) Two types of mutations are applied:\n Hidden state mutation: randomly chooses a pairwise combination and rewires a random end such that there is no loop in the graph. Operation mutation: randomly replaces an existing operation with a random one. Fig. 12. Two types of mutations in AmoebaNet. (Image source: Real et al. 2018) In their experiments, EA and RL work equally well in terms of the final validation accuracy, but EA has better anytime performance and is able to find smaller models. Here using EA in NAS is still expensive in terms of computation, as each experiment took 7 days with 450 GPUs.\nHNAS (Liu et al 2017) also employs the evolutionary algorithms (the original tournament selection) as their search strategy. In the hierarchical structure search space, each edge is an operation. Thus genotype mutation in their experiments is applied by replacing a random edge with a different operation. The replacement set includes an none op, so it can alter, remove and add an edge. The initial set of genotypes is created by applying a large number of random mutations on \u0026ldquo;trivial\u0026rdquo; motifs (all identity mappings).\nProgressive Decision Process Constructing a model architecture is a sequential process. Every additional operator or layer brings extra complexity. If we guide the search model to start the investigation from simple models and gradually evolve to more complex architectures, it is like to introduce \u0026ldquo;curriculum\u0026rdquo; into the search model\u0026rsquo;s learning process.\nProgressive NAS (PNAS; Liu, et al 2018) frames the problem of NAS as a progressive procedure for searching models of increasing complexity. Instead of RL or EA, PNAS adopts a Sequential Model-based Bayesian Optimization (SMBO) as the search strategy. PNAS works similar to A* search, as it searches for models from simple to hard while simultaneously learning a surrogate function to guide the search.\n A* search algorithm (\u0026ldquo;best-first search\u0026rdquo;) is a popular algorithm for path finding. The problem is framed as finding a path of smallest cost from a specific starting node to a given target node in a weighted graph. At each iteration, A* finds a path to extend by minimizing: $f(n)=g(n)+h(n)$, where $n$ is the next node, $g(n)$ is the cost from start to $n$, and $h(n)$ is the heuristic function that estimates the minimum cost of going from node $n$ to the goal.\n PNAS uses the NASNet search space. Each block is specified as a 5-element tuple and PNAS only considers the element-wise addition as the step 5 combination operator, no concatenation. Differently, instead of setting the number of blocks $B$ at a fixed number, PNAS starts with $B=1$, a model with only one block in a cell, and gradually increases $B$.\nThe performance on a validation set is used as feedback to train a surrogate model for predicting the performance of novel architectures. With this predictor, we can thus decide which models should be prioritized to be evaluated next. Since the performance predictor should be able to handle various-sized inputs, accuracy, and sample-efficient, they ended up using an RNN model.\nFig. 13. The algorithm of Progressive NAS. (Image source: Liu, et al 2018) Gradient descent Using gradient descent to update the architecture search model requires an effort to make the process of choosing discrete operations differentiable. These approaches usually combine the learning of both architecture parameters and network weights together into one model. See more in the section on the \u0026ldquo;one-shot\u0026rdquo; approach.\nEvaluation Strategy We need to measure, estimate or predict the performance of every child model in order to obtain feedback for optimizing the search algorithm. The process of candidate evaluation could be very expensive and many new evaluation methods have been proposed to save time or computation. When evaluating a child model, we mostly care about its performance measured as accuracy on a validation set. Recent work has started looking into other factors of a model, such as model size and latency, as certain devices may have limitations on memory or demand fast response time.\nTraining from Scratch The most naive approach is to train every child network independently from scratch until convergence and then measure its accuracy on a validation set (Zoph \u0026amp; Le 2017). It provides solid performance numbers, but one complete train-converge-evaluate loop only generates a single data sample for training the RL controller (let alone RL is known to be sample-inefficient in general). Thus it is very expensive in terms of computation consumption.\nProxy Task Performance There are several approaches for using a proxy task performance as the performance estimator of a child network, which is generally cheaper and faster to calculate:\n Train on a smaller dataset. Train for fewer epochs. Train and evaluate a down-scaled model in the search stage. For example, once a cell structure is learned, we can play with the number of cell repeats or scale up the number of filters (Zoph et al. 2018). Predict the learning curve. Baker et al (2018) model the prediction of validation accuracies as a time-series regression problem. The features for the regression model ($\\nu$-support vector machine regressions; $\\nu$-SVR) include the early sequences of accuracy per epoch, architecture parameters, and hyperparameters. Parameter Sharing Instead of training every child model independently from scratch. You may ask, ok, what if we fabricate dependency between them and find a way to reuse weights? Some researchers succeeded to make such approaches work.\nInspired by Net2net transformation, Cai et al (2017) proposed Efficient Architecture Search (EAS). EAS sets up an RL agent, known as a meta-controller, to predict function-preserving network transformation so as to grow the network depth or layer width. Because the network is growing incrementally, the weights of previously validated networks can be reused for further exploration. With inherited weights, newly constructed networks only need some light-weighted training.\nA meta-controller learns to generate network transformation actions given the current network architecture, which is specified with a variable-length string. In order to handle architecture configuration of a variable length, the meta-controller is implemented as a bi-directional recurrent network. Multiple actor networks output different transformation decisions:\n Net2WiderNet operation allows to replace a layer with a wider layer, meaning more units for fully-connected layers, or more filters for convolutional layers, while preserving the functionality. Net2DeeperNet operation allows to insert a new layer that is initialized as adding an identity mapping between two layers so as to preserve the functionality. Fig. 14. Overview of the RL based meta-controller in Efficient Architecture Search (NAS). After encoding the architecture configuration, it outputs net2net transformation actions through two separate actor networks. (Image source: Cai et al 2017) With similar motivation, Efficient NAS (ENAS; Pham et al. 2018) speeds up NAS (i.e. 1000x less) by aggressively sharing parameters among child models. The core motivation behind ENAS is the observation that all of the sampled architecture graphs can be viewed as sub-graphs of a larger supergraph. All the child networks are sharing weights of this supergraph.\nFig. 15. (Left) The graph represents the entire search space for a 4-node recurrent cell, but only connections in red are active. (Middle) An example of how the left active sub-graph can be translated into a child model architecture. (Right) The network parameters produced by an RNN controller for the architecture in the middle. (Image source: Pham et al. 2018) ENAS alternates between training the shared model weights $\\omega$ and training the controller $\\theta$:\n The parameters of the controller LSTM $\\theta$ are trained with REINFORCE, where the reward $R(\\mathbf{m}, \\omega)$ is computed on the validation set. The shared parameters of the child models $\\omega$ are trained with standard supervised learning loss. Note that different operators associated with the same node in the supergraph would have their own distinct parameters. Prediction-Based A routine child model evaluation loop is to update model weights via standard gradient descent. SMASH (Brock et al. 2017) proposes a different and interesting idea: Can we predict the model weights directly based on the network architecture parameters?\nThey employ a HyperNet (Ha et al 2016) to directly generate the weights of a model conditioned on an encoding of its architecture configuration. Then the model with HyperNet-generated weights is validated directly. Note that we don\u0026rsquo;t need extra training for every child model but we do need to train the HyperNet.\nFig. 16. The algorithm of SMASH. (Image source: Brock et al. 2017) The correlation between model performance with SMASH-generated weights and true validation errors suggests that predicted weights can be used for model comparison, to some extent. We do need a HyperNet of large enough capacity, as the correlation would be corrupted if the HyperNet model is too small compared to the child model size.\nFig. 17. The algorithm of SMASH. (Image source: Brock et al. 2017) SMASH can be viewed as another way to implement the idea of parameter sharing. One problem of SMASH as pointed out by Pham et al. (2018) is: The usage of HyperNet restricts the weights of SMASH child models to a low-rank space, because weights are generated via tensor products. In comparison, ENAS has no such restrictions.\nOne-Shot Approach: Search + Evaluation Running search \u0026amp; evaluation independently for a large population of child models is expensive. We have seen promising approaches like Brock et al. (2017) or Pham et al. (2018), where training a single model is enough for emulating any child model in the search space.\nThe one-shot architecture search extends the idea of weight sharing and further combines the learning of architecture generation together with weight parameters. The following approaches all treat child architectures as different sub-graphs of a supergraph with shared weights between common edges in the supergraph.\nBender et al (2018) construct a single large over-parameterized network, known as the One-Shot model, such that it contains every possible operation in the search space. With ScheduledDropPath (the dropout rate is increased over time, which is $r^{1/k}$ at the end of training, where $0 \u0026lt; r \u0026lt; 1$ is a hyperparam and $k$ is the number of incoming paths) and some carefully designed tricks (e.g. ghost batch normalization, L2 regularization only on the active architecture), the training of such a giant model can be stabilized enough and used for evaluating any child model sampled from the supergraph.\nFig. 18. The architecture of the One-Shot model in Bender et al 2018. Each cell has $N$ choice blocks and each choice block can select up to 2 operations. Solid edges are used in every architecture, where dash lines are optional. (Image source: Bender et al 2018) Once the one-shot model is trained, it is used for evaluating the performance of many different architectures sampled at random by zeroing out or removing some operations. This sampling process can be replaced by RL or evolution.\nThey observed that the difference between the accuracy measured with the one-shot model and the accuracy of the same architecture after a small fine-tuning could be very large. Their hypothesis is that the one-shot model automatically learns to focus on the most useful operations in the network and comes to rely on these operations when they are available. Thus zeroing out useful operations lead to big reduction in model accuracy, while removing less important components only causes a small impact \u0026mdash; Therefore, we see a larger variance in scores when using the one-shot model for evaluation.\nFig. 19. A stratified sample of models with different one-shot model accuracy versus their true validation accuracy as stand-alone models. (Image source: Bender et al 2018) Clearly designing such a search graph is not a trivial task, but it demonstrates a strong potential with the one-shot approach. It works well with only gradient descent and no additional algorithm like RL or EA is a must.\nSome believe that one main cause for inefficiency in NAS is to treat the architecture search as a black-box optimization and thus we fall into methods like RL, evolution, SMBO, etc. If we shift to rely on standard gradient descent, we could potentially make the search process more effectively. As a result, Liu et al (2019) propose Differentiable Architecture Search (DARTS). DARTS introduces a continuous relaxation on each path in the search supergraph, making it possible to jointly train architecture parameters and weights via gradient descent.\nLet\u0026rsquo;s use the directed acyclic graph (DAG) representation here. A cell is a DAG consisting of a topologically ordered sequence of $N$ nodes. Each node has a latent representation $x_i$ to be learned. Each edge $(i, j)$ is tied to some operation $o^{(i,j)} \\in \\mathcal{O}$ that transforms $x_j$ to compose $x_i$:\n $$ x_i = \\sum_{j To make the search space continuous, DARTS relaxes the categorical choice of a particular operation as a softmax over all the operations and the task of architecture search is reduced to learn a set of mixing probabilities $\\alpha = \\{ \\alpha^{(i,j)} \\}$.\n $$ \\bar{o}^{(i,j)}(x) = \\sum_{o\\in\\mathcal{O}} \\frac{\\exp(\\alpha_{ij}^o)}{\\sum_{o'\\in\\mathcal{O}} \\exp(\\alpha^{o'}_{ij})} o(x) $$ where $\\alpha_{ij}$ is a vector of dimension $\\vert \\mathcal{O} \\vert$, containing weights between nodes $i$ and $j$ over different operations.\nThe bilevel optimization exists as we want to optimize both the network weights $w$ and the architecture representation $\\alpha$:\n $$ \\begin{aligned} \\min_\\alpha \u0026 \\mathcal{L}_\\text{validate} (w^*(\\alpha), \\alpha) \\\\ \\text{s.t.} \u0026 w^*(\\alpha) = \\arg\\min_w \\mathcal{L}_\\text{train} (w, \\alpha) \\end{aligned} $$ At step $k$, given the current architecture parameters $\\alpha_{k−1}$, we first optimize weights $w_k$ by moving $w_{k−1}$ in the direction of minimizing the training loss $\\mathcal{L}_\\text{train}(w_{k−1}, \\alpha_{k−1})$ with a learning rate $\\xi$. Next, while keeping the newly updated weights $w_k$ fixed, we update the mixing probabilities so as to minimize the validation loss after a single step of gradient descent w.r.t. the weights:\n $$ J_\\alpha = \\mathcal{L}_\\text{val}(w_k - \\xi \\nabla_w \\mathcal{L}_\\text{train}(w_k, \\alpha_{k-1}), \\alpha_{k-1}) $$ The motivation here is that we want to find an architecture with a low validation loss when its weights are optimized by gradient descent and the one-step unrolled weights serve as the surrogate for $w^∗(\\alpha)$.\n Side note: Earlier we have seen similar formulation in MAML where the two-step optimization happens between task losses and the meta-learner update, as well as framing Domain Randomization as a bilevel optimization for better transfer in the real environment.\n Fig. 20. An illustration of how DARTS applies continuous relaxation on edges in DAG supergraph and identifies the final model. (Image source: Liu et al 2019) $$ \\begin{aligned} \\text{Let }w'_k \u0026= w_k - \\xi \\nabla_w \\mathcal{L}_\\text{train}(w_k, \\alpha_{k-1}) \u0026 \\\\ J_\\alpha \u0026= \\mathcal{L}_\\text{val}(w_k - \\xi \\nabla_w \\mathcal{L}_\\text{train}(w_k, \\alpha_{k-1}), \\alpha_{k-1}) = \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1}) \u0026 \\\\ \\nabla_\\alpha J_\\alpha \u0026= \\nabla_{\\alpha_{k-1}} \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1}) \\nabla_\\alpha \\alpha_{k-1} + \\nabla_{w'_k} \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1})\\nabla_\\alpha w'_k \u0026 \\\\\u0026 \\text{; multivariable chain rule}\\\\ \u0026= \\nabla_{\\alpha_{k-1}} \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1}) + \\nabla_{w'_k} \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1}) \\big( - \\xi \\color{red}{\\nabla^2_{\\alpha, w} \\mathcal{L}_\\text{train}(w_k, \\alpha_{k-1})} \\big) \u0026 \\\\ \u0026\\approx \\nabla_{\\alpha_{k-1}} \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1}) - \\xi \\nabla_{w'_k} \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1}) \\color{red}{\\frac{\\nabla_\\alpha \\mathcal{L}_\\text{train}(w_k^+, \\alpha_{k-1}) - \\nabla_\\alpha \\mathcal{L}_\\text{train}(w_k^-, \\alpha_{k-1}) }{2\\epsilon}} \u0026 \\\\ \u0026 \\text{; apply numerical differentiation approximation} \\end{aligned} $$ where the red part is using numerical differentiation approximation where $w_k^+ = w_k + \\epsilon \\nabla_{w'_k} \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1})$ and $w_k^- = w_k - \\epsilon \\nabla_{w'_k} \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1})$.\nFig. 21. The algorithm overview of DARTS. (Image source: Liu et al 2019) As another idea similar to DARTS, Stochastic NAS (Xie et al., 2019) applies a continuous relaxation by employing the concrete distribution (CONCRETE = CONtinuous relaxations of disCRETE random variables; Maddison et al 2017) and reparametrization tricks. The goal is same as DARTS, to make the discrete distribution differentiable and thus enable optimization by gradient descent.\nDARTS is able to greatly reduce the cost of GPU hours. Their experiments for searching for CNN cells have $N=7$ and only took 1.5 days with a single GPU. However, it suffers from the high GPU memory consumption issue due to its continuous representation of network architecture. In order to fit the model into the memory of a single GPU, they picked a small $N$.\nTo constrain the GPU memory consumption, ProxylessNAS (Cai et al., 2019) views NAS as a path-level pruning process in DAG and binarizes the architecture parameters to force only one path to be active between two nodes at a time. The probabilities for an edge being either masked out or not are then learned by sampling a few binarized architectures and using BinaryConnect (Courbariaux et al., 2015) to update the corresponding probabilities. ProxylessNAS demonstrates a strong connection between NAS and model compression. By using path-level compression, it is able to save memory consumption by one order of magnitude.\nLet\u0026rsquo;s continue with the graph representation. In a DAG adjacency matrix $G$ where $G_{ij}$ represents an edge between node $i$ and $j$ and its value can be chosen from the set of $\\vert \\mathcal{O} \\vert$ candidate primitive operations, $\\mathcal{O} = \\{ o_1, \\dots \\}$. The One-Shot model, DARTS and ProxylessNAS all consider each edge as a mixture of operations, $m_\\mathcal{O}$, but with different tweaks.\nIn One-Shot, $m_\\mathcal{O}(x)$ is the sum of all the operations. In DARTS, it is a weighted sum where weights are softmax over a real-valued architecture weighting vector $\\alpha$ of length $\\vert \\mathcal{O} \\vert$. ProxylessNAS transforms the softmax probabilities of $\\alpha$ into a binary gate and uses the binary gate to keep only one operation active at a time.\n $$ \\begin{aligned} m^\\text{one-shot}_\\mathcal{O}(x) \u0026= \\sum_{i=1}^{\\vert \\mathcal{O} \\vert} o_i(x) \\\\ m^\\text{DARTS}_\\mathcal{O}(x) \u0026= \\sum_{i=1}^{\\vert \\mathcal{O} \\vert} p_i o_i(x) = \\sum_{i=1}^{\\vert \\mathcal{O} \\vert} \\frac{\\exp(\\alpha_i)}{\\sum_j \\exp(\\alpha_j)} o_i(x) \\\\ m^\\text{binary}_\\mathcal{O}(x) \u0026= \\sum_{i=1}^{\\vert \\mathcal{O} \\vert} g_i o_i(x) = \\begin{cases} o_1(x) \u0026 \\text{with probability }p_1, \\\\ \\dots \u0026\\\\ o_{\\vert \\mathcal{O} \\vert}(x) \u0026 \\text{with probability }p_{\\vert \\mathcal{O} \\vert} \\end{cases} \\\\ \\text{ where } g \u0026= \\text{binarize}(p_1, \\dots, p_N) = \\begin{cases} [1, 0, \\dots, 0] \u0026 \\text{with probability }p_1, \\\\ \\dots \u0026 \\\\ [0, 0, \\dots, 1] \u0026 \\text{with probability }p_N. \\\\ \\end{cases} \\end{aligned} $$ Fig. 22. ProxylessNAS has two training steps running alternatively. (Image source: Cai et al., 2019) ProxylessNAS runs two training steps alternatively:\n When training weight parameters $w$, it freezes the architecture parameters $\\alpha$ and stochastically samples binary gates $g$ according to the above $m^\\text{binary}_\\mathcal{O}(x)$. The weight parameters can be updated with standard gradient descent. When training architecture parameters $\\alpha$, it freezes $w$, resets the binary gates and then updates $\\alpha$ on the validation set. Following the idea of BinaryConnect, the gradient w.r.t. architecture parameters can be approximately estimated using $\\partial \\mathcal{L} / \\partial g_i$ in replacement for $\\partial \\mathcal{L} / \\partial p_i$: $$ \\begin{aligned} \\frac{\\partial \\mathcal{L}}{\\partial \\alpha_i} \u0026= \\sum_{j=1}^{\\vert \\mathcal{O} \\vert} \\frac{\\partial \\mathcal{L}}{\\partial p_j} \\frac{\\partial p_j}{\\partial \\alpha_i} \\approx \\sum_{j=1}^{\\vert \\mathcal{O} \\vert} \\frac{\\partial \\mathcal{L}}{\\partial g_j} \\frac{\\partial p_j}{\\partial \\alpha_i} = \\sum_{j=1}^{\\vert \\mathcal{O} \\vert} \\frac{\\partial \\mathcal{L}}{\\partial g_j} \\frac{\\partial \\frac{e^{\\alpha_j}}{\\sum_k e^{\\alpha_k}}}{\\partial \\alpha_i} \\\\ \u0026= \\sum_{j=1}^{\\vert \\mathcal{O} \\vert} \\frac{\\partial \\mathcal{L}}{\\partial g_j} \\frac{\\sum_k e^{\\alpha_k} (\\mathbf{1}_{i=j} e^{\\alpha_j}) - e^{\\alpha_j} e^{\\alpha_i} }{(\\sum_k e^{\\alpha_k})^2} = \\sum_{j=1}^{\\vert \\mathcal{O} \\vert} \\frac{\\partial \\mathcal{L}}{\\partial g_j} p_j (\\mathbf{1}_{i=j} -p_i) \\end{aligned} $$ Instead of BinaryConnect, REINFORCE can also be used for parameter updates with the goal for maximizing the reward, while no RNN meta-controller is involved.\nComputing $\\partial \\mathcal{L} / \\partial g_i$ needs to calculate and store $o_i(x)$, which requires $\\vert \\mathcal{O} \\vert$ times GPU memory. To resolve this issue, they factorize the task of choosing one path out of $N$ into multiple binary selection tasks (Intuition: \u0026ldquo;if a path is the best choice, it should be better than any other path\u0026rdquo;). At every update step, only two paths are sampled while others are masked. These two selected paths are updated according to the above equation and then scaled properly so that other path weights are unchanged. After this process, one of the sampled paths is enhanced (path weight increases) and the other is attenuated (path weight decreases), while all other paths stay unaltered.\nBesides accuracy, ProxylessNAS also considers latency as an important metric to optimize, as different devices might have very different requirements on inference time latency (e.g. GPU, CPU, mobile). To make latency differentiable, they model latency as a continuous function of the network dimensions. The expected latency of a mixed operation can be written as $\\mathbb{E}[\\text{latency}] = \\sum_j p_j F(o_j)$, where $F(.)$ is a latency prediction model:\nFig. 23. Add a differentiable latency loss into the training of ProxylessNAS. (Image source: Cai et al., 2019) What\u0026rsquo;s the Future? So far we have seen many interesting new ideas on automating the network architecture engineering through neural architecture search and many have achieved very impressive performance. However, it is a bit hard to do inference on why some architecture work well and how we can develop modules generalizable across tasks rather than being very dataset-specific.\nAs also noted in Elsken, et al (2019):\n \u0026ldquo;\u0026hellip;, so far it provides little insights into why specific architectures work well and how similar the architectures derived in independent runs would be. Identifying common motifs, providing an understanding why those motifs are important for high performance, and investigating if these motifs generalize over different problems would be desirable.\u0026rdquo;\n In the meantime, purely focusing on improvement over validation accuracy might not be enough (Cai et al., 2019). Devices like mobile phones for daily usage in general have limited memory and computation power. While AI applications are on the way to affect our daily life, it is unavoidable to be more device-specific.\nAnother interesting investigation is to consider unlabelled dataset and self-supervised learning for NAS. The size of labelled dataset is always limited and it is not easy to tell whether such a dataset has biases or big deviation from the real world data distribution.\nLiu et al (2020) delve into the question \u0026ldquo;Can we find high-quality neural architecture without human-annotated labels?\u0026quot; and proposed a new setup called Unsupervised Neural Architecture Search (UnNAS). The quality of the architecture needs to be estimated in an unsupervised fashion during the search phase. The paper experimented with three unsupervised pretext tasks: image rotation prediction, colorization, and solving the jigsaw puzzle.\nThey observed in a set of UnNAS experiments that:\n High rank correlation between supervised accuracy and pretext accuracy on the same dataset. Typically the rank correlation is higher than 0.8, regardless of the dataset, the search space, and the pretext task. High rank correlation between supervised accuracy and pretext accuracy across datasets. Better pretext accuracy translates to better supervised accuracy. Performance of UnNAS architecture is comparable to supervised counterparts, though not better yet. One hypothesis is that the architecture quality is correlated with image statistics. Because CIFAR-10 and ImageNet are all on the natural images, they are comparable and the results are transferable. UnNAS could potentially enable a much larger amount of unlabelled data into the search phase which captures image statistics better.\nHyperparameter search is a long-standing topic in the ML community. And NAS automates architecture engineering. Gradually we are trying to automate processes in ML which usually demand a lot of human efforts. Taking even one more step further, is it possible to automatically discover ML algorithms? AutoML-Zero (Real et al 2020) investigates this idea. Using aging evolutionary algorithms, AutoML-Zero automatically searches for whole ML algorithms using little restriction on the form with only simple mathematical operations as building blocks.\nIt learns three component functions. Each function only adopts very basic operations.\n Setup: initialize memory variables (weights). Learn: modify memory variables Predict: make a prediction from an input $x$. Fig. 24. Algorithm evaluation on one task (Image source: Real et al 2020) Three types of operations are considered when mutating a parent genotype:\n Insert a random instruction or remove an instruction at a random location in a component function; Randomize all the instructions in a component function; Modify one of the arguments of an instruction by replacing it with a random choice (e.g. \u0026ldquo;swap the output address\u0026rdquo; or \u0026ldquo;change the value of a constant\u0026rdquo;) Fig. 25. An illustration of evolutionary progress on projected binary CIFAR-10 with example code. (Image source: Real et al 2020) Appendix: Summary of NAS Papers Model name Search space Search algorithms Child model evaluation NEAT (2002) - Evolution (Genetic algorithm) - NAS (2017) Sequential layer-wise ops RL (REINFORCE) Train from scratch until convergence MetaQNN (2017) Sequential layer-wise ops RL (Q-learning with $\\epsilon$-greedy) Train for 20 epochs HNAS (2017) Hierarchical structure Evolution (Tournament selection) Train for a fixed number of iterations NASNet (2018) Cell-based RL (PPO) Train for 20 epochs AmoebaNet (2018) NASNet search space Evolution (Tournament selection with aging regularization) Train for 25 epochs EAS (2018a) Network transformation RL (REINFORCE) 2-stage training PNAS (2018) Reduced version of NASNet search space SMBO; Progressive search for architectures of increasing complexity Train for 20 epochs ENAS (2018) Both sequential and cell-based search space RL (REINFORCE) Train one model with shared weights SMASH (2017) Memory-bank representation Random search HyperNet predicts weights of evaluated architectures. One-Shot (2018) An over-parameterized one-shot model Random search (zero out some paths at random) Train the one-shot model DARTS (2019) NASNet search space Gradient descent (Softmax weights over operations) ProxylessNAS (2019) Tree structure architecture Gradient descent (BinaryConnect) or REINFORCE SNAS (2019) NASNet search space Gradient descent (concrete distribution) Citation Cited as:\n Weng, Lilian. (Aug 2020). Neural architecture search. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2020-08-06-nas/.\n Or\n@article{weng2020nas, title = \u0026quot;Neural Architecture Search\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2020\u0026quot;, month = \u0026quot;Aug\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2020-08-06-nas/\u0026quot; } Reference [1] Thomas Elsken, Jan Hendrik Metzen, Frank Hutter. \u0026ldquo;Neural Architecture Search: A Survey\u0026rdquo; JMLR 20 (2019) 1-21.\n[2] Kenneth O. Stanley, et al. \u0026ldquo;Designing neural networks through neuroevolution\u0026rdquo; Nature Machine Intelligence volume 1, pages 24–35 (2019).\n[3] Kenneth O. Stanley \u0026amp; Risto Miikkulainen. \u0026ldquo;Evolving Neural Networks through Augmenting Topologies\u0026rdquo; Evolutionary Computation 10(2): 99-127 (2002).\n[4] Barret Zoph, Quoc V. Le. \u0026ldquo;Neural architecture search with reinforcement learning\u0026rdquo; ICLR 2017.\n[5] Bowen Baker, et al. \u0026ldquo;Designing Neural Network Architectures using Reinforcement Learning\u0026rdquo; ICLR 2017.\n[6] Bowen Baker, et al. \u0026ldquo;Accelerating neural architecture search using performance prediction\u0026rdquo; ICLR Workshop 2018.\n[7] Barret Zoph, et al. \u0026ldquo;Learning transferable architectures for scalable image recognition\u0026rdquo; CVPR 2018.\n[8] Hanxiao Liu, et al. \u0026ldquo;Hierarchical representations for efficient architecture search.\u0026quot; ICLR 2018.\n[9] Esteban Real, et al. \u0026ldquo;Regularized Evolution for Image Classifier Architecture Search\u0026rdquo; arXiv:1802.01548 (2018).\n[10] Han Cai, et al. [\u0026ldquo;Efficient architecture search by network transformation\u0026rdquo;] AAAI 2018a.\n[11] Han Cai, et al. \u0026ldquo;Path-Level Network Transformation for Efficient Architecture Search\u0026rdquo; ICML 2018b.\n[12] Han Cai, Ligeng Zhu \u0026amp; Song Han. \u0026ldquo;ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware\u0026rdquo; ICLR 2019.\n[13] Chenxi Liu, et al. \u0026ldquo;Progressive neural architecture search\u0026rdquo; ECCV 2018.\n[14] Hieu Pham, et al. \u0026ldquo;Efficient neural architecture search via parameter sharing\u0026rdquo; ICML 2018.\n[15] Andrew Brock, et al. \u0026ldquo;SMASH: One-shot model architecture search through hypernetworks.\u0026quot; ICLR 2018.\n[16] Gabriel Bender, et al. \u0026ldquo;Understanding and simplifying one-shot architecture search.\u0026quot; ICML 2018.\n[17] Hanxiao Liu, Karen Simonyan, Yiming Yang. \u0026ldquo;DARTS: Differentiable Architecture Search\u0026rdquo; ICLR 2019.\n[18] Sirui Xie, Hehui Zheng, Chunxiao Liu, Liang Lin. \u0026ldquo;SNAS: Stochastic Neural Architecture Search\u0026rdquo; ICLR 2019.\n[19] Chenxi Liu et al. \u0026ldquo;Are Labels Necessary for Neural Architecture Search?\u0026quot; ECCV 2020.\n[20] Esteban Real, et al. \u0026ldquo;AutoML-Zero: Evolving Machine Learning Algorithms From Scratch\u0026rdquo; ICML 2020.\n","permalink":"https://lilianweng.github.io/posts/2020-08-06-nas/","summary":"Although most popular and successful model architectures are designed by human experts, it doesn\u0026rsquo;t mean we have explored the entire network architecture space and settled down with the best option. We would have a better chance to find the optimal solution if we adopt a systematic and automatic way of learning high-performance model architectures.\nAutomatically learning and evolving network topologies is not a new idea (Stanley \u0026amp; Miikkulainen, 2002). In recent years, the pioneering work by Zoph \u0026amp; Le 2017 and Baker et al.","title":"Neural Architecture Search"},{"content":"[Updated on 2020-06-17: Add \u0026ldquo;exploration via disagreement\u0026rdquo; in the \u0026ldquo;Forward Dynamics\u0026rdquo; section.\nExploitation versus exploration is a critical topic in Reinforcement Learning. We\u0026rsquo;d like the RL agent to find the best solution as fast as possible. However, in the meantime, committing to solutions too quickly without enough exploration sounds pretty bad, as it could lead to local minima or total failure. Modern RL algorithms that optimize for the best returns can achieve good exploitation quite efficiently, while exploration remains more like an open topic.\nI would like to discuss several common exploration strategies in Deep RL here. As this is a very big topic, my post by no means can cover all the important subtopics. I plan to update it periodically and keep further enriching the content gradually in time.\nClassic Exploration Strategies As a quick recap, let\u0026rsquo;s first go through several classic exploration algorithms that work out pretty well in the multi-armed bandit problem or simple tabular RL.\n Epsilon-greedy: The agent does random exploration occasionally with probability $\\epsilon$ and takes the optimal action most of the time with probability $1-\\epsilon$. Upper confidence bounds: The agent selects the greediest action to maximize the upper confidence bound $\\hat{Q}_t(a) + \\hat{U}_t(a)$, where $\\hat{Q}_t(a)$ is the average rewards associated with action $a$ up to time $t$ and $\\hat{U}_t(a)$ is a function reversely proportional to how many times action $a$ has been taken. See here for more details. Boltzmann exploration: The agent draws actions from a boltzmann distribution (softmax) over the learned Q values, regulated by a temperature parameter $\\tau$. Thompson sampling: The agent keeps track of a belief over the probability of optimal actions and samples from this distribution. See here for more details. The following strategies could be used for better exploration in deep RL training when neural networks are used for function approximation:\n Entropy loss term: Add an entropy term $H(\\pi(a \\vert s))$ into the loss function, encouraging the policy to take diverse actions. Noise-based Exploration: Add noise into the observation, action or even parameter space (Fortunato, et al. 2017, Plappert, et al. 2017). Key Exploration Problems Good exploration becomes especially hard when the environment rarely provides rewards as feedback or the environment has distracting noise. Many exploration strategies are proposed to solve one or both of the following problems.\nThe Hard-Exploration Problem The \u0026ldquo;hard-exploration\u0026rdquo; problem refers to exploration in an environment with very sparse or even deceptive reward. It is difficult because random exploration in such scenarios can rarely discover successful states or obtain meaningful feedback.\nMontezuma\u0026rsquo;s Revenge is a concrete example for the hard-exploration problem. It remains as a few challenging games in Atari for DRL to solve. Many papers use Montezuma\u0026rsquo;s Revenge to benchmark their results.\nThe Noisy-TV Problem The \u0026ldquo;Noisy-TV\u0026rdquo; problem started as a thought experiment in Burda, et al (2018). Imagine that an RL agent is rewarded with seeking novel experience, a TV with uncontrollable \u0026amp; unpredictable random noise outputs would be able to attract the agent\u0026rsquo;s attention forever. The agent obtains new rewards from noisy TV consistently, but it fails to make any meaningful progress and becomes a \u0026ldquo;couch potato\u0026rdquo;.\nFig. 1. An agent is rewarded with novel experience in the experiment. If a maze has a noisy TC set up, the agent would be attracted and stop moving in the maze. (Image source: OpenAI Blog: \"Reinforcement Learning with Prediction-Based Rewards\") Intrinsic Rewards as Exploration Bonuses One common approach to better exploration, especially for solving the hard-exploration problem, is to augment the environment reward with an additional bonus signal to encourage extra exploration. The policy is thus trained with a reward composed of two terms, $r_t = r^e_t + \\beta r^i_t$, where $\\beta$ is a hyperparameter adjusting the balance between exploitation and exploration.\n $r^e_t$ is an extrinsic reward from the environment at time $t$, defined according to the task in hand. $r^i_t$ is an intrinsic exploration bonus at time $t$. This intrinsic reward is somewhat inspired by intrinsic motivation in psychology (Oudeyer \u0026amp; Kaplan, 2008). Exploration driven by curiosity might be an important way for children to grow and learn. In other words, exploratory activities should be rewarding intrinsically in the human mind to encourage such behavior. The intrinsic rewards could be correlated with curiosity, surprise, familiarity of the state, and many other factors.\nSame ideas can be applied to RL algorithms. In the following sections, methods of bonus-based exploration rewards are roughly grouped into two categories:\n Discovery of novel states Improvement of the agent\u0026rsquo;s knowledge about the environment. Count-based Exploration If we consider intrinsic rewards as rewarding conditions that surprise us, we need a way to measure whether a state is novel or appears often. One intuitive way is to count how many times a state has been encountered and to assign a bonus accordingly. The bonus guides the agent\u0026rsquo;s behavior to prefer rarely visited states to common states. This is known as the count-based exploration method.\nLet $N_n(s)$ be the empirical count function that tracks the real number of visits of a state $s$ in the sequence of $s_{1:n}$. Unfortunately, using $N_n(s)$ for exploration directly is not practical, because most of the states would have $N_n(s)=0$, especially considering that the state space is often continuous or high-dimensional. We need an non-zero count for most states, even when they haven\u0026rsquo;t been seen before.\nCounting by Density Model Bellemare, et al. (2016) used a density model to approximate the frequency of state visits and a novel algorithm for deriving a pseudo-count from this density model. Let\u0026rsquo;s first define a conditional probability over the state space, $\\rho_n(s) = \\rho(s \\vert s_{1:n})$ as the probability of the $(n+1)$-th state being $s$ given the first $n$ states are $s_{1:n}$. To measure this empirically, we can simply use $N_n(s)/n$.\nLet\u0026rsquo;s also define a recoding probability of a state $s$ as the probability assigned by the density model to $s$ after observing a new occurrence of $s$, $\\rho'_n(s) = \\rho(s \\vert s_{1:n}s)$.\nThe paper introduced two concepts to better regulate the density model, a pseudo-count function $\\hat{N}_n(s)$ and a pseudo-count total $\\hat{n}$. As they are designed to imitate an empirical count function, we would have:\n $$ \\rho_n(s) = \\frac{\\hat{N}_n(s)}{\\hat{n}} \\leq \\rho'_n(s) = \\frac{\\hat{N}_n(s) + 1}{\\hat{n} + 1} $$ The relationship between $\\rho_n(x)$ and $\\rho'_n(x)$ requires the density model to be learning-positive: for all $s_{1:n} \\in \\mathcal{S}^n$ and all $s \\in \\mathcal{S}$, $\\rho_n(s) \\leq \\rho'_n(s)$. In other words, After observing one instance of $s$, the density model\u0026rsquo;s prediction of that same $s$ should increase. Apart from being learning-positive, the density model should be trained completely online with non-randomized mini-batches of experienced states, so naturally we have $\\rho'_n = \\rho_{n+1}$.\nThe pseudo-count can be computed from $\\rho_n(s)$ and $\\rho'_n(s)$ after solving the above linear system:\n $$ \\hat{N}_n(s) = \\hat{n} \\rho_n(s) = \\frac{\\rho_n(s)(1 - \\rho'_n(s))}{\\rho'_n(s) - \\rho_n(s)} $$ Or estimated by the prediction gain (PG):\n $$ \\hat{N}_n(s) \\approx (e^{\\text{PG}_n(s)} - 1)^{-1} = (e^{\\log \\rho'_n(s) - \\log \\rho(s)} - 1)^{-1} $$ A common choice of a count-based intrinsic bonus is $r^i_t = N(s_t, a_t)^{-1/2}$ (as in MBIE-EB; Strehl \u0026amp; Littman, 2008). The pseudo-count-based exploration bonus is shaped in a similar form, $r^i_t = \\big(\\hat{N}_n(s_t, a_t) + 0.01 \\big)^{-1/2}$.\nExperiments in Bellemare et al., (2016) adopted a simple CTS (Context Tree Switching) density model to estimate pseudo-counts. The CTS model takes as input a 2D image and assigns to it a probability according to the product of location-dependent L-shaped filters, where the prediction of each filter is given by a CTS algorithm trained on past images. The CTS model is simple but limited in expressiveness, scalability, and data efficiency. In a following-up paper, Georg Ostrovski, et al. (2017) improved the approach by training a PixelCNN (van den Oord et al., 2016) as the density model.\nThe density model can also be a Gaussian Mixture Model as in Zhao \u0026amp; Tresp (2018). They used a variational GMM to estimate the density of trajectories (e.g. concatenation of a sequence of states) and its predicted probabilities to guide prioritization in experience replay in off-policy setting.\nCounting after Hashing Another idea to make it possible to count high-dimensional states is to map states into hash codes so that the occurrences of states become trackable (Tang et al. 2017). The state space is discretized with a hash function $\\phi: \\mathcal{S} \\mapsto \\mathbb{Z}^k$. An exploration bonus $r^{i}: \\mathcal{S} \\mapsto \\mathbb{R}$ is added to the reward function, defined as $r^{i}(s) = {N(\\phi(s))}^{-1/2}$, where $N(\\phi(s))$ is an empirical count of occurrences of $\\phi(s)$.\nTang et al. (2017) proposed to use Locality-Sensitive Hashing (LSH) to convert continuous, high-dimensional data to discrete hash codes. LSH is a popular class of hash functions for querying nearest neighbors based on certain similarity metrics. A hashing scheme $x \\mapsto h(x)$ is locality-sensitive if it preserves the distancing information between data points, such that close vectors obtain similar hashes while distant vectors have very different ones. (See how LSH is used in Transformer improvement if interested.) SimHash is a type of computationally efficient LSH and it measures similarity by angular distance:\n $$ \\phi(s) = \\text{sgn}(A g(s)) \\in \\{-1, 1\\}^k $$ where $A \\in \\mathbb{R}^{k \\times D}$ is a matrix with each entry drawn i.i.d. from a standard Gaussian and $g: \\mathcal{S} \\mapsto \\mathbb{R}^D$ is an optional preprocessing function. The dimension of binary codes is $k$, controlling the granularity of the state space discretization. A higher $k$ leads to higher granularity and fewer collisions.\nFig. 2. Algorithm of count-based exploration through hashing high-dimensional states by SimHash. (Image source: Tang et al. 2017) For high-dimensional images, SimHash may not work well on the raw pixel level. Tang et al. (2017) designed an autoencoder (AE) which takes as input states $s$ to learn hash codes. It has one special dense layer composed of $k$ sigmoid functions as the latent state in the middle and then the sigmoid activation values $b(s)$ of this layer are binarized by rounding to their closest binary numbers $\\lfloor b(s)\\rceil \\in \\{0, 1\\}^D$ as the binary hash codes for state $s$. The AE loss over $n$ states includes two terms:\n $$ \\mathcal{L}(\\{s_n\\}_{n=1}^N) = \\underbrace{-\\frac{1}{N} \\sum_{n=1}^N \\log p(s_n)}_\\text{reconstruction loss} + \\underbrace{\\frac{1}{N} \\frac{\\lambda}{K} \\sum_{n=1}^N\\sum_{i=1}^k \\min \\big \\{ (1-b_i(s_n))^2, b_i(s_n)^2 \\big\\}}_\\text{sigmoid activation being closer to binary} $$ One problem with this approach is that dissimilar inputs $s_i, s_j$ may be mapped to identical hash codes but the AE still reconstructs them perfectly. One can imagine replacing the bottleneck layer $b(s)$ with the hash codes $\\lfloor b(s)\\rceil$, but then gradients cannot be back-propagated through the rounding function. Injecting uniform noise could mitigate this effect, as the AE has to learn to push the latent variable far apart to counteract the noise.\nPrediction-based Exploration The second category of intrinsic exploration bonuses are rewarded for improvement of the agent\u0026rsquo;s knowledge about the environment. The agent\u0026rsquo;s familiarity with the environment dynamics can be estimated through a prediction model. This idea of using a prediction model to measure curiosity was actually proposed quite a long time ago (Schmidhuber, 1991).\nForward Dynamics Learning a forward dynamics prediction model is a great way to approximate how much knowledge our model has obtained about the environment and the task MDPs. It captures an agent\u0026rsquo;s capability of predicting the consequence of its own behavior, $f: (s_t, a_t) \\mapsto s_{t+1}$. Such a model cannot be perfect (e.g. due to partial observation), the error $e(s_t, a_t) = | f(s_t, a_t) - s_{t+1} |^2_2$ can be used for providing intrinsic exploration rewards. The higher the prediction error, the less familiar we are with that state. The faster the error rate drops, the more learning progress signals we acquire.\nIntelligent Adaptive Curiosity (IAC; Oudeyer, et al. 2007) sketched an idea of using a forward dynamics prediction model to estimate learning progress and assigned intrinsic exploration reward accordingly.\nIAC relies on a memory which stores all the experiences encountered by the robot, $M=\\{(s_t, a_t, s_{t+1})\\}$ and a forward dynamics model $f$. IAC incrementally splits the state space (i.e. sensorimotor space in the context of robotics, as discussed in the paper) into separate regions based on the transition samples, using a process similar to how a decision tree is split: The split happens when the number of samples is larger than a threshold, and the variance of states in each leaf should be minimal. Each tree node is characterized by its exclusive set of samples and has its own forward dynamics predictor $f$, named \u0026ldquo;expert\u0026rdquo;.\nThe prediction error $e_t$ of an expert is pushed into a list associated with each region. The learning progress is then measured as the difference between the mean error rate of a moving window with offset $\\tau$ and the current moving window. The intrinsic reward is defined for tracking the learning progress: $r^i_t = \\frac{1}{k}\\sum_{i=0}^{k-1}(e_{t-i-\\tau} - e_{t-i})$, where $k$ is the moving window size. So the larger prediction error rate decrease we can achieve, the higher intrinsic reward we would assign to the agent. In other words, the agent is encouraged to take actions to quickly learn about the environment.\nFig. 3. Architecture of the IAC (Intelligent Adaptive Curiosity) module: the intrinsic reward is assigned w.r.t the learning progress in reducing prediction error of the dynamics model. (Image source: Oudeyer, et al. 2007) Stadie et al. (2015) trained a forward dynamics model in the encoding space defined by $\\phi$, $f_\\phi: (\\phi(s_t), a_t) \\mapsto \\phi(s_{t+1})$. The model\u0026rsquo;s prediction error at time $T$ is normalized by the maximum error up to time $t$, $\\bar{e}_t = \\frac{e_t}{\\max_{i \\leq t} e_i}$, so it is always between 0 and 1. The intrinsic reward is defined accordingly: $r^i_t = (\\frac{\\bar{e}_t(s_t, a_t)}{t \\cdot C})$, where $C \u0026gt; 0$ is a decay constant.\nEncoding the state space via $\\phi(.)$ is necessary, as experiments in the paper have shown that a dynamics model trained directly on raw pixels has very poor behavior \u0026mdash; assigning same exploration bonuses to all the states. In Stadie et al. (2015), the encoding function $\\phi$ is learned via an autocoder (AE) and $\\phi(.)$ is one of the output layers in AE. The AE can be statically trained using a set of images collected by a random agent, or dynamically trained together with the policy where the early frames are gathered using $\\epsilon$-greedy exploration.\nInstead of autoencoder, Intrinsic Curiosity Module (ICM; Pathak, et al., 2017) learns the state space encoding $\\phi(.)$ with a self-supervised inverse dynamics model. Predicting the next state given the agent\u0026rsquo;s own action is not easy, especially considering that some factors in the environment cannot be controlled by the agent or do not affect the agent. ICM believes that a good state feature space should exclude such factors because they cannot influence the agent\u0026rsquo;s behavior and thus the agent has no incentive for learning them. By learning an inverse dynamics model $g: (\\phi(s_t), \\phi(s_{t+1})) \\mapsto a_t$, the feature space only captures those changes in the environment related to the actions of our agent, and ignores the rest.\nGiven a forward model $f$, an inverse dynamics model $g$ and an observation $(s_t, a_t, s_{t+1})$:\n $$ g_{\\psi_I}(\\phi(s_t), \\phi(s_{t+1})) = \\hat{a}_t \\quad f_{\\psi_F}(\\phi(s_t), a_t) = \\hat{\\phi}(s_{t+1}) \\quad r_t^i = \\| \\hat{\\phi}(s_{t+1}) - \\phi(s_{t+1}) \\|_2^2 $$ Such $\\phi(.)$ is expected to be robust to uncontrollable aspects of the environment.\nFig. 4. ICM (Intrinsic Curiosity Module) assigns the forward dynamics prediction error to the agent as the intrinsic reward. This dynamics model operates in a state encoding space learned through an inverse dynamics model to exclude environmental factors that do not affect the agent's behavior. (Image source: Pathak, et al. 2017) Burda, Edwards \u0026amp; Pathak, et al. (2018) did a set of large-scale comparison experiments on purely curiosity-driven learning, meaning that only intrinsic rewards are provided to the agent. In this study, the reward is $r_t = r^i_t = | f(s_t, a_t) - \\phi(s_{t+1})|_2^2$. A good choice of $\\phi$ is crucial to learning forward dynamics, which is expected to be compact, sufficient and stable, making the prediction task more tractable and filtering out irrelevant observation.\nIn comparison of 4 encoding functions:\n Raw image pixels: No encoding, $\\phi(x) = x$. Random features (RF): Each state is compressed through a fixed random neural network. VAE: The probabilistic encoder is used for encoding, $\\phi(x) = q(z \\vert x)$. Inverse dynamic features (IDF): The same feature space as used in ICM. All the experiments have the reward signals normalized by a running estimation of standard deviation of the cumulative returns. And all the experiments are running in an infinite horizon setting to avoid \u0026ldquo;done\u0026rdquo; flag leaking information.\nFig. 5. The mean reward in different games when training with only curiosity signals, generated by different state encoding functions. (Image source: Burda, Edwards \u0026 Pathak, et al. 2018) Interestingly random features turn out to be quite competitive, but in feature transfer experiments (i.e. train an agent in Super Mario Bros level 1-1 and then test it in another level), learned IDF features can generalize better.\nThey also compared RF and IDF in an environment with a noisy TV on. Unsurprisingly the noisy TV drastically slows down the learning and extrinsic rewards are much lower in time.\nFig. 6. Experiments using RF and IDF feature encoding in an environment with noisy TV on or off. The plot tracks extrinsic reward per episode as the training progresses. (Image source: Burda, Edwards \u0026 Pathak, et al. 2018) The forward dynamics optimization can be modeled via variational inference as well. VIME (short for \u0026ldquo;Variational information maximizing exploration\u0026rdquo;; Houthooft, et al. 2017) is an exploration strategy based on maximization of information gain about the agent\u0026rsquo;s belief of environment dynamics. How much additional information has been obtained about the forward dynamics can be measured as the reduction in entropy.\nLet $\\mathcal{P}$ be the environment transition function, $p(s_{t+1}\\vert s_t, a_t; \\theta)$ be the forward prediction model, parameterized by $\\theta \\in \\Theta$, and $\\xi_t = \\{s_1, a_1, \\dots, s_t\\}$ be the trajectory history. We would like to reduce the entropy after taking a new action and observing the next state, which is to maximize the following:\n $$ \\begin{aligned} \u0026\\sum_t H(\\Theta \\vert \\xi_t, a_t) - H(\\Theta \\vert S_{t+1}, \\xi_t, a_t) \\\\ =\u0026 I(\\Theta; S_{t+1} \\vert \\xi_t, a_t) \\quad \\scriptstyle{\\text{; because } I(X; Y) = I(X) - I(X \\vert Y)} \\\\ =\u0026 \\mathbb{E}_{s_{t+1} \\sim \\mathcal{P}(.\\vert\\xi_t,a_t)} [D_\\text{KL}(p(\\theta \\vert \\xi_t, a_t, s_{t+1}) \\| p(\\theta \\vert \\xi_t, a_t))] \\quad \\scriptstyle{\\text{; because } I(X; Y) = \\mathbb{E}_Y [D_\\text{KL} (p_{X \\vert Y} \\| p_X)]} \\\\ =\u0026 \\mathbb{E}_{s_{t+1} \\sim \\mathcal{P}(.\\vert\\xi_t,a_t)} [D_\\text{KL}(p(\\theta \\vert \\xi_t, a_t, s_{t+1}) \\| p(\\theta \\vert \\xi_t))] \\quad \\scriptstyle{\\text{; because } \\theta \\text{ does not depend on } a_t} \\end{aligned} $$ While taking expectation over the new possible states, the agent is expected to take a new action to increase the KL divergence (\u0026ldquo;information gain\u0026rdquo;) between its new belief over the prediction model to the old one. This term can be added into the reward function as an intrinsic reward: $r^i_t = D_\\text{KL} [p(\\theta \\vert \\xi_t, a_t, s_{t+1}) | p(\\theta \\vert \\xi_t))]$.\nHowever, computing the posterior $p(\\theta \\vert \\xi_t, a_t, s_{t+1})$ is generally intractable.\n $$ \\begin{aligned} p(\\theta \\vert \\xi_t, a_t, s_{t+1}) \u0026= \\frac{p(\\theta \\vert \\xi_t, a_t) p(s_{t+1} \\vert \\xi_t, a_t; \\theta)}{p(s_{t+1}\\vert\\xi_t, a_t)} \\\\ \u0026= \\frac{p(\\theta \\vert \\xi_t) p(s_{t+1} \\vert \\xi_t, a_t; \\theta)}{p(s_{t+1}\\vert\\xi_t, a_t)} \u0026 \\scriptstyle{\\text{; because action doesn't affect the belief.}} \\\\ \u0026= \\frac{\\color{red}{p(\\theta \\vert \\xi_t)} p(s_{t+1} \\vert \\xi_t, a_t; \\theta)}{\\int_\\Theta p(s_{t+1}\\vert\\xi_t, a_t; \\theta) \\color{red}{p(\\theta \\vert \\xi_t)} d\\theta} \u0026 \\scriptstyle{\\text{; red part is hard to compute directly.}} \\end{aligned} $$ Since it is difficult to compute $p(\\theta\\vert\\xi_t)$ directly, a natural choice is to approximate it with an alternative distribution $q_\\phi(\\theta)$. With variational lower bound, we know the maximization of $q_\\phi(\\theta)$ is equivalent to maximizing $p(\\xi_t\\vert\\theta)$ and minimizing $D_\\text{KL}[q_\\phi(\\theta) | p(\\theta)]$.\nUsing the approximation distribution $q$, the intrinsic reward becomes:\n $$ r^i_t = D_\\text{KL} [q_{\\phi_{t+1}}(\\theta) \\| q_{\\phi_t}(\\theta))] $$ where $\\phi_{t+1}$ represents $q$\u0026rsquo;s parameters associated with the new relief after seeing $a_t$ and $s_{t+1}$. When used as an exploration bonus, it is normalized by division by the moving median of this KL divergence value.\nHere the dynamics model is parameterized as a Bayesian neural network (BNN), as it maintains a distribution over its weights. The BNN weight distribution $q_\\phi(\\theta)$ is modeled as a fully factorized Gaussian with $\\phi = \\{\\mu, \\sigma\\}$ and we can easily sample $\\theta \\sim q_\\phi(.)$. After applying a second-order Taylor expansion, the KL term $D_\\text{KL}[q_{\\phi + \\lambda \\Delta\\phi}(\\theta) | q_{\\phi}(\\theta)]$ can be estimated using Fisher Information Matrix $\\mathbf{F}_\\phi$, which is easy to compute, because $q_\\phi$ is factorized Gaussian and thus the covariance matrix is only a diagonal matrix. See more details in the paper, especially section 2.3-2.5.\nAll the methods above depend on a single prediction model. If we have multiple such models, we could use the disagreement among models to set the exploration bonus (Pathak, et al. 2019). High disagreement indicates low confidence in prediction and thus requires more exploration. Pathak, et al. (2019) proposed to train a set of forward dynamics models and to use the variance over the ensemble of model outputs as $r_t^i$. Precisely, they encode the state space with random feature and learn 5 models in the ensemble.\nFig. 7. Illustration of training architecture for self-supervised exploration via disagreement. (Image source: Pathak, et al. 2019) Because $r^i_t$ is differentiable, the intrinsic reward in the model could be directly optimized through gradient descent so as to inform the policy agent to change actions. This differentiable exploration approach is very efficient but limited by having a short exploration horizon.\nRandom Networks But, what if the prediction task is not about the environment dynamics at all? It turns out when the prediction is for a random task, it still can help exploration.\nDORA (short for \u0026ldquo;Directed Outreaching Reinforcement Action-Selection\u0026rdquo;; Fox \u0026amp; Choshen, et al. 2018) is a novel framework that injects exploration signals based on a newly introduced, task-independent MDP. The idea of DORA depends on two parallel MDPs:\n One is the original task MDP; The other is an identical MDP but with no reward attached: Rather, every state-action pair is designed to have value 0. The Q-value learned for the second MDP is called E-value. If the model cannot perfectly predict E-value to be zero, it is still missing information. Initially E-value is assigned with value 1. Such positive initialization can encourage directed exploration for better E-value prediction. State-action pairs with high E-value estimation don\u0026rsquo;t have enough information gathered yet, at least not enough to exclude their high E-values. To some extent, the logarithm of E-values can be considered as a generalization of visit counters.\nWhen using a neural network to do function approximation for E-value, another value head is added to predict E-value and it is simply expected to predict zero. Given a predicted E-value $E(s_t, a_t)$, the exploration bonus is $r^i_t = \\frac{1}{\\sqrt{-\\log E(s_t, a_t)}}$.\nSimilar to DORA, Random Network Distillation (RND; Burda, et al. 2018) introduces a prediction task independent of the main task. The RND exploration bonus is defined as the error of a neural network $\\hat{f}(s_t)$ predicting features of the observations given by a fixed randomly initialized neural network $f(s_t)$. The motivation is that given a new state, if similar states have been visited many times in the past, the prediction should be easier and thus has lower error. The exploration bonus is $r^i(s_t) = |\\hat{f}(s_t; \\theta) - f(s_t) |_2^2$.\nFig. 8. How RND (Random Network Distillation) works for providing an intrinsic reward. The features $O_{i+1} \\mapsto f_{i+1}$ are generated by a fixed random neural network. (Image source: OpenAI Blog: \"Reinforcement Learning with Prediction-Based Rewards\") Two factors are important in RND experiments:\n Non-episodic setting results in better exploration, especially when not using any extrinsic rewards. It means that the return is not truncated at \u0026ldquo;Game over\u0026rdquo; and intrinsic return can spread across multiple episodes. Normalization is important since the scale of the reward is tricky to adjust given a random neural network as a prediction target. The intrinsic reward is normalized by division by a running estimate of the standard deviations of the intrinsic return. The RND setup works well for resolving the hard-exploration problem. For example, maximizing the RND exploration bonus consistently finds more than half of the rooms in Montezuma\u0026rsquo;s Revenge.\nPhysical Properties Different from games in simulators, some RL applications like Robotics need to understand objects and intuitive reasoning in the physical world. Some prediction tasks require the agent to perform a sequence of interactions with the environment and to observe the corresponding consequences, such as estimating some hidden properties in physics (e.g. mass, friction, etc).\nMotivated by such ideas, Denil, et al. (2017) found that DRL agents can learn to perform necessary exploration to discover such hidden properties. Precisely they considered two experiments:\n \u0026ldquo;Which is heavier?\u0026quot; \u0026mdash; The agent has to interact with the blocks and infer which one is heavier. \u0026ldquo;Towers\u0026rdquo; \u0026mdash; The agent needs to infer how many rigid bodies a tower is composed of by knocking it down. The agent in the experiments first goes through an exploration phase to interact with the environment and to collect information. Once the exploration phase ends, the agent is asked to output a labeling action to answer the question. Then a positive reward is assigned to the agent if the answer is correct; otherwise a negative one is assigned. Because the answer requires a decent amount of interactions with items in the scene, the agent has to learn to efficiently play around so as to figure out the physics and the correct answer. The exploration naturally happens.\nIn their experiments, the agent is able to learn in both tasks with performance varied by the difficulty of the task. Although the paper didn\u0026rsquo;t use the physics prediction task to provide intrinsic reward bonus along with extrinsic reward associated with another learning task, rather it focused on the exploration tasks themselves. I do enjoy the idea of encouraging sophisticated exploration behavior by predicting hidden physics properties in the environment.\nMemory-based Exploration Reward-based exploration suffers from several drawbacks:\n Function approximation is slow to catch up. Exploration bonus is non-stationary. Knowledge fading, meaning that states cease to be novel and cannot provide intrinsic reward signals in time. Methods in this section rely on external memory to resolve disadvantages of reward bonus-based exploration.\nEpisodic Memory As mentioned above, RND is better running in an non-episodic setting, meaning the prediction knowledge is accumulated across multiple episodes. The exploration strategy, Never Give Up (NGU; Badia, et al. 2020a), combines an episodic novelty module that can rapidly adapt within one episode with RND as a lifelong novelty module.\nPrecisely, the intrinsic reward in NGU consists of two exploration bonuses from two modules, within one episode and across multiple episodes, respectively.\nThe short-term per-episode reward is provided by an episodic novelty module. It contains an episodic memory $M$, a dynamically-sized slot-based memory, and an IDF (inverse dynamics features) embedding function $\\phi$, same as the feature encoding in ICM\n At every step the current state embedding $\\phi(s_t)$ is added into $M$.\n The intrinsic bonus is determined by comparing how similar the current observation is to the content of $M$. A larger difference results in a larger bonus.\n $$ r^\\text{episodic}_t \\approx \\frac{1}{\\sqrt{\\sum_{\\phi_i \\in N_k} K(\\phi(x_t), \\phi_i)} + c} $$ where $K(x, y)$ is a kernel function for measuring the distance between two samples. $N_k$ is a set of $k$ nearest neighbors in $M$ according to $K(., .)$. $c$ is a small constant to keep the denominator non-zero. In the paper, $K(x, y)$ is configured to be the inverse kernel:\n $$ K(x, y) = \\frac{\\epsilon}{\\frac{d^2(x, y)}{d^2_m} + \\epsilon} $$ where $d(.,.)$ is Euclidean distance between two samples and $d_m$ is a running average of the squared Euclidean distance of the k-th nearest neighbors for better robustness. $\\epsilon$ is a small constant.\n Fig. 9. The architecture of NGU's embedding function (left) and reward generator (right). (Image source: Badia, et al. 2020a) The long-term across-episode novelty relies on RND prediction error in life-long novelty module. The exploration bonus is $\\alpha_t = 1 + \\frac{e^\\text{RND}(s_t) - \\mu_e}{\\sigma_e}$ where $\\mu_e$ and $\\sigma_e$ are running mean and std dev for RND error $e^\\text{RND}(s_t)$.\n However in the conclusion section of the RND paper, I noticed the following statement:\n\u0026ldquo;We find that the RND exploration bonus is sufficient to deal with local exploration, i.e. exploring the consequences of short-term decisions, like whether to interact with a particular object, or avoid it. However global exploration that involves coordinated decisions over long time horizons is beyond the reach of our method. \u0026quot;\nAnd this confuses me a bit how RND can be used as a good life-long novelty bonus provider. If you know why, feel free to leave a comment below.\n The final combined intrinsic reward is $r^i_t = r^\\text{episodic}_t \\cdot \\text{clip}(\\alpha_t, 1, L)$, where $L$ is a constant maximum reward scalar.\nThe design of NGU enables it to have two nice properties:\n Rapidly discourages revisiting the same state within the same episode; Slowly discourages revisiting states that have been visited many times across episodes. Later, built on top of NGU, DeepMind proposed \u0026ldquo;Agent57\u0026rdquo; (Badia, et al. 2020b), the first deep RL agent that outperforms the standard human benchmark on all 57 Atari games. Two major improvements in Agent57 over NGU are:\n A population of policies are trained in Agent57, each equipped with a different exploration parameter pair $\\{(\\beta_j, \\gamma_j)\\}_{j=1}^N$. Recall that given $\\beta_j$, the reward is constructed as $r_{j,t} = r_t^e + \\beta_j r^i_t$ and $\\gamma_j$ is the reward discounting factor. It is natural to expect policies with higher $\\beta_j$ and lower $\\gamma_j$ to make more progress early in training, while the opposite would be expected as training progresses. A meta-controller (sliding-window UCB bandit algorithm) is trained to select which policies should be prioritized. The second improvement is a new parameterization of Q-value function that decomposes the contributions of the intrinsic and extrinsic rewards in a similar form as the bundled reward: $Q(s, a; \\theta_j) = Q(s, a; \\theta_j^e) + \\beta_j Q(s, a; \\theta_j^i)$. During training, $Q(s, a; \\theta_j^e)$ and $Q(s, a; \\theta_j^i)$ are optimized separately with rewards $r_j^e$ and $r_j^i$, respectively. Fig. 10. A pretty cool illustration of techniques developed in time since DQN in 2015, eventually leading to Agent57. (Image source: DeepMind Blog: \"Agent57: Outperforming the human Atari benchmark\") Instead of using the Euclidean distance to measure closeness of states in episodic memory, Savinov, et al. (2019) took the transition between states into consideration and proposed a method to measure the number of steps needed to visit one state from other states in memory, named Episodic Curiosity (EC) module. The novelty bonus depends on reachability between states.\n At the beginning of each episode, the agent starts with an empty episodic memory $M$. At every step, the agent compares the current state with saved states in memory to determine novelty bonus: If the current state is novel (i.e., takes more steps to reach from observations in memory than a threshold), the agent gets a bonus. The current state is added into the episodic memory if the novelty bonus is high enough. (Imagine that if all the states were added into memory, any new state could be added within 1 step.) Repeat 1-3 until the end of this episode. Fig. 11. The nodes in the graph are states, the edges are possible transitions. The blue nodes are states in memory. The green nodes are reachable from the memory within $k = 2$ steps (not novel). The orange nodes are further away, so they are considered as novel states. (Image source: Savinov, et al. 2019) In order to estimate reachability between states, we need to access the transition graph, which is unfortunately not entirely known. Thus, Savinov, et al. (2019) trained a siamese neural network to predict how many steps separate two states. It contains one embedding network $\\phi: \\mathcal{S} \\mapsto \\mathbb{R}^n$ to first encode the states to feature vectors and then one comparator network $C: \\mathbb{R}^n \\times \\mathbb{R}^n \\mapsto [0, 1]$ to output a binary label on whether two states are close enough (i.e., reachable within $k$ steps) in the transition graph, $C(\\phi(s_i), \\phi(s_j)) \\mapsto [0, 1]$.\nAn episodic memory buffer $M$ stores embeddings of some past observations within the same episode. A new observation will be compared with existing state embeddings via $C$ and the results are aggregated (e.g. max, 90th percentile) to provide a reachability score $C^M(\\phi(s_t))$. The exploration bonus is $r^i_t = \\big(C' - C^M(f(s_t))\\big)$, where $C'$ is a predefined threshold for determining the sign of the reward (e.g. $C'=0.5$ works well for fixed-duration episodes). High bonus is awarded to new states when they are not easily reachable from states in the memory buffer.\nThey claimed that the EC module can overcome the noisy-TV problem.\nFig. 12. The architecture of episodic curiosity (EC) module for intrinsic reward generation. (Image source: Savinov, et al. 2019) Direct Exploration Go-Explore (Ecoffet, et al., 2019) is an algorithm aiming to solve the \u0026ldquo;hard-exploration\u0026rdquo; problem. It is composed of the following two phases.\nPhase 1 (\u0026ldquo;Explore until solved\u0026rdquo;) feels quite like Dijkstra\u0026rsquo;s algorithm for finding shortest paths in a graph. Indeed, no neural network is involved in phase 1. By maintaining a memory of interesting states as well as trajectories leading to them, the agent can go back (given a simulator is deterministic) to promising states and continue doing random exploration from there. The state is mapped into a short discretized code (named \u0026ldquo;cell\u0026rdquo;) in order to be memorized. The memory is updated if a new state appears or a better/shorter trajectory is found. When selecting which past states to return to, the agent might select one in the memory uniformly or according to heuristics like recency, visit count, count of neighbors in the memory, etc. This process is repeated until the task is solved and at least one solution trajectory is found.\nThe above found high-performance trajectories would not work well on evaluation envs with any stochasticity. Thus, Phase 2 (\u0026ldquo;Robustification\u0026rdquo;) is needed to robustify the solution via imitation learning. They adopted Backward Algorithm, in which the agent is started near the last state in the trajectory and then runs RL optimization from there.\nOne important note in phase 1 is: In order to go back to a state deterministically without exploration, Go-Explore depends on a resettable and deterministic simulator, which is a big disadvantage.\nTo make the algorithm more generally useful to environments with stochasticity, an enhanced version of Go-Explore (Ecoffet, et al., 2020), named policy-based Go-Explore was proposed later.\n Instead of resetting the simulator state effortlessly, the policy-based Go-Explore learns a goal-conditioned policy and uses that to access a known state in memory repeatedly. The goal-conditioned policy is trained to follow the best trajectory that previously led to the selected states in memory. They include a Self-Imitation Learning (SIL; Oh, et al. 2018) loss to help extract as much information as possible from successful trajectories. Also, they found sampling from policy works better than random actions when the agent returns to promising states to continue exploration. Another improvement in policy-based Go-Explore is to make the downscaling function of images to cells adjustable. It is optimized so that there would be neither too many nor too few cells in the memory. Fig. 13. An overview of the Go-Explore algorithm. (Image source: Ecoffet, et al., 2020) After vanilla Go-Explore, Yijie Guo, et al. (2019) proposed DTSIL (Diverse Trajectory-conditioned Self-Imitation Learning), which shared a similar idea as policy-based Go-Explore above. DTSIL maintains a memory of diverse demonstrations collected during training and uses them to train a trajectory-conditioned policy via SIL. They prioritize trajectories that end with a rare state during sampling.\nFig. 14. Algorithm of DTSIL (Diverse Trajectory-conditioned Self-Imitation Learning). (Image source: Yijie Guo, et al. 2019) The similar approach is also seen in Guo, et al. (2019). The main idea is to store goals with high uncertainty in memory so that later the agent can revisit these goal states with a goal-conditioned policy repeatedly. In each episode, the agent flips a coin (probability 0.5) to decide whether it will act greedily w.r.t. the policy or do directed exploration by sampling goals from the memory.\nFig. 15. Different components in directed exploration with function approximation. (Image source: Guo, et al. 2019) The uncertainty measure of a state can be something simple like count-based bonuses or something complex like density or bayesian models. The paper trained a forward dynamics model and took its prediction error as the uncertainty metric.\nQ-Value Exploration Inspired by Thompson sampling, Bootstrapped DQN (Osband, et al. 2016) introduces a notion of uncertainty in Q-value approximation in classic DQN by using the bootstrapping method. Bootstrapping is to approximate a distribution by sampling with replacement from the same population multiple times and then aggregate the results.\nMultiple Q-value heads are trained in parallel but each only consumes a bootstrapped sub-sampled set of data and each has its own corresponding target network. All the Q-value heads share the same backbone network.\nFig. 16. The algorithm of Bootstrapped DQN. (Image source: Osband, et al. 2016) At the beginning of one episode, one Q-value head is sampled uniformly and acts for collecting experience data in this episode. Then a binary mask is sampled from the masking distribution $m \\sim \\mathcal{M}$ and decides which heads can use this data for training. The choice of masking distribution $\\mathcal{M}$ determines how bootstrapped samples are generated; For example,\n If $\\mathcal{M}$ is an independent Bernoulli distribution with $p=0.5$, this corresponds to the double-or-nothing bootstrap. If $\\mathcal{M}$ always returns an all-one mask, the algorithm reduces to an ensemble method. However, this kind of exploration is still restricted, because uncertainty introduced by bootstrapping fully relies on the training data. It is better to inject some prior information independent of the data. This \u0026ldquo;noisy\u0026rdquo; prior is expected to drive the agent to keep exploring when the reward is sparse. The algorithm of adding random prior into bootstrapped DQN for better exploration (Osband, et al. 2018) depends on Bayesian linear regression. The core idea of Bayesian regression is: We can \u0026ldquo;generate posterior samples by training on noisy versions of the data, together with some random regularization\u0026rdquo;.\nLet $\\theta$ be the Q function parameter and $\\theta^-$ for the target Q, the loss function using a randomized prior function $p$ is:\n $$ \\mathcal{L}(\\theta, \\theta^{-}, p, \\mathcal{D}; \\gamma) = \\sum_{t\\in\\mathcal{D}}\\Big( r_t + \\gamma \\max_{a'\\in\\mathcal{A}} (\\underbrace{Q_{\\theta^-} + p)}_\\text{target Q}(s'_t, a') - \\underbrace{(Q_\\theta + p)}_\\text{Q to optimize}(s_t, a_t) \\Big)^2 $$ Varitional Options Options are policies with termination conditions. There are a large set of options available in the search space and they are independent of an agent\u0026rsquo;s intentions. By explicitly including intrinsic options into modeling, the agent can obtain intrinsic rewards for exploration.\nVIC (short for \u0026ldquo;Variational Intrinsic Control\u0026rdquo;; Gregor, et al. 2017) is such a framework for providing the agent with intrinsic exploration bonuses based on modeling options and learning policies conditioned on options. Let $\\Omega$ represent an option which starts from $s_0$ and ends at $s_f$. An environment probability distribution $p^J(s_f \\vert s_0, \\Omega)$ defines where an option $\\Omega$ terminates given a starting state $s_0$. A controllability distribution $p^C(\\Omega \\vert s_0)$ defines the probability distribution of options we can sample from. And by definition we have $p(s_f, \\Omega \\vert s_0) = p^J(s_f \\vert s_0, \\Omega) p^C(\\Omega \\vert s_0)$.\nWhile choosing options, we would like to achieve two goals:\n Achieve a diverse set of the final states from $s_0$ ⇨ Maximization of $H(s_f \\vert s_0)$. Know precisely which state a given option $\\Omega$ can end with ⇨ Minimization of $H(s_f \\vert s_0, \\Omega)$. Combining them, we get mutual information $I(\\Omega; s_f \\vert s_0)$ to maximize:\n $$ \\begin{aligned} I(\\Omega; s_f \\vert s_0) \u0026= H(s_f \\vert s_0) - H(s_f \\vert s_0, \\Omega) \\\\ \u0026= - \\sum_{s_f} p(s_f \\vert s_0) \\log p(s_f \\vert s_0) + \\sum_{s_f, \\Omega} p(s_f, \\Omega \\vert s_0) \\log \\frac{p(s_f, \\Omega \\vert s_0)}{p^C(\\Omega \\vert s_0)} \\\\ \u0026= - \\sum_{s_f} p(s_f \\vert s_0) \\log p(s_f \\vert s_0) + \\sum_{s_f, \\Omega} p^J(s_f \\vert s_0, \\Omega) p^C(\\Omega \\vert s_0) \\log p^J(s_f \\vert s_0, \\Omega) \\\\ \\end{aligned} $$ Because mutual information is symmetric, we can switch $s_f$ and $\\Omega$ in several places without breaking the equivalence. Also because $p(\\Omega \\vert s_0, s_f)$ is difficult to observe, let us replace it with an approximation distribution $q$. According to the variational lower bound, we would have $I(\\Omega; s_f \\vert s_0) \\geq I^{VB}(\\Omega; s_f \\vert s_0)$.\n $$ \\begin{aligned} I(\\Omega; s_f \\vert s_0) \u0026= I(s_f; \\Omega \\vert s_0) \\\\ \u0026= - \\sum_{\\Omega} p(\\Omega \\vert s_0) \\log p(\\Omega \\vert s_0) + \\sum_{s_f, \\Omega} p^J(s_f \\vert s_0, \\Omega) p^C(\\Omega \\vert s_0) \\log \\color{red}{p(\\Omega \\vert s_0, s_f)}\\\\ I^{VB}(\\Omega; s_f \\vert s_0) \u0026= - \\sum_{\\Omega} p(\\Omega \\vert s_0) \\log p(\\Omega \\vert s_0) + \\sum_{s_f, \\Omega} p^J(s_f \\vert s_0, \\Omega) p^C(\\Omega \\vert s_0) \\log \\color{red}{q(\\Omega \\vert s_0, s_f)} \\\\ I(\\Omega; s_f \\vert s_0) \u0026\\geq I^{VB}(\\Omega; s_f \\vert s_0) \\end{aligned} $$ Fig. 17. The algorithm for VIC (Variational Intrinsic Control). (Image source: Gregor, et al. 2017) Here $\\pi(a \\vert \\Omega, s)$ can be optimized with any RL algorithm. The option inference function $q(\\Omega \\vert s_0, s_f)$ is doing supervised learning. The prior $p^C$ is updated so that it tends to choose $\\Omega$ with higher rewards. Note that $p^C$ can also be fixed (e.g. a Gaussian). Various $\\Omega$ will result in different behavior through learning. Additionally, Gregor, et al. (2017) observed that it is difficult to make VIC with explicit options work in practice with function approximation and therefore they also proposed another version of VIC with implicit options.\nDifferent from VIC which models $\\Omega$ conditioned only on the start and end states, VALOR (short for \u0026ldquo;Variational Auto-encoding Learning of Options by Reinforcement\u0026rdquo;; Achiam, et al. 2018) relies on the whole trajectory to extract the option context $c$, which is sampled from a fixed Gaussian distribution. In VALOR:\n A policy acts as an encoder, translating contexts from a noise distribution into trajectories A decoder attempts to recover the contexts from the trajectories, and rewards the policies for making contexts easier to distinguish. The decoder never sees the actions during training, so the agent has to interact with the environment in a way that facilitates communication with the decoder for better prediction. Also, the decoder recurrently takes in a sequence of steps in one trajectory to better model the correlation between timesteps. Fig. 18. The decoder of VALOR is a biLSTM which takes $N = 11$ equally spaced observations from one trajectory as inputs. (Image source: Achiam, et al. 2018) DIAYN (\u0026ldquo;Diversity is all you need\u0026rdquo;; Eysenbach, et al. 2018) has the idea lying in the same direction, although with a different name \u0026mdash; DIAYN models the policies conditioned on a latent skill variable. See my previous post for more details.\nCitation Cited as:\n Weng, Lilian. (Jun 2020). Exploration strategies in deep reinforcement learning. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2020-06-07-exploration-drl/.\n Or\n@article{weng2020exploration, title = \u0026quot;Exploration Strategies in Deep Reinforcement Learning\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2020\u0026quot;, month = \u0026quot;Jun\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2020-06-07-exploration-drl/\u0026quot; } Reference [1] Pierre-Yves Oudeyer \u0026amp; Frederic Kaplan. \u0026ldquo;How can we define intrinsic motivation?\u0026quot; Conf. on Epigenetic Robotics, 2008.\n[2] Marc G. Bellemare, et al. \u0026ldquo;Unifying Count-Based Exploration and Intrinsic Motivation\u0026rdquo;. NIPS 2016.\n[3] Georg Ostrovski, et al. \u0026ldquo;Count-Based Exploration with Neural Density Models\u0026rdquo;. PMLR 2017.\n[4] Rui Zhao \u0026amp; Volker Tresp. \u0026ldquo;Curiosity-Driven Experience Prioritization via Density Estimation\u0026rdquo;. NIPS 2018.\n[5] Haoran Tang, et al. \u0026quot;#Exploration: A Study of Count-Based Exploration for Deep Reinforcement Learning\u0026rdquo;. NIPS 2017.\n[6] Jürgen Schmidhuber. \u0026ldquo;A possibility for implementing curiosity and boredom in model-building neural controllers\u0026rdquo; 1991.\n[7] Pierre-Yves Oudeyer, et al. \u0026ldquo;Intrinsic Motivation Systems for Autonomous Mental Development\u0026rdquo; IEEE Transactions on Evolutionary Computation, 2007.\n[8] Bradly C. Stadie, et al. \u0026ldquo;Incentivizing Exploration In Reinforcement Learning With Deep Predictive Models\u0026rdquo;. ICLR 2016.\n[9] Deepak Pathak, et al. \u0026ldquo;Curiosity-driven Exploration by Self-supervised Prediction\u0026rdquo;. CVPR 2017.\n[10] Yuri Burda, Harri Edwards \u0026amp; Deepak Pathak, et al. \u0026ldquo;Large-Scale Study of Curiosity-Driven Learning\u0026rdquo;. arXiv 1808.04355 (2018).\n[11] Joshua Achiam \u0026amp; Shankar Sastry. \u0026ldquo;Surprise-Based Intrinsic Motivation for Deep Reinforcement Learning\u0026rdquo; NIPS 2016 Deep RL Workshop.\n[12] Rein Houthooft, et al. \u0026ldquo;VIME: Variational information maximizing exploration\u0026rdquo;. NIPS 2016.\n[13] Leshem Choshen, Lior Fox \u0026amp; Yonatan Loewenstein. \u0026ldquo;DORA the explorer: Directed outreaching reinforcement action-selection\u0026rdquo;. ICLR 2018\n[14] Yuri Burda, et al. \u0026ldquo;Exploration by Random Network Distillation\u0026rdquo; ICLR 2019.\n[15] OpenAI Blog: \u0026ldquo;Reinforcement Learning with Prediction-Based Rewards\u0026rdquo; Oct, 2018.\n[16] Misha Denil, et al. \u0026ldquo;Learning to Perform Physics Experiments via Deep Reinforcement Learning\u0026rdquo;. ICLR 2017.\n[17] Ian Osband, et al. \u0026ldquo;Deep Exploration via Bootstrapped DQN\u0026rdquo;. NIPS 2016.\n[18] Ian Osband, John Aslanides \u0026amp; Albin Cassirer. \u0026ldquo;Randomized Prior Functions for Deep Reinforcement Learning\u0026rdquo;. NIPS 2018.\n[19] Karol Gregor, Danilo Jimenez Rezende \u0026amp; Daan Wierstra. \u0026ldquo;Variational Intrinsic Control\u0026rdquo;. ICLR 2017.\n[20] Joshua Achiam, et al. \u0026ldquo;Variational Option Discovery Algorithms\u0026rdquo;. arXiv 1807.10299 (2018).\n[21] Benjamin Eysenbach, et al. \u0026ldquo;Diversity is all you need: Learning skills without a reward function.\u0026quot;. ICLR 2019.\n[22] Adrià Puigdomènech Badia, et al. \u0026ldquo;Never Give Up (NGU): Learning Directed Exploration Strategies\u0026rdquo; ICLR 2020.\n[23] Adrià Puigdomènech Badia, et al. \u0026ldquo;Agent57: Outperforming the Atari Human Benchmark\u0026rdquo;. arXiv 2003.13350 (2020).\n[24] DeepMind Blog: \u0026ldquo;Agent57: Outperforming the human Atari benchmark\u0026rdquo; Mar 2020.\n[25] Nikolay Savinov, et al. \u0026ldquo;Episodic Curiosity through Reachability\u0026rdquo; ICLR 2019.\n[26] Adrien Ecoffet, et al. \u0026ldquo;Go-Explore: a New Approach for Hard-Exploration Problems\u0026rdquo;. arXiv 1901.10995 (2019).\n[27] Adrien Ecoffet, et al. \u0026ldquo;First return then explore\u0026rdquo;. arXiv 2004.12919 (2020).\n[28] Junhyuk Oh, et al. \u0026ldquo;Self-Imitation Learning\u0026rdquo;. ICML 2018.\n[29] Yijie Guo, et al. \u0026ldquo;Self-Imitation Learning via Trajectory-Conditioned Policy for Hard-Exploration Tasks\u0026rdquo;. arXiv 1907.10247 (2019).\n[30] Zhaohan Daniel Guo \u0026amp; Emma Brunskill. \u0026ldquo;Directed Exploration for Reinforcement Learning\u0026rdquo;. arXiv 1906.07805 (2019).\n[31] Deepak Pathak, et al. “Self-Supervised Exploration via Disagreement.” ICML 2019.\n","permalink":"https://lilianweng.github.io/posts/2020-06-07-exploration-drl/","summary":"[Updated on 2020-06-17: Add \u0026ldquo;exploration via disagreement\u0026rdquo; in the \u0026ldquo;Forward Dynamics\u0026rdquo; section.\nExploitation versus exploration is a critical topic in Reinforcement Learning. We\u0026rsquo;d like the RL agent to find the best solution as fast as possible. However, in the meantime, committing to solutions too quickly without enough exploration sounds pretty bad, as it could lead to local minima or total failure. Modern RL algorithms that optimize for the best returns can achieve good exploitation quite efficiently, while exploration remains more like an open topic.","title":"Exploration Strategies in Deep Reinforcement Learning"},{"content":"[Updated on 2023-01-27: After almost three years, I did a big refactoring update of this post to incorporate a bunch of new Transformer models since 2020. The enhanced version of this post is here: The Transformer Family Version 2.0. Please refer to that post on this topic.] \nIt has been almost two years since my last post on attention. Recent progress on new and enhanced versions of Transformer motivates me to write another post on this specific topic, focusing on how the vanilla Transformer can be improved for longer-term attention span, less memory and computation consumption, RL task solving and more.\nNotations Symbol Meaning $d$ The model size / hidden state dimension / positional encoding size. $h$ The number of heads in multi-head attention layer. $L$ The segment length of input sequence. $\\mathbf{X} \\in \\mathbb{R}^{L \\times d}$ The input sequence where each element has been mapped into an embedding vector of shape $d$, same as the model size. $\\mathbf{W}^k \\in \\mathbb{R}^{d \\times d_k}$ The key weight matrix. $\\mathbf{W}^q \\in \\mathbb{R}^{d \\times d_k}$ The query weight matrix. $\\mathbf{W}^v \\in \\mathbb{R}^{d \\times d_v}$ The value weight matrix. Often we have $d_k = d_v = d$. $\\mathbf{W}^k_i, \\mathbf{W}^q_i \\in \\mathbb{R}^{d \\times d_k/h}; \\mathbf{W}^v_i \\in \\mathbb{R}^{d \\times d_v/h}$ The weight matrices per head. $\\mathbf{W}^o \\in \\mathbb{R}^{d_v \\times d}$ The output weight matrix. $\\mathbf{Q} = \\mathbf{X}\\mathbf{W}^q \\in \\mathbb{R}^{L \\times d_k}$ The query embedding inputs. $\\mathbf{K} = \\mathbf{X}\\mathbf{W}^k \\in \\mathbb{R}^{L \\times d_k}$ The key embedding inputs. $\\mathbf{V} = \\mathbf{X}\\mathbf{W}^v \\in \\mathbb{R}^{L \\times d_v}$ The value embedding inputs. $S_i$ A collection of key positions for the $i$-th query $\\mathbf{q}_i$ to attend to. $\\mathbf{A} \\in \\mathbb{R}^{L \\times L}$ The self-attention matrix between a input sequence of lenght $L$ and itself. $\\mathbf{A} = \\text{softmax}(\\mathbf{Q}\\mathbf{K}^\\top / \\sqrt{d_k})$. $a_{ij} \\in \\mathbf{A}$ The scalar attention score between query $\\mathbf{q}_i$ and key $\\mathbf{k}_j$. $\\mathbf{P} \\in \\mathbb{R}^{L \\times d}$ position encoding matrix, where the $i$-th row $\\mathbf{p}_i$ is the positional encoding for input $\\mathbf{x}_i$. Attention and Self-Attention Attention is a mechanism in the neural network that a model can learn to make predictions by selectively attending to a given set of data. The amount of attention is quantified by learned weights and thus the output is usually formed as a weighted average.\nSelf-attention is a type of attention mechanism where the model makes prediction for one part of a data sample using other parts of the observation about the same sample. Conceptually, it feels quite similar to non-local means. Also note that self-attention is permutation-invariant; in other words, it is an operation on sets.\nThere are various forms of attention / self-attention, Transformer (Vaswani et al., 2017) relies on the scaled dot-product attention: given a query matrix $\\mathbf{Q}$, a key matrix $\\mathbf{K}$ and a value matrix $\\mathbf{V}$, the output is a weighted sum of the value vectors, where the weight assigned to each value slot is determined by the dot-product of the query with the corresponding key:\n $$ \\text{Attention}(\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}) = \\text{softmax}(\\frac{\\mathbf{Q} {\\mathbf{K}}^\\top}{\\sqrt{d_k}})\\mathbf{V} $$ And for a query and a key vector $\\mathbf{q}_i, \\mathbf{k}_j \\in \\mathbb{R}^d$ (row vectors in query and key matrices), we have a scalar score:\n $$ a_{ij} = \\text{softmax}(\\frac{\\mathbf{q}_i {\\mathbf{k}_j}^\\top}{\\sqrt{d_k}}) = \\frac{\\exp(\\mathbf{q}_i {\\mathbf{k}_j}^\\top)}{ \\sqrt{d_k} \\sum_{r \\in S_i} \\exp(\\mathbf{q}_i {\\mathbf{k}_r}^\\top) } $$ where $S_i$ is a collection of key positions for the $i$-th query to attend to.\nSee my old post for other types of attention if interested.\nMulti-Head Self-Attention The multi-head self-attention module is a key component in Transformer. Rather than only computing the attention once, the multi-head mechanism splits the inputs into smaller chunks and then computes the scaled dot-product attention over each subspace in parallel. The independent attention outputs are simply concatenated and linearly transformed into expected dimensions.\n $$ \\begin{aligned} \\text{MultiHeadAttention}(\\mathbf{X}_q, \\mathbf{X}_k, \\mathbf{X}_v) \u0026= [\\text{head}_1; \\dots; \\text{head}_h] \\mathbf{W}^o \\\\ \\text{where head}_i \u0026= \\text{Attention}(\\mathbf{X}_q\\mathbf{W}^q_i, \\mathbf{X}_k\\mathbf{W}^k_i, \\mathbf{X}_v\\mathbf{W}^v_i) \\end{aligned} $$ where $[.;.]$ is a concatenation operation. $\\mathbf{W}^q_i, \\mathbf{W}^k_i \\in \\mathbb{R}^{d \\times d_k/h}, \\mathbf{W}^v_i \\in \\mathbb{R}^{d \\times d_v/h}$ are weight matrices to map input embeddings of size $L \\times d$ into query, key and value matrices. And $\\mathbf{W}^o \\in \\mathbb{R}^{d_v \\times d}$ is the output linear transformation. All the weights should be learned during training.\nFig. 1. Illustration of the multi-head scaled dot-product attention mechanism. (Image source: Figure 2 in Vaswani, et al., 2017) Transformer The Transformer (which will be referred to as \u0026ldquo;vanilla Transformer\u0026rdquo; to distinguish it from other enhanced versions; Vaswani, et al., 2017) model has an encoder-decoder architecture, as commonly used in many NMT models. Later simplified Transformer was shown to achieve great performance in language modeling tasks, like in encoder-only BERT or decoder-only GPT.\nEncoder-Decoder Architecture\nThe encoder generates an attention-based representation with capability to locate a specific piece of information from a large context. It consists of a stack of 6 identity modules, each containing two submodules, a multi-head self-attention layer and a point-wise fully connected feed-forward network. By point-wise, it means that it applies the same linear transformation (with same weights) to each element in the sequence. This can also be viewed as a convolutional layer with filter size 1. Each submodule has a residual connection and layer normalization. All the submodules output data of the same dimension $d$.\nThe function of Transformer decoder is to retrieve information from the encoded representation. The architecture is quite similar to the encoder, except that the decoder contains two multi-head attention submodules instead of one in each identical repeating module. The first multi-head attention submodule is masked to prevent positions from attending to the future.\nFig. 2. The architecture of the vanilla Transformer model. (Image source: Figure 17) Positional Encoding\nBecause self-attention operation is permutation invariant, it is important to use proper positional encodingto provide order information to the model. The positional encoding $\\mathbf{P} \\in \\mathbb{R}^{L \\times d}$ has the same dimension as the input embedding, so it can be added on the input directly. The vanilla Transformer considered two types of encodings:\n(1) Sinusoidal positional encoding is defined as follows, given the token position $i=1,\\dots,L$ and the dimension $\\delta=1,\\dots,d$:\n $$ \\text{PE}(i,\\delta) = \\begin{cases} \\sin(\\frac{i}{10000^{2\\delta'/d}}) \u0026 \\text{if } \\delta = 2\\delta'\\\\ \\cos(\\frac{i}{10000^{2\\delta'/d}}) \u0026 \\text{if } \\delta = 2\\delta' + 1\\\\ \\end{cases} $$ In this way each dimension of the positional encoding corresponds to a sinusoid of different wavelengths in different dimensions, from $2\\pi$ to $10000 \\cdot 2\\pi$.\nFig. 3. Sinusoidal positional encoding with $L=32$ and $d=128$. The value is between -1 (black) and 1 (white) and the value 0 is in gray. (2) Learned positional encoding, as its name suggested, assigns each element with a learned column vector which encodes its absolute position (Gehring, et al. 2017).\nQuick Follow-ups\nFollowing the vanilla Transformer, Al-Rfou et al. (2018) added a set of auxiliary losses to enable training a deep Transformer model on character-level language modeling which outperformed LSTMs. Several types of auxiliary tasks are used:\n Instead of producing only one prediction at the sequence end, every immediate position is also asked to make a correct prediction, forcing the model to predict given smaller contexts (e.g. first couple tokens at the beginning of a context window). Each intermediate Transformer layer is used for making predictions as well. Lower layers are weighted to contribute less and less to the total loss as training progresses. Each position in the sequence can predict multiple targets, i.e. two or more predictions of the future tokens. Fig. 4. Auxiliary prediction tasks used in deep Transformer for character-level language modeling. (Image source: Al-Rfou et al. (2018)) Adaptive Computation Time (ACT) Adaptive Computation Time (short for ACT; Graves, 2016) is a mechanism for dynamically deciding how many computational steps are needed in a recurrent neural network. Here is a cool tutorial on ACT from distill.pub.\nLet\u0026rsquo;s say, we have a RNN model $\\mathcal{R}$ composed of input weights $W_x$, a parametric state transition function $\\mathcal{S}(.)$, a set of output weights $W_y$ and an output bias $b_y$. Given an input sequence $(x_1, \\dots, x_L)$, the output sequence $(y_1, \\dots, y_L)$ is computed by:\n $$ s_t = \\mathcal{S}(s_{t-1}, W_x x_t), \\quad y_t = W_y s_t + b_y\\quad\\text{for }t=1, \\dots, L $$ ACT enables the above RNN setup to perform a variable number of steps at each input element. Multiple computational steps lead to a sequence of intermediate states $(s_t^1, \\dots, s_t^{N(t)})$ and outputs $(y_t^1, \\dots, y_t^{N(t)})$ \u0026mdash; they all share the same state transition function $\\mathcal{S}(.)$, as well as the same output weights $W_y$ and bias $b_y$:\n $$ \\begin{aligned} s_t^0 \u0026= s_{t-1} \\\\ s_t^n \u0026= \\mathcal{S}(s_{t}^{n-1}, x_t^n) = \\mathcal{S}(s_{t}^{n-1}, x_t + \\delta_{n,1}) \\text{ for } n=1, \\dots, N(t)\\\\ y_t^n \u0026= W_y s_t^n + b_y \\end{aligned} $$ where $\\delta_{n,1}$ is a binary flag indicating whether the input step has been incremented.\nThe number of steps $N(t)$ is determined by an extra sigmoidal halting unit $h$, with associated weight matrix $W_h$ and bias $b_h$, outputting a halting probability $p_t^n$ at immediate step $n$ for $t$-th input element:\n $$ h_t^n = \\sigma(W_h s_t^n + b_h) $$ In order to allow the computation to halt after a single step, ACT introduces a small constant $\\epsilon$ (e.g. 0.01), so that whenever the cumulative probability goes above $1-\\epsilon$, the computation stops.\n $$ \\begin{aligned} N(t) \u0026= \\min(\\min\\{n': \\sum_{n=1}^{n'} h_t^n \\geq 1 -\\epsilon\\}, M) \\\\ p_t^n \u0026= \\begin{cases} h_t^n \u0026 \\text{if }n where $M$ is an upper limit for the number of immediate steps allowed.\nThe final state and output are mean-field updates:\n $$ s_t = \\sum_{n=1}^{N(t)} p_t^n s_t^n,\\quad y_t = \\sum_{n=1}^{N(t)} p_t^n y_t^n $$ Fig. 5. The computation graph of a RNN with ACT mechanism. (Image source: Graves, 2016) To avoid unnecessary pondering over each input, ACT adds a ponder cost $\\mathcal{P}(x) = \\sum_{t=1}^L N(t) + R(t) $ in the loss function to encourage a smaller number of intermediate computational steps.\nImproved Attention Span The goal of improving attention span is to make the context that can be used in self-attention longer, more efficient and flexible.\nLonger Attention Span (Transformer-XL) The vanilla Transformer has a fixed and limited attention span. The model can only attend to other elements in the same segments during each update step and no information can flow across separated fixed-length segments.\nThis context segmentation causes several issues:\n The model cannot capture very long term dependencies. It is hard to predict the first few tokens in each segment given no or thin context. The evaluation is expensive. Whenever the segment is shifted to the right by one, the new segment is re-processed from scratch, although there are a lot of overlapped tokens. Transformer-XL (Dai et al., 2019; \u0026ldquo;XL\u0026rdquo; means \u0026ldquo;extra long\u0026rdquo;) solves the context segmentation problem with two main modifications:\n Reusing hidden states between segments. Adopting a new positional encoding that is suitable for reused states. Hidden State Reuse\nThe recurrent connection between segments is introduced into the model by continuously using the hidden states from the previous segments.\nFig. 6. A comparison between the training phrase of vanilla Transformer \u0026 Transformer-XL with a segment length 4. (Image source: left part of Figure 2 in Dai et al., 2019). Let\u0026rsquo;s label the hidden state of the $n$-th layer for the $(\\tau + 1)$-th segment in the model as $\\mathbf{h}_{\\tau+1}^{(n)} \\in \\mathbb{R}^{L \\times d}$. In addition to the hidden state of the last layer for the same segment $\\mathbf{h}_{\\tau+1}^{(n-1)}$, it also depends on the hidden state of the same layer for the previous segment $\\mathbf{h}_{\\tau}^{(n)}$. By incorporating information from the previous hidden states, the model extends the attention span much longer in the past, over multiple segments.\n $$ \\begin{aligned} \\color{red}{\\widetilde{\\mathbf{h}}_{\\tau+1}^{(n-1)}} \u0026= [\\text{stop-gradient}(\\mathbf{h}_{\\tau}^{(n-1)}) \\circ \\mathbf{h}_{\\tau+1}^{(n-1)}] \\\\ \\mathbf{Q}_{\\tau+1}^{(n)} \u0026= \\mathbf{h}_{\\tau+1}^{(n-1)}\\mathbf{W}^q \\\\ \\mathbf{K}_{\\tau+1}^{(n)} \u0026= \\color{red}{\\widetilde{\\mathbf{h}}_{\\tau+1}^{(n-1)}} \\mathbf{W}^k \\\\ \\mathbf{V}_{\\tau+1}^{(n)} \u0026= \\color{red}{\\widetilde{\\mathbf{h}}_{\\tau+1}^{(n-1)}} \\mathbf{W}^v \\\\ \\mathbf{h}_{\\tau+1}^{(n)} \u0026= \\text{transformer-layer}(\\mathbf{Q}_{\\tau+1}^{(n)}, \\mathbf{K}_{\\tau+1}^{(n)}, \\mathbf{V}_{\\tau+1}^{(n)}) \\end{aligned} $$ Note that both key and value rely on the extended hidden state, while the query only consumes hidden state at current step. The concatenation operation $[. \\circ .]$ is along the sequence length dimension.\nRelative Positional Encoding\nIn order to work with this new form of attention span, Transformer-XL proposed a new type of positional encoding. If using the same approach by vanilla Transformer and encoding the absolute position, the previous and current segments will be assigned with the same encoding, which is undesired.\nTo keep the positional information flow coherently across segments, Transformer-XL encodes the relative position instead, as it could be sufficient enough to know the position offset for making good predictions, i.e. $i-j$, between one key vector $\\mathbf{k}_{\\tau, j}$ and its query $\\mathbf{q}_{\\tau, i}$.\nIf omitting the scalar $1/\\sqrt{d_k}$ and the normalizing term in softmax but including positional encodings, we can write the attention score between query at position $i$ and key at position $j$ as:\n $$ \\begin{aligned} a_{ij} \u0026= \\mathbf{q}_i {\\mathbf{k}_j}^\\top = (\\mathbf{x}_i + \\mathbf{p}_i)\\mathbf{W}^q ((\\mathbf{x}_j + \\mathbf{p}_j)\\mathbf{W}^k)^\\top \\\\ \u0026= \\mathbf{x}_i\\mathbf{W}^q {\\mathbf{W}^k}^\\top\\mathbf{x}_j^\\top + \\mathbf{x}_i\\mathbf{W}^q {\\mathbf{W}^k}^\\top\\mathbf{p}_j^\\top + \\mathbf{p}_i\\mathbf{W}^q {\\mathbf{W}^k}^\\top\\mathbf{x}_j^\\top + \\mathbf{p}_i\\mathbf{W}^q {\\mathbf{W}^k}^\\top\\mathbf{p}_j^\\top \\end{aligned} $$ Transformer-XL reparameterizes the above four terms as follows:\n $$ a_{ij}^\\text{rel} = \\underbrace{ \\mathbf{x}_i\\mathbf{W}^q \\color{blue}{ {\\mathbf{W}_E^k}^\\top } \\mathbf{x}_j^\\top }_\\text{content-based addressing} + \\underbrace{ \\mathbf{x}_i\\mathbf{W}^q \\color{blue}{ {\\mathbf{W}_R^k}^\\top } \\color{green}{\\mathbf{r}_{i-j}^\\top} }_\\text{content-dependent positional bias} + \\underbrace{ \\color{red}{\\mathbf{u}} \\color{blue}{ {\\mathbf{W}_E^k}^\\top } \\mathbf{x}_j^\\top }_\\text{global content bias} + \\underbrace{ \\color{red}{\\mathbf{v}} \\color{blue}{ {\\mathbf{W}_R^k}^\\top } \\color{green}{\\mathbf{r}_{i-j}^\\top} }_\\text{global positional bias} $$ Replace $\\mathbf{p}_j$ with relative positional encoding $\\mathbf{r}_{i-j} \\in \\mathbf{R}^{d}$; Replace $\\mathbf{p}_i\\mathbf{W}^q$ with two trainable parameters $\\mathbf{u}$ (for content) and $\\mathbf{v}$ (for location) in two different terms; Split $\\mathbf{W}^k$ into two matrices, $\\mathbf{W}^k_E$ for content information and $\\mathbf{W}^k_R$ for location information. Adaptive Attention Span One key advantage of Transformer is the capability of capturing long-term dependencies. Depending on the context, the model may prefer to attend further sometime than others; or one attention head may had different attention pattern from the other. If the attention span could adapt its length flexibly and only attend further back when needed, it would help reduce both computation and memory cost to support longer maximum context size in the model.\nThis is the motivation for Adaptive Attention Span. Sukhbaatar, et al., (2019) proposed a self-attention mechanism that seeks an optimal attention span. They hypothesized that different attention heads might assign scores differently within the same context window (See Fig. 7) and thus the optimal span would be trained separately per head.\nFig. 7. Two attention heads in the same model, A \u0026 B, assign attention differently within the same context window. Head A attends more to the recent tokens, while head B look further back into the past uniformly. (Image source: Sukhbaatar, et al. 2019) Given the $i$-th token, we need to compute the attention weights between this token and other keys at positions $j \\in S_i$, where $S_i$ defineds the $i$-th token\u0026rsquo;s context window.\n $$ \\begin{aligned} e_{ij} \u0026= \\mathbf{q}_i {\\mathbf{k}_j}^\\top \\\\ a_{ij} \u0026= \\text{softmax}(e_{ij}) = \\frac{\\exp(e_{ij})}{\\sum_{r=i-s}^{i-1} \\exp(e_{ir})} \\\\ \\mathbf{y}_i \u0026= \\sum_{r=i-s}^{i-1}a_{ir}\\mathbf{v}_r = \\sum_{r=i-s}^{i-1}a_{ir}\\mathbf{x}_r\\mathbf{W}^v \\end{aligned} $$ A soft mask function $m_z$ is added to control for an effective adjustable attention span, which maps the distance between query and key into a [0, 1] value. $m_z$ is parameterized by $z \\in [0, s]$ and $z$ is to be learned:\n $$ m_z(x) = \\text{clamp}(\\frac{1}{R}(R+z-x), 0, 1) $$ where $R$ is a hyper-parameter which defines the softness of $m_z$.\nFig. 8. The soft masking function used in the adaptive attention span. (Image source: Sukhbaatar, et al. 2019.) The soft mask function is applied to the softmax elements in the attention weights:\n $$ a_{ij} = \\frac{m_z(i-j)\\exp(s_{ij})}{\\sum_{r=i-s}^{i-1}m_z(i-r) \\exp(s_{ir})} $$ In the above equation, $z$ is differentiable so it is trained jointly with other parts of the model. Parameters $z^{(i)}, i=1, \\dots, h$ are learned separately per head. Moreover, the loss function has an extra L1 penalty on $\\sum_{i=1}^h z^{(i)}$.\nUsing Adaptive Computation Time, the approach can be further enhanced to have flexible attention span length, adaptive to the current input dynamically. The span parameter $z_t$ of an attention head at time $t$ is a sigmoidal function, $z_t = S \\sigma(\\mathbf{v} \\cdot \\mathbf{x}_t +b)$, where the vector $\\mathbf{v}$ and the bias scalar $b$ are learned jointly with other parameters.\nIn the experiments of Transformer with adaptive attention span, Sukhbaatar, et al. (2019) found a general tendency that lower layers do not require very long attention spans, while a few attention heads in higher layers may use exceptionally long spans. Adaptive attention span also helps greatly reduce the number of FLOPS, especially in a big model with many attention layers and a large context length.\nLocalized Attention Span (Image Transformer) The original, also the most popular, use case for Transformer is to do language modeling. The text sequence is one-dimensional in a clearly defined chronological order and thus the attention span grows linearly with increased context size.\nHowever, if we want to use Transformer on images, it is unclear how to define the scope of context or the order. Image Transformer (Parmer, et al 2018) embraces a formulation of image generation similar to sequence modeling within the Transformer framework. Additionally, Image Transformer restricts the self-attention span to only local neighborhoods, so that the model can scale up to process more images in parallel and keep the likelihood loss tractable.\nThe encoder-decoder architecture remains for image-conditioned generation:\n The encoder generates a contextualized, per-pixel-channel representation of the source image; The decoder autoregressively generates an output image, one channel per pixel at each time step. Let\u0026rsquo;s label the representation of the current pixel to be generated as the query $\\mathbf{q}$. Other positions whose representations will be used for computing $\\mathbf{q}$ are key vector $\\mathbf{k}_1, \\mathbf{k}_2, \\dots$ and they together form a memory matrix $\\mathbf{M}$. The scope of $\\mathbf{M}$ defines the context window for pixel query $\\mathbf{q}$.\nImage Transformer introduced two types of localized $\\mathbf{M}$, as illustrated below.\nFig. 9. Illustration of 1D and 2D attention span for visual inputs in Image Transformer. The black line marks a query block and the cyan outlines the actual attention span for pixel q. (Image source: Figure 2 in Parmer et al, 2018) (1) 1D Local Attention: The input image is flattened in the raster scanning order, that is, from left to right and top to bottom. The linearized image is then partitioned into non-overlapping query blocks. The context window consists of pixels in the same query block as $\\mathbf{q}$ and a fixed number of additional pixels generated before this query block.\n(2) 2D Local Attention: The image is partitioned into multiple non-overlapping rectangular query blocks. The query pixel can attend to all others in the same memory blocks. To make sure the pixel at the top-left corner can also have a valid context window, the memory block is extended to the top, left and right by a fixed amount, respectively.\nLess Time and Memory Cost This section introduces several improvements made on Transformer to reduce the computation time and memory consumption.\nSparse Attention Matrix Factorization (Sparse Transformers) The compute and memory cost of the vanilla Transformer grows quadratically with sequence length and thus it is hard to be applied on very long sequences.\nSparse Transformer (Child et al., 2019) introduced factorized self-attention, through sparse matrix factorization, making it possible to train dense attention networks with hundreds of layers on sequence length up to 16,384, which would be infeasible on modern hardware otherwise.\nGiven a set of attention connectivity pattern $\\mathcal{S} = \\{S_1, \\dots, S_n\\}$, where each $S_i$ records a set of key positions that the $i$-th query vector attends to.\n $$ \\begin{aligned} \\text{Attend}(\\mathbf{X}, \\mathcal{S}) \u0026= \\Big( a(\\mathbf{x}_i, S_i) \\Big)_{i \\in \\{1, \\dots, L\\}} \\\\ \\text{ where } a(\\mathbf{x}_i, S_i) \u0026= \\text{softmax}\\Big(\\frac{(\\mathbf{x}_i \\mathbf{W}^q)(\\mathbf{x}_j \\mathbf{W}^k)_{j \\in S_i}^\\top}{\\sqrt{d_k}}\\Big) (\\mathbf{x}_j \\mathbf{W}^v)_{j \\in S_i} \\end{aligned} $$ Note that although the size of $S_i$ is not fixed, $a(\\mathbf{x}_i, S_i)$ is always of size $d_v$ and thus $\\text{Attend}(\\mathbf{X}, \\mathcal{S}) \\in \\mathbb{R}^{L \\times d_v}$.\nIn anto-regressive models, one attention span is defined as $S_i = \\{j: j \\leq i\\}$ as it allows each token to attend to all the positions in the past.\nIn factorized self-attention, the set $S_i$ is decomposed into a tree of dependencies, such that for every pair of $(i, j)$ where $j \\leq i$, there is a path connecting $i$ back to $j$ and $i$ can attend to $j$ either directly or indirectly.\nPrecisely, the set $S_i$ is divided into $p$ non-overlapping subsets, where the $m$-th subset is denoted as $A^{(m)}_i \\subset S_i, m = 1,\\dots, p$. Therefore the path between the output position $i$ and any $j$ has a maximum length $p + 1$. For example, if $(j, a, b, c, \\dots, i)$ is a path of indices between $i$ and $j$, we would have $j \\in A_a^{(1)}, a \\in A_b^{(2)}, b \\in A_c^{(3)}, \\dots$, so on and so forth.\nSparse Factorized Attention\nSparse Transformer proposed two types of fractorized attention. It is easier to understand the concepts as illustrated in Fig. 10 with 2D image inputs as examples.\nFig. 10. The top row illustrates the attention connectivity patterns in (a) Transformer, (b) Sparse Transformer with strided attention, and (c) Sparse Transformer with fixed attention. The bottom row contains corresponding self-attention connectivity matrices. Note that the top and bottom rows are not in the same scale. (Image source: Child et al., 2019 + a few of extra annotations.) (1) Strided attention with stride $\\ell \\sim \\sqrt{n}$. This works well with image data as the structure is aligned with strides. In the image case, each pixel would attend to all the previous $\\ell$ pixels in the raster scanning order (naturally cover the entire width of the image) and then those pixels attend to others in the same column (defined by another attention connectivity subset).\n $$ \\begin{aligned} A_i^{(1)} \u0026= \\{ t, t+1, \\dots, i\\} \\text{, where } t = \\max(0, i - \\ell) \\\\ A_i^{(2)} \u0026= \\{j: (i-j) \\mod \\ell = 0\\} \\end{aligned} $$ (2) Fixed attention. A small set of tokens summarize previous locations and propagate that information to all future locations.\n $$ \\begin{aligned} A_i^{(1)} \u0026= \\{j: \\lfloor \\frac{j}{\\ell} \\rfloor = \\lfloor \\frac{i}{\\ell} \\rfloor \\} \\\\ A_i^{(2)} \u0026= \\{j: j \\mod \\ell \\in \\{\\ell-c, \\dots, \\ell-1\\} \\} \\end{aligned} $$ where $c$ is a hyperparameter. If $c=1$, it restricts the representation whereas many depend on a few positions. The paper chose $c\\in \\{ 8, 16, 32 \\}$ for $\\ell \\in \\{ 128, 256 \\}$.\nUse Factorized Self-Attention in Transformer\nThere are three ways to use sparse factorized attention patterns in Transformer architecture:\n One attention type per residual block and then interleave them, $\\text{attention}(\\mathbf{X}) = \\text{Attend}(\\mathbf{X}, A^{(n \\mod p)}) \\mathbf{W}^o$, where $n$ is the index of the current residual block. Set up a single head which attends to locations that all the factorized heads attend to, $\\text{attention}(\\mathbf{X}) = \\text{Attend}(\\mathbf{X}, \\cup_{m=1}^p A^{(m)}) \\mathbf{W}^o $. Use a multi-head attention mechanism, but different from vanilla Transformer, each head might adopt a pattern presented above, 1 or 2. =\u0026gt; This option often performs the best. Sparse Transformer also proposed a set of changes so as to train the Transformer up to hundreds of layers, including gradient checkpointing, recomputing attention \u0026amp; FF layers during the backward pass, mixed precision training, efficient block-sparse implementation, etc. Please check the paper for more details.\nLocality-Sensitive Hashing (Reformer) The improvements proposed by the Reformer model (Kitaev, et al. 2020) aim to solve the following pain points in Transformer:\n Memory in a model with $N$ layers is $N$-times larger than in a single-layer model because we need to store activations for back-propagation. The intermediate FF layers are often quite large. The attention matrix on sequences of length $L$ often requires $O(L^2)$ in both memory and time. Reformer proposed two main changes:\n Replace the dot-product attention with locality-sensitive hashing (LSH) attention, reducing the complexity from $O(L^2)$ to $O(L\\log L)$. Replace the standard residual blocks with reversible residual layers, which allows storing activations only once during training instead of $N$ times (i.e. proportional to the number of layers). Locality-Sensitive Hashing Attention\nIn $\\mathbf{Q} \\mathbf{K}^\\top$ part of the attention formula, we are only interested in the largest elements as only large elements contribute a lot after softmax. For each query $\\mathbf{q}_i \\in \\mathbf{Q}$, we are looking for row vectors in $\\mathbf{K}$ closest to $\\mathbf{q}_i$. In order to find nearest neighbors quickly in high-dimensional space, Reformer incorporates Locality-Sensitive Hashing (LSH) into its attention mechanism.\nA hashing scheme $x \\mapsto h(x)$ is locality-sensitive if it preserves the distancing information between data points, such that close vectors obtain similar hashes while distant vectors have very different ones. The Reformer adopts a hashing scheme as such, given a fixed random matrix $\\mathbf{R} \\in \\mathbb{R}^{d \\times b/2}$ (where $b$ is a hyperparam), the hash function is $h(x) = \\arg\\max([xR; −xR])$.\n$$ \\mathbf{o}_i = \\sum_{j \\in S_i} \\exp(\\mathbf{q}_i \\cdot \\mathbf{k}_j - Z(i, S_i)) \\mathbf{v}_j \\text{, where } S_i = \\{j: j \\leq i\\} $$ -- Fig. 11. Illustration of Locality-Sensitive Hashing (LSH) attention. (Image source: right part of Figure 1 in Kitaev, et al. 2020). In LSH attention, a query can only attend to positions in the same hashing bucket, $S_i = \\{j: h(\\mathbf{q}_i) = h(\\mathbf{k}_j)\\}$. It is carried out in the following process, as illustrated in Fig. 11:\n (a) The attention matrix for full attention is often sparse. (b) Using LSH, we can sort the keys and queries to be aligned according to their hash buckets. (c) Set $\\mathbf{Q} = \\mathbf{K}$ (precisely $\\mathbf{k}_j = \\mathbf{q}_j / |\\mathbf{q}_j|$), so that there are equal numbers of keys and queries in one bucket, easier for batching. Interestingly, this \u0026ldquo;shared-QK\u0026rdquo; config does not affect the performance of the Transformer. (d) Apply batching where chunks of $m$ consecutive queries are grouped together. Fig. 12. The LSH attention consists of 4 steps: bucketing, sorting, chunking, and attention computation. (Image source: left part of Figure 1 in Kitaev, et al. 2020). Reversible Residual Network\nAnother improvement by Reformer is to use reversible residual layers (Gomez et al. 2017). The motivation for reversible residual network is to design the architecture in a way that activations at any given layer can be recovered from the activations at the following layer, using only the model parameters. Hence, we can save memory by recomputing the activation during backprop rather than storing all the activations.\nGiven a layer $x \\mapsto y$, the normal residual layer does $y = x + F(x)$, but the reversible layer splits both input and output into pairs $(x_1, x_2) \\mapsto (y_1, y_2)$ and then executes the following:\n $$ y_1 = x_1 + F(x_2),\\; y_2 = x_2 + G(y_1) $$ and reversing is easy:\n $$ x_2 = y_2 - G(y_1), \\; x_1 = y_1 − F(x_2) $$ Reformer applies the same idea to Transformer by combination attention ($F$) and feed-forward layers ($G$) within a reversible net block:\n $$ Y_1 = X_1 + \\text{Attention}(X_2), \\; Y_2 = X_2 + \\text{FeedForward}(Y_1) $$ The memory can be further reduced by chunking the feed-forward computation:\n $$ Y_2 = [Y_2^{(1)}; \\dots; Y_2^{(c)}] = [X_2^{(1)} + \\text{FeedForward}(Y_1^{(1)}); \\dots; X_2^{(c)} + \\text{FeedForward}(Y_1^{(c)})] $$ The resulting reversible Transformer does not need to store activation in every layer.\nMake it Recurrent (Universal Transformer) The Universal Transformer (Dehghani, et al. 2019) combines self-attention in Transformer with the recurrent mechanism in RNN, aiming to benefit from both a long-term global receptive field of Transformer and learned inductive biases of RNN.\nRather than going through a fixed number of layers, Universal Transformer dynamically adjusts the number of steps using adaptive computation time. If we fix the number of steps, an Universal Transformer is equivalent to a multi-layer Transformer with shared parameters across layers.\nOn a high level, the universal transformer can be viewed as a recurrent function for learning the hidden state representation per token. The recurrent function evolves in parallel across token positions and the information between positions is shared through self-attention.\nFig. 13. How the Universal Transformer refines a set of hidden state representations repeatedly for every position in parallel. (Image source: Figure 1 in Dehghani, et al. 2019). Given an input sequence of length $L$, Universal Transformer iteratively updates the representation $\\mathbf{H}^t \\in \\mathbb{R}^{L \\times d}$ at step $t$ for an adjustable number of steps. At step 0, $\\mathbf{H}^0$ is initialized to be same as the input embedding matrix. All the positions are processed in parallel in the multi-head self-attention mechanism and then go through a recurrent transition function.\n $$ \\begin{aligned} \\mathbf{A}^t \u0026= \\text{LayerNorm}(\\mathbf{H}^{t-1} + \\text{MultiHeadAttention}(\\mathbf{H}^{t-1} + \\mathbf{P}^t) \\\\ \\mathbf{H}^t \u0026= \\text{LayerNorm}(\\mathbf{A}^{t-1} + \\text{Transition}(\\mathbf{A}^t)) \\end{aligned} $$ where $\\text{Transition}(.)$ is either a separable convolution or a fully-connected neural network that consists of two position-wise (i.e. applied to each row of $\\mathbf{A}^t$ individually) affine transformation + one ReLU.\nThe positional encoding $\\mathbf{P}^t$ uses sinusoidal position signal but with an additional time dimension:\n $$ \\text{PE}(i, t, \\delta) = \\begin{cases} \\sin(\\frac{i}{10000^{2\\delta'/d}}) \\oplus \\sin(\\frac{t}{10000^{2\\delta'/d}}) \u0026 \\text{if } \\delta = 2\\delta'\\\\ \\cos(\\frac{i}{10000^{2\\delta'/d}}) \\oplus \\cos(\\frac{t}{10000^{2\\delta'/d}}) \u0026 \\text{if } \\delta = 2\\delta' + 1\\\\ \\end{cases} $$ Fig. 14. A simplified illustration of Universal Transformer. The encoder and decoder share the same basic recurrent structure. But the decoder also attends to final encoder representation $\\mathbf{H}^T$. (Image source: Figure 2 in Dehghani, et al. 2019) In the adaptive version of Universal Transformer, the number of recurrent steps $T$ is dynamically determined by ACT. Each position is equipped with a dynamic ACT halting mechanism. Once a per-token recurrent block halts, it stops taking more recurrent updates but simply copies the current value to the next step until all the blocks halt or until the model reaches a maximum step limit.\nStabilization for RL (GTrXL) The self-attention mechanism avoids compressing the whole past into a fixed-size hidden state and does not suffer from vanishing or exploding gradients as much as RNNs. Reinforcement Learning tasks can for sure benefit from these traits. However, it is quite difficult to train Transformer even in supervised learning, let alone in the RL context. It could be quite challenging to stabilize and train a LSTM agent by itself, after all.\nThe Gated Transformer-XL (GTrXL; Parisotto, et al. 2019) is one attempt to use Transformer for RL. GTrXL succeeded in stabilizing training with two changes on top of Transformer-XL:\n The layer normalization is only applied on the input stream in a residual module, but NOT on the shortcut stream. A key benefit to this reordering is to allow the original input to flow from the first to last layer. The residual connection is replaced with a GRU-style (Gated Recurrent Unit; Chung et al., 2014) gating mechanism. $$ \\begin{aligned} r \u0026= \\sigma(W_r^{(l)} y + U_r^{(l)} x) \\\\ z \u0026= \\sigma(W_z^{(l)} y + U_z^{(l)} x - b_g^{(l)}) \\\\ \\hat{h} \u0026= \\tanh(W_g^{(l)} y + U_g^{(l)} (r \\odot x)) \\\\ g^{(l)}(x, y) \u0026= (1-z)\\odot x + z\\odot \\hat{h} \\end{aligned} $$ The gating function parameters are explicitly initialized to be close to an identity map - this is why there is a $b_g$ term. A $b_g \u0026gt; 0$ greatly helps with the learning speedup.\nFig. 15. Comparison of the model architecture of Transformer-XL, Transformer-XL with the layer norm reordered, and Gated Transformer-XL. (Image source: Figure 1 in Parisotto, et al. 2019) Citation Cited as:\n Weng, Lilian. (Apr 2020). The transformer family. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2020-04-07-the-transformer-family/.\n Or\n@article{weng2020transformer, title = \u0026quot;The Transformer Family\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2020\u0026quot;, month = \u0026quot;Apr\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2020-04-07-the-transformer-family/\u0026quot; } Reference [1] Ashish Vaswani, et al. \u0026ldquo;Attention is all you need.\u0026quot; NIPS 2017.\n[2] Rami Al-Rfou, et al. \u0026ldquo;Character-level language modeling with deeper self-attention.\u0026quot; AAAI 2019.\n[3] Olah \u0026amp; Carter, \u0026ldquo;Attention and Augmented Recurrent Neural Networks\u0026rdquo;, Distill, 2016.\n[4] Sainbayar Sukhbaatar, et al. \u0026ldquo;Adaptive Attention Span in Transformers\u0026rdquo;. ACL 2019.\n[5] Rewon Child, et al. \u0026ldquo;Generating Long Sequences with Sparse Transformers\u0026rdquo; arXiv:1904.10509 (2019).\n[6] Nikita Kitaev, et al. \u0026ldquo;Reformer: The Efficient Transformer\u0026rdquo; ICLR 2020.\n[7] Alex Graves. (\u0026ldquo;Adaptive Computation Time for Recurrent Neural Networks\u0026rdquo;)[https://arxiv.org/abs/1603.08983]\n[8] Niki Parmar, et al. \u0026ldquo;Image Transformer\u0026rdquo; ICML 2018.\n[9] Zihang Dai, et al. \u0026ldquo;Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.\u0026quot; ACL 2019.\n[10] Aidan N. Gomez, et al. \u0026ldquo;The Reversible Residual Network: Backpropagation Without Storing Activations\u0026rdquo; NIPS 2017.\n[11] Mostafa Dehghani, et al. \u0026ldquo;Universal Transformers\u0026rdquo; ICLR 2019.\n[12] Emilio Parisotto, et al. \u0026ldquo;Stabilizing Transformers for Reinforcement Learning\u0026rdquo; arXiv:1910.06764 (2019).\n","permalink":"https://lilianweng.github.io/posts/2020-04-07-the-transformer-family/","summary":"[Updated on 2023-01-27: After almost three years, I did a big refactoring update of this post to incorporate a bunch of new Transformer models since 2020. The enhanced version of this post is here: The Transformer Family Version 2.0. Please refer to that post on this topic.] \nIt has been almost two years since my last post on attention. Recent progress on new and enhanced versions of Transformer motivates me to write another post on this specific topic, focusing on how the vanilla Transformer can be improved for longer-term attention span, less memory and computation consumption, RL task solving and more.","title":"The Transformer Family"},{"content":"[Updated on 2020-02-03: mentioning PCG in the \u0026ldquo;Task-Specific Curriculum\u0026rdquo; section. [Updated on 2020-02-04: Add a new \u0026ldquo;curriculum through distillation\u0026rdquo; section.\nIt sounds like an impossible task if we want to teach integral or derivative to a 3-year-old who does not even know basic arithmetics. That\u0026rsquo;s why education is important, as it provides a systematic way to break down complex knowledge and a nice curriculum for teaching concepts from simple to hard. A curriculum makes learning difficult things easier and approachable for us humans. But, how about machine learning models? Can we train our models more efficiently with a curriculum? Can we design a curriculum to speed up learning?\nBack in 1993, Jeffrey Elman has proposed the idea of training neural networks with a curriculum. His early work on learning simple language grammar demonstrated the importance of such a strategy: starting with a restricted set of simple data and gradually increasing the complexity of training samples; otherwise the model was not able to learn at all.\nCompared to training without a curriculum, we would expect the adoption of the curriculum to expedite the speed of convergence and may or may not improve the final model performance. To design an efficient and effective curriculum is not easy. Keep in mind that, a bad curriculum may even hamper learning.\nNext, we will look into several categories of curriculum learning, as illustrated in Fig. 1. Most cases are applied to Reinforcement Learning, with a few exceptions on Supervised Learning.\nFig. 1. Five types of curriculum for reinforcement learning. In \u0026ldquo;The importance of starting small\u0026rdquo; paper (Elman 1993), I especially like the starting sentences and find them both inspiring and affecting:\n \u0026ldquo;Humans differ from other species along many dimensions, but two are particularly noteworthy. Humans display an exceptional capacity to learn; and humans are remarkable for the unusually long time it takes to reach maturity. The adaptive advantage of learning is clear, and it may be argued that, through culture, learning has created the basis for a non-genetically based transmission of behaviors which may accelerate the evolution of our species.\u0026rdquo;\n Indeed, learning is probably the best superpower we humans have.\nTask-Specific Curriculum Bengio, et al. (2009) provided a good overview of curriculum learning in the old days. The paper presented two ideas with toy experiments using a manually designed task-specific curriculum:\n Cleaner Examples may yield better generalization faster. Introducing gradually more difficult examples speeds up online training. It is plausible that some curriculum strategies could be useless or even harmful. A good question to answer in the field is: What could be the general principles that make some curriculum strategies work better than others? The Bengio 2009 paper hypothesized it would be beneficial to make learning focus on \u0026ldquo;interesting\u0026rdquo; examples that are neither too hard or too easy.\nIf our naive curriculum is to train the model on samples with a gradually increasing level of complexity, we need a way to quantify the difficulty of a task first. One idea is to use its minimal loss with respect to another model while this model is pretrained on other tasks (Weinshall, et al. 2018). In this way, the knowledge of the pretrained model can be transferred to the new model by suggesting a rank of training samples. Fig. 2 shows the effectiveness of the curriculum group (green), compared to control (random order; yellow) and anti (reverse the order; red) groups.\nFig. 2. Image classification accuracy on test image set (5 member classes of \"small mammals\" in CIFAR100). There are 4 experimental groups, (a) `curriculum`: sort the labels by the confidence of another trained classifier (e.g. the margin of an SVM); (b) `control-curriculum`: sort the labels randomly; (c) `anti-curriculum`: sort the labels reversely; (d) `None`: no curriculum. (Image source: Weinshall, et al. 2018) Zaremba \u0026amp; Sutskever (2014) did an interesting experiment on training LSTM to predict the output of a short Python program for mathematical ops without actually executing the code. They found curriculum is necessary for learning. The program\u0026rsquo;s complexity is controlled by two parameters, length ∈ [1, a] and nesting∈ [1, b]. Three strategies are considered:\n Naive curriculum: increase length first until reaching a; then increase nesting and reset length to 1; repeat this process until both reach maximum. Mix curriculum: sample length ~ [1, a] and nesting ~ [1, b] Combined: naive + mix. They noticed that combined strategy always outperformed the naive curriculum and would generally (but not always) outperform the mix strategy \u0026mdash; indicating that it is quite important to mix in easy tasks during training to avoid forgetting.\nProcedural content generation (PCG) is a popular approach for creating video games of various levels of difficulty. PCG involves algorithmic randomness and a heavy dose of human expertise in designing game elements and dependencies among them. Procedurally generated levels have been introduced into several benchmark environments for evaluating whether an RL agent can generalize to a new level that it is not trained on (meta-RL!), such as GVGAI, OpenAI CoinRun and Procgen benchmark. Using GVGAI, Justesen, et al. (2018) demonstrated that an RL policy can easily overfit to a specific game but training over a simple curriculum that grows the task difficulty together with the model performance helps its generalization to new human-designed levels. Similar results are also found in CoinRun (Cobbe, et al. 2018). POET (Wang et al, 2019) is another example for leveraging evolutionary algorithm and procedural generated game levels to improve RL generalization, which I\u0026rsquo;ve described in details in my meta-RL post.\nTo follow the curriculum learning approaches described above, generally we need to figure out two problems in the training procedure:\n Design a metric to quantify how hard a task is so that we can sort tasks accordingly. Provide a sequence of tasks with an increasing level of difficulty to the model during training. However, the order of tasks does not have to be sequential. In our Rubik\u0026rsquo;s cube paper (OpenAI et al, 2019), we depended on Automatic domain randomization (ADR) to generate a curriculum by growing a distribution of environments with increasing complexity. The difficulty of each task (i.e. solving a Rubik\u0026rsquo;s cube in a set of environments) depends on the randomization ranges of various environmental parameters. Even with a simplified assumption that all the environmental parameters are uncorrelated, we were able to create a decent curriculum for our robot hand to learn the task.\nTeacher-Guided Curriculum The idea of Automatic Curriculum Learning was proposed by Graves, et al. 2017 slightly earlier. It considers a $N$-task curriculum as an $N$-armed bandit problem and an adaptive policy which learns to optimize the returns from this bandit.\nTwo categories of learning signals have been considered in the paper:\n Loss-driven progress: the loss function change before and after one gradient update. This type of reward signals tracks the speed of the learning process, because the greatest task loss decrease is equivalent to the fastest learning. Complex-driven progress: the KL divergence between posterior and prior distribution over network weights. This type of learning signals are inspired by the MDL principle, \u0026ldquo;increasing the model complexity by a certain amount is only worthwhile if it compresses the data by a greater amount\u0026rdquo;. The model complexity is therefore expected to increase most in response to the model nicely generalizing to training examples. This framework of proposing curriculum automatically through another RL agent was formalized as Teacher-Student Curriculum Learning (TSCL; Matiisen, et al. 2017). In TSCL, a student is an RL agent working on actual tasks while a teacher agent is a policy for selecting tasks. The student aims to master a complex task that might be hard to learn directly. To make this task easier to learn, we set up the teacher agent to guide the student\u0026rsquo;s training process by picking proper sub-tasks.\nFig. 3. The setup of teacher-student curriculum learning. (Image source: Matiisen, et al. 2017 + my annotation in red.) In the process, the student should learn tasks which:\n can help the student make fastest learning progress, or are at risk of being forgotten. Note: The setup of framing the teacher model as an RL problem feels quite similar to Neural Architecture Search (NAS), but differently the RL model in TSCL operates on the task space and NAS operates on the main model architecture space.\n Training the teacher model is to solve a POMDP problem:\n The unobserved $s_t$ is the full state of the student model. The observed $o = (x_t^{(1)}, \\dots, x_t^{(N)})$ are a list of scores for $N$ tasks. The action $a$ is to pick on subtask. The reward per step is the score delta.$r_t = \\sum_{i=1}^N x_t^{(i)} - x_{t-1}^{(i)}$ (i.e., equivalent to maximizing the score of all tasks at the end of the episode). The method of estimating learning progress from noisy task scores while balancing exploration vs exploitation can be borrowed from the non-stationary multi-armed bandit problem \u0026mdash; use ε-greedy, or Thompson sampling.\nThe core idea, in summary, is to use one policy to propose tasks for another policy to learn better. Interestingly, both works above (in the discrete task space) found that uniformly sampling from all tasks is a surprisingly strong benchmark.\nWhat if the task space is continuous? Portelas, et al. (2019) studied a continuous teacher-student framework, where the teacher has to sample parameters from continuous task space to generate a learning curriculum. Given a newly sampled parameter $p$, the absolute learning progress (short for ALP) is measured as $\\text{ALP}_p = \\vert r - r_\\text{old} \\vert$, where $r$ is the episodic reward associated with $p$ and $r_\\text{old}$ is the reward associated with $p_\\text{old}$. Here, $p_\\text{old}$ is a previous sampled parameter closest to $p$ in the task space, which can be retrieved by nearest neighbor. Note that how this ALP score is different from learning signals in TSCL or Grave, et al. 2017 above: ALP score measures the reward difference between two tasks rather than performance at two time steps of the same task.\nOn top of the task parameter space, a Gaussian mixture model is trained to fit the distribution of $\\text{ALP}_p$ over $p$. ε-greedy is used when sampling the tasks: with some probability, sampling a random task; otherwise sampling proportionally to ALP score from the GMM model.\nFig. 4. The algorithm of ALP-GMM (absolute learning progress Gaussian mixture model). (Image source: Portelas, et al., 2019) Curriculum through Self-Play Different from the teacher-student framework, two agents are doing very different things. The teacher learns to pick a task for the student without any knowledge of the actual task content. What if we want to make both train on the main task directly? How about even make them compete with each other?\nSukhbaatar, et al. (2017) proposed a framework for automatic curriculum learning through asymmetric self-play. Two agents, Alice and Bob, play the same task with different goals: Alice challenges Bob to achieve the same state and Bob attempts to complete it as fast as he can.\nFig. 5. Illustration of the self-play setup when training two agents. The example task is MazeBase: An agent is asked to reach a goal flag in a maze with a light switch, a key and a wall with a door. Toggling the key switch can open or close the door and Turning off the light makes only the glowing light switch available to the agent. (Image source: Sukhbaatar, et al. 2017) Let us consider Alice and Bob as two separate copies for one RL agent trained in the same environment but with different brains. Each of them has independent parameters and loss objective. The self-play-driven training consists of two types of episodes:\n In the self-play episode, Alice alters the state from $s_0$ to $s_t$ and then Bob is asked to return the environment to its original state $s_0$ to get an internal reward. In the target task episode, Bob receives an external reward if he visits the target flag. Note that since B has to repeat the actions between the same pair of $(s_0, s_t)$ of A, this framework only works in reversible or resettable environments.\nAlice should learn to push Bob out of his comfort zone, but not give him impossible tasks. Bob\u0026rsquo;s reward is set as $R_B = -\\gamma t_B$ and Alice\u0026rsquo;s reward is $R_A = \\gamma \\max(0, t_B - t_A)$, where $t_B$ is the total time for B to complete the task, $t_A$ is the time until Alice performs the STOP action and $\\gamma$ is a scalar constant to rescale the reward to be comparable with the external task reward. If B fails a task, $t_B = t_\\max - t_A$. Both policies are goal-conditioned. The losses imply:\n B wants to finish a task asap. A prefers tasks that take more time of B. A does not want to take too many steps when B is failing. In this way, the interaction between Alice and Bob automatically builds a curriculum of increasingly challenging tasks. Meanwhile, as A has done the task herself before proposing the task to B, the task is guaranteed to be solvable.\nThe paradigm of A suggesting tasks and then B solving them does sound similar to the Teacher-Student framework. However, in asymmetric self-play, Alice, who plays a teacher role, also works on the same task to find challenging cases for Bob, rather than optimizes B\u0026rsquo;s learning process explicitly.\nAutomatic Goal Generation Often RL policy needs to be able to perform over a set of tasks. The goal should be carefully chosen so that at every training stage, it would not be too hard or too easy for the current policy. A goal $g \\in \\mathcal{G}$ can be defined as a set of states $S^g$ and a goal is considered as achieved whenever an agent arrives at any of those states.\nThe approach of Generative Goal Learning (Florensa, et al. 2018) relies on a Goal GAN to generate desired goals automatically. In their experiment, the reward is very sparse, just a binary flag for whether a goal is achieved or not and the policy is conditioned on goal,\n $$ \\begin{aligned} \\pi^{*}(a_t\\vert s_t, g) \u0026= \\arg\\max_\\pi \\mathbb{E}_{g\\sim p_g(.)} R^g(\\pi) \\\\ \\text{where }R^g(\\pi) \u0026= \\mathbb{E}_\\pi(.\\mid s_t, g) \\mathbf{1}[\\exists t \\in [1,\\dots, T]: s_t \\in S^g] \\end{aligned} $$ Here $R^g(\\pi)$ is the expected return, also equivalent to the success probability. Given sampled trajectories from the current policy, as long as any state belongs to the goal set, the return will be positive.\nTheir approach iterates through 3 steps until the policy converges:\n Label a set of goals based on whether they are at the appropriate level of difficulty for the current policy. The set of goals at the appropriate level of difficulty are named GOID (short for \u0026ldquo;Goals of Intermediate Difficulty\u0026rdquo;).$\\text{GOID}_i := \\{g : R_\\text{min} \\leq R^g(\\pi_i) \\leq R_\\text{max} \\} \\subseteq G$ Here $R_\\text{min}$ and $R_\\text{max}$ can be interpreted as a minimum and maximum probability of reaching a goal over T time-steps. Train a Goal GAN model using labelled goals from step 1 to produce new goals Use these new goals to train the policy, improving its coverage objective. The Goal GAN generates a curriculum automatically:\n Generator $G(z)$: produces a new goal. =\u0026gt; expected to be a goal uniformly sampled from $GOID$ set. Discriminator $D(g)$: evaluates whether a goal can be achieved. =\u0026gt; expected to tell whether a goal is from $GOID$ set. The Goal GAN is constructed similar to LSGAN (Least-Squared GAN; Mao et al., (2017)), which has better stability of learning compared to vanilla GAN. According to LSGAN, we should minimize the following losses for $D$ and $G$ respectively:\n $$ \\begin{aligned} \\mathcal{L}_\\text{LSGAN}(D) \u0026= \\frac{1}{2} \\mathbb{E}_{g \\sim p_\\text{data}(g)} [ (D(g) - b)^2] + \\frac{1}{2} \\mathbb{E}_{z \\sim p_z(z)} [ (D(G(z)) - a)^2] \\\\ \\mathcal{L}_\\text{LSGAN}(G) \u0026= \\frac{1}{2} \\mathbb{E}_{z \\sim p_z(z)} [ (D(G(z)) - c)^2] \\end{aligned} $$ where $a$ is the label for fake data, $b$ for real data, and $c$ is the value that $G$ wants $D$ to believe for fake data. In LSGAN paper\u0026rsquo;s experiments, they used $a=-1, b=1, c=0$.\nThe Goal GAN introduces an extra binary flag $y_b$ indicating whether a goal $g$ is real ($y_g = 1$) or fake ($y_g = 0$) so that the model can use negative samples for training:\n $$ \\begin{aligned} \\mathcal{L}_\\text{GoalGAN}(D) \u0026= \\frac{1}{2} \\mathbb{E}_{g \\sim p_\\text{data}(g)} [ (D(g) - b)^2 + (1-y_g) (D(g) - a)^2] + \\frac{1}{2} \\mathbb{E}_{z \\sim p_z(z)} [ (D(G(z)) - a)^2] \\\\ \\mathcal{L}_\\text{GoalGAN}(G) \u0026= \\frac{1}{2} \\mathbb{E}_{z \\sim p_z(z)} [ (D(G(z)) - c)^2] \\end{aligned} $$ Fig. 6. The algorithm of Generative Goal Learning. (Image source: (Florensa, et al. 2018) Following the same idea, Racaniere \u0026amp; Lampinen, et al. (2019) designs a method to make the objectives of goal generator more sophisticated. Their method contains three components, same as generative goal learning above:\n Solver/Policy $\\pi$: In each episode, the solver gets a goal $g$ at the beginning and get a single binary reward $R^g$ at the end. Judge/Discriminator $D(.)$: A classifier to predict the binary reward (whether goal can be achieved or not); precisely it outputs the logit of a probability of achieving the given goal, $\\sigma(D(g)) = p(R^g=1\\vert g)$, where $\\sigma$ is the sigmoid function. Setter/Generator $G(.)$: The goal setter takes as input a desired feasibility score $f \\in \\text{Unif}(0, 1)$ and generates $g = G(z, f)$, where the latent variable $z$ is sampled by $z \\sim \\mathcal{N}(0, I)$. The goal generator is designed to reversible, so $G^{-1}$ can map backwards from a goal $g$ to a latent $z = G^{-1}(g, f)$ The generator is optimized with three objectives:\n Goal validity: The proposed goal should be achievable by an expert policy. The corresponding generative loss is designed to increase the likelihood of generating goals that the solver policy has achieved before (like in HER). $\\mathcal{L}_\\text{val}$ is the negative log-likelihood of generated goals that have been solved by the solver in the past. $$ \\begin{align*} \\mathcal{L}_\\text{val} = \\mathbb{E}_{\\substack{ g \\sim \\text{ achieved by solver}, \\\\ \\xi \\in \\text{Uniform}(0, \\delta), \\\\ f \\in \\text{Uniform}(0, 1) }} \\big[ -\\log p(G^{-1}(g + \\xi, f)) \\big] \\end{align*} $$ Goal feasibility: The proposed goal should be achievable by the current policy; that is, the level of difficulty should be appropriate. $\\mathcal{L}_\\text{feas}$ is the output probability by the judge model $D$ on the generated goal $G(z, f)$ should match the desired $f$. $$ \\begin{align*} \\mathcal{L}_\\text{feas} = \\mathbb{E}_{\\substack{ z \\in \\mathcal{N}(0, 1), \\\\ f \\in \\text{Uniform}(0, 1) }} \\big[ D(G(z, f)) - \\sigma^{-1}(f)^2 \\big] \\end{align*} $$ Goal coverage: We should maximize the entropy of generated goals to encourage diverse goal and to improve the coverage over the goal space. $$ \\begin{align*} \\mathcal{L}_\\text{cov} = \\mathbb{E}_{\\substack{ z \\in \\mathcal{N}(0, 1), \\\\ f \\in \\text{Uniform}(0, 1) }} \\big[ \\log p(G(z, f)) \\big] \\end{align*} $$ Their experiments showed complex environments require all three losses above. When the environment is changing between episodes, both the goal generator and the discriminator need to be conditioned on environmental observation to produce better results. If there is a desired goal distribution, an additional loss can be added to match a desired goal distribution using Wasserstein distance. Using this loss, the generator can push the solver toward mastering the desired tasks more efficiently.\nFig. 7. Training schematic for the (a) solver/policy, (b) judge/discriminator, and (c) setter/goal generator models. (Image source: Racaniere \u0026 Lampinen, et al., 2019) Skill-Based Curriculum Another view is to decompose what an agent is able to complete into a variety of skills and each skill set could be mapped into a task. Let\u0026rsquo;s imagine when an agent interacts with the environment in an unsupervised manner, is there a way to discover useful skills from such interaction and further build into the solutions for more complicated tasks through a curriculum?\nJabri, et al. (2019) developed an automatic curriculum, CARML (short for \u0026ldquo;Curricula for Unsupervised Meta-Reinforcement Learning\u0026rdquo;), by modeling unsupervised trajectories into a latent skill space, with a focus on training meta-RL policies (i.e. can transfer to unseen tasks). The setting of training environments in CARML is similar to DIAYN. Differently, CARML is trained on pixel-level observations but DIAYN operates on the true state space. An RL algorithm $\\pi_\\theta$, parameterized by $\\theta$, is trained via unsupervised interaction formulated as a CMP combined with a learned reward function $r$. This setting naturally works for the meta-learning purpose, since a customized reward function can be given only at the test time.\nFig. 8. An illustration of CARML, containing two steps: (1) organizing experiential data into the latent skill space; (2) meta-training the policy with the reward function constructed from the learned skills. (Image source: Jabri, et al 2019) CARML is framed as a variational Expectation-Maximization (EM).\n(1) E-Step: This is the stage for organizing experiential data. Collected trajectories are modeled with a mixture of latent components forming the basis of skills.\nLet $z$ be a latent task variable and $q_\\phi$ be a variational distribution of $z$, which could be a mixture model with discrete $z$ or a VAE with continuous $z$. A variational posterior $q_\\phi(z \\vert s)$ works like a classifier, predicting a skill given a state, and we would like to maximize $q_\\phi(z \\vert s)$ to discriminate between data produced by different skills as much as possible. In E-step, $q_\\phi$ is fitted to a set of trajectories produced by $\\pi_\\theta$.\nPrecisely, given a trajectory $\\tau = (s_1,\\dots,s_T)$, we would like to find $\\phi$ such that\n $$ \\max_\\phi \\mathbb{E}_{z\\sim q_\\phi(z)} \\big[ \\log q_\\phi(\\tau \\vert z) \\big] = \\max_\\phi \\mathbb{E}_{z\\sim q_\\phi(z)} \\big[ \\sum_{s_i \\in \\tau} \\log q_\\phi(s_i \\vert z) \\big] $$ A simplifying assumption is made here to ignore the order of states in one trajectory.\n(2) M-Step: This is the stage for doing meta-RL training with $\\pi_\\theta$. The learned skill space is considered as a training task distribution. CARML is agnostic to the type of meta-RL algorithm for policy parameter updates.\nGiven a trajectory $\\tau$, it makes sense for the policy to maximize the mutual information between $\\tau$ and $z$, $I(\\tau;z) = H(\\tau) - H(\\tau \\vert z)$, because:\n maximizing $H(\\tau)$ =\u0026gt; diversity in the policy data space; expected to be large. minimizing $H(\\tau \\vert z)$ =\u0026gt; given a certain skill, the behavior should be restricted; expected to be small. Then we have,\n $$ \\begin{aligned} I(\\tau; z) \u0026= \\mathcal{H}(z) - \\mathcal{H}(z \\vert s_1,\\dots, s_T) \\\\ \u0026\\geq \\mathbb{E}_{s \\in \\tau} [\\mathcal{H}(z) - \\mathcal{H}(z\\vert s)] \u0026 \\scriptstyle{\\text{; discard the order of states.}} \\\\ \u0026= \\mathbb{E}_{s \\in \\tau} [\\mathcal{H}(s_t) - \\mathcal{H}(s\\vert z)] \u0026 \\scriptstyle{\\text{; by definition of MI.}} \\\\ \u0026= \\mathbb{E}_{z\\sim q_\\phi(z), s\\sim \\pi_\\theta(s|z)} [\\log q_\\phi(s|z) - \\log \\pi_\\theta(s)] \\\\ \u0026\\approx \\mathbb{E}_{z\\sim q_\\phi(z), s\\sim \\pi_\\theta(s|z)} [\\color{green}{\\log q_\\phi(s|z) - \\log q_\\phi(s)}] \u0026 \\scriptstyle{\\text{; assume learned marginal distr. matches policy.}} \\end{aligned} $$ We can set the reward as $\\log q_\\phi(s \\vert z) - \\log q_\\phi(s)$, as shown in the red part in the equation above. In order to balance between task-specific exploration (as in red below) and latent skill matching (as in blue below) , a parameter $\\lambda \\in [0, 1]$ is added. Each realization of $z \\sim q_\\phi(z)$ induces a reward function $r_z(s)$ (remember that reward + CMP =\u0026gt; MDP) as follows:\n $$ \\begin{aligned} r_z(s) \u0026= \\lambda \\log q_\\phi(s|z) - \\log q_\\phi(s) \\\\ \u0026= \\lambda \\log q_\\phi(s|z) - \\log \\frac{q_\\phi(s|z) q_\\phi(z)}{q_\\phi(z|s)} \\\\ \u0026= \\lambda \\log q_\\phi(s|z) - \\log q_\\phi(s|z) - \\log q_\\phi(z) + \\log q_\\phi(z|s) \\\\ \u0026= (\\lambda - 1) \\log \\color{red}{q_\\phi(s|z)} + \\color{blue}{\\log q_\\phi(z|s)} + C \\end{aligned} $$ Fig. 9. The algorithm of CARML. (Image source: Jabri, et al 2019) Learning a latent skill space can be done in different ways, such as in Hausman, et al. 2018. The goal of their approach is to learn a task-conditioned policy, $\\pi(a \\vert s, t^{(i)})$, where $t^{(i)}$ is from a discrete list of $N$ tasks, $\\mathcal{T} = [t^{(1)}, \\dots, t^{(N)}]$. However, rather than learning $N$ separate solutions, one per task, it would be nice to learn a latent skill space so that each task could be represented in a distribution over skills and thus skills are reused between tasks. The policy is defined as $\\pi_\\theta(a \\vert s,t) = \\int \\pi_\\theta(a \\vert z,s,t) p_\\phi(z \\vert t)\\mathrm{d}z$, where $\\pi_\\theta$ and $p_\\phi$ are policy and embedding networks to learn, respectively. If $z$ is discrete, i.e. drawn from a set of $K$ skills, then the policy becomes a mixture of $K$ sub-policies. The policy training uses SAC and the dependency on $z$ is introduced in the entropy term.\nCurriculum through Distillation [I was thinking of the name of this section for a while, deciding between cloning, inheritance, and distillation. Eventually, I picked distillation because it sounds the coolest B-)]\nThe motivation for the progressive neural network (Rusu et al. 2016) architecture is to efficiently transfer learned skills between different tasks and in the meantime avoid catastrophic forgetting. The curriculum is realized through a set of progressively stacked neural network towers (or \u0026ldquo;columns\u0026rdquo;, as in the paper).\nA progressive network has the following structure:\n It starts with a single column containing $L$ layers of neurons, in which the corresponding activation layers are labelled as $h^{(1)}_i, i=1, \\dots, L$. We first train this single-column network for one task to convergence, achieving parameter config $\\theta^{(1)}$.\n Once switch to the next task, we need to add a new column to adapt to the new context while freezing $\\theta^{(1)}$ to lock down the learned skills from the previous task. The new column has activation layers labelled as $h^{(2)}_i, i=1, \\dots, L$, and parameters $\\theta^{(2)}$.\n Step 2 can be repeated with every new task. The $i$-th layer activation in the $k$-th column depends on the previous activation layers in all the existing columns:\n $$ h^{(k)}_i = f(W^{(k)}_i h^{(k)}_{i-1} + \\sum_{j where $W^{(k)}_i$ is the weight matrix of the layer $i$ in the column $k$; $U_i^{(k:j)}, j \u0026lt; k$ are the weight matrices for projecting the layer $i-1$ of the column $j$ to the layer $i$ of column $k$ ($ j \u0026lt; k $). The above weights matrices should be learned. $f(.)$ is a non-linear activation function by choice.\n Fig. 10. The progressive neural network architecture. (Image source: Rusu, et al. 2017) The paper experimented with Atari games by training a progressive network on multiple games to check whether features learned in one game can transfer to another. That is indeed the case. Though interestingly, learning a high dependency on features in the previous columns does not always indicate good transfer performance on the new task. One hypothesis is that features learned from the old task might introduce biases into the new task, leading to policy getting trapped in a sub-optimal solution. Overall, the progressive network works better than only fine-tuning the top layer and can achieve similar transfer performance as fine-tuning the entire network.\nOne use case for the progressive network is to do sim2real transfer (Rusu, et al. 2017), in which the first column is trained in simulator with a lot of samples and then the additional columns (could be for different real-world tasks) are added and trained with a few real data samples.\nCzarnecki, et al. (2018) proposed another RL training framework, Mix \u0026amp; Match (short for M\u0026amp;M) to provide curriculum through coping knowledge between agents. Given a sequence of agents from simple to complex, $\\pi_1, \\dots, \\pi_K$, each parameterized with some shared weights (e.g. by shared some lower common layers). M\u0026amp;M trains a mixture of agents, but only the final performance of the most complex one $\\pi_K$ matters.\nIn the meantime, M\u0026amp;M learns a categorical distribution $c \\sim \\text{Categorical}(1, \\dots, K \\vert \\alpha)$ with pmf $p(c=i) = \\alpha_i$ probability to pick which policy to use at a given time. The mixed M\u0026amp;M policy is a simple weighted sum: $\\pi_\\text{mm}(a \\vert s) = \\sum_{i=1}^K \\alpha_i \\pi_i(a \\vert s)$. Curriculum learning is realized by dynamically adjusting $\\alpha_i$, from $\\alpha_K=0$ to $\\alpha_K=1$. The tuning of $\\alpha$ can be manual or through population-based training.\nTo encourage cooperation rather than competition among policies, besides the RL loss $\\mathcal{L}_\\text{RL}$, another distillation-like loss $\\mathcal{L}_\\text{mm}(\\theta)$ is added. The knowledge transfer loss $\\mathcal{L}_\\text{mm}(\\theta)$ measures the KL divergence between two policies, $\\propto D_\\text{KL}(\\pi_{i}(. \\vert s) | \\pi_j(. \\vert s))$ for $i \u0026lt; j$. It encourages complex agents to match the simpler ones early on. The final loss is $\\mathcal{L} = \\mathcal{L}_\\text{RL}(\\theta \\vert \\pi_\\text{mm}) + \\lambda \\mathcal{L}_\\text{mm}(\\theta)$.\nFig. 11. The Mix \u0026 Match architecture for training a mixture of policies. (Image source: Czarnecki, et al., 2018) Citation Cited as:\n Weng, Lilian. (Jan 2020). Curriculum for reinforcement learning. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2020-01-29-curriculum-rl/.\n Or\n@article{weng2020curriculum, title = \u0026quot;Curriculum for Reinforcement Learning\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2020\u0026quot;, month = \u0026quot;Jan\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2020-01-29-curriculum-rl/\u0026quot; } References [1] Jeffrey L. Elman. \u0026ldquo;Learning and development in neural networks: The importance of starting small.\u0026quot; Cognition 48.1 (1993): 71-99.\n[2] Yoshua Bengio, et al. \u0026ldquo;Curriculum learning.\u0026quot; ICML 2009.\n[3] Daphna Weinshall, Gad Cohen, and Dan Amir. \u0026ldquo;Curriculum learning by transfer learning: Theory and experiments with deep networks.\u0026quot; ICML 2018.\n[4] Wojciech Zaremba and Ilya Sutskever. \u0026ldquo;Learning to execute.\u0026quot; arXiv preprint arXiv:1410.4615 (2014).\n[5] Tambet Matiisen, et al. \u0026ldquo;Teacher-student curriculum learning.\u0026quot; IEEE Trans. on neural networks and learning systems (2017).\n[6] Alex Graves, et al. \u0026ldquo;Automated curriculum learning for neural networks.\u0026quot; ICML 2017.\n[7] Remy Portelas, et al. Teacher algorithms for curriculum learning of Deep RL in continuously parameterized environments. CoRL 2019.\n[8] Sainbayar Sukhbaatar, et al. \u0026ldquo;Intrinsic Motivation and Automatic Curricula via Asymmetric Self-Play.\u0026quot; ICLR 2018.\n[9] Carlos Florensa, et al. \u0026ldquo;Automatic Goal Generation for Reinforcement Learning Agents\u0026rdquo; ICML 2019.\n[10] Sebastien Racaniere \u0026amp; Andrew K. Lampinen, et al. \u0026ldquo;Automated Curriculum through Setter-Solver Interactions\u0026rdquo; ICLR 2020.\n[11] Allan Jabri, et al. \u0026ldquo;Unsupervised Curricula for Visual Meta-Reinforcement Learning\u0026rdquo; NeuriPS 2019.\n[12] Karol Hausman, et al. \u0026ldquo;Learning an Embedding Space for Transferable Robot Skills \u0026ldquo; ICLR 2018.\n[13] Josh Merel, et al. \u0026ldquo;Reusable neural skill embeddings for vision-guided whole body movement and object manipulation\u0026rdquo; arXiv preprint arXiv:1911.06636 (2019).\n[14] OpenAI, et al. \u0026ldquo;Solving Rubik\u0026rsquo;s Cube with a Robot Hand.\u0026quot; arXiv preprint arXiv:1910.07113 (2019).\n[15] Niels Justesen, et al. \u0026ldquo;Illuminating Generalization in Deep Reinforcement Learning through Procedural Level Generation\u0026rdquo; NeurIPS 2018 Deep RL Workshop.\n[16] Karl Cobbe, et al. \u0026ldquo;Quantifying Generalization in Reinforcement Learning\u0026rdquo; arXiv preprint arXiv:1812.02341 (2018).\n[17] Andrei A. Rusu et al. \u0026ldquo;Progressive Neural Networks\u0026rdquo; arXiv preprint arXiv:1606.04671 (2016).\n[18] Andrei A. Rusu et al. \u0026ldquo;Sim-to-Real Robot Learning from Pixels with Progressive Nets.\u0026quot; CoRL 2017.\n[19] Wojciech Marian Czarnecki, et al. \u0026ldquo;Mix \u0026amp; Match – Agent Curricula for Reinforcement Learning.\u0026quot; ICML 2018.\n","permalink":"https://lilianweng.github.io/posts/2020-01-29-curriculum-rl/","summary":"[Updated on 2020-02-03: mentioning PCG in the \u0026ldquo;Task-Specific Curriculum\u0026rdquo; section. [Updated on 2020-02-04: Add a new \u0026ldquo;curriculum through distillation\u0026rdquo; section.\nIt sounds like an impossible task if we want to teach integral or derivative to a 3-year-old who does not even know basic arithmetics. That\u0026rsquo;s why education is important, as it provides a systematic way to break down complex knowledge and a nice curriculum for teaching concepts from simple to hard.","title":"Curriculum for Reinforcement Learning"},{"content":"[Updated on 2020-01-09: add a new section on Contrastive Predictive Coding]. [Updated on 2020-04-13: add a \u0026ldquo;Momentum Contrast\u0026rdquo; section on MoCo, SimCLR and CURL.] [Updated on 2020-07-08: add a \u0026ldquo;Bisimulation\u0026rdquo; section on DeepMDP and DBC.] [Updated on 2020-09-12: add MoCo V2 and BYOL in the \u0026ldquo;Momentum Contrast\u0026rdquo; section.] [Updated on 2021-05-31: remove section on \u0026ldquo;Momentum Contrast\u0026rdquo; and add a pointer to a full post on \u0026ldquo;Contrastive Representation Learning\u0026rdquo;]\nGiven a task and enough labels, supervised learning can solve it really well. Good performance usually requires a decent amount of labels, but collecting manual labels is expensive (i.e. ImageNet) and hard to be scaled up. Considering the amount of unlabelled data (e.g. free text, all the images on the Internet) is substantially more than a limited number of human curated labelled datasets, it is kinda wasteful not to use them. However, unsupervised learning is not easy and usually works much less efficiently than supervised learning.\nWhat if we can get labels for free for unlabelled data and train unsupervised dataset in a supervised manner? We can achieve this by framing a supervised learning task in a special form to predict only a subset of information using the rest. In this way, all the information needed, both inputs and labels, has been provided. This is known as self-supervised learning.\nThis idea has been widely used in language modeling. The default task for a language model is to predict the next word given the past sequence. BERT adds two other auxiliary tasks and both rely on self-generated labels.\nFig. 1. A great summary of how self-supervised learning tasks can be constructed (Image source: LeCun’s talk) Here is a nicely curated list of papers in self-supervised learning. Please check it out if you are interested in reading more in depth.\nNote that this post does not focus on either NLP / language modeling or generative modeling.\nWhy Self-Supervised Learning? Self-supervised learning empowers us to exploit a variety of labels that come with the data for free. The motivation is quite straightforward. Producing a dataset with clean labels is expensive but unlabeled data is being generated all the time. To make use of this much larger amount of unlabeled data, one way is to set the learning objectives properly so as to get supervision from the data itself.\nThe self-supervised task, also known as pretext task, guides us to a supervised loss function. However, we usually don’t care about the final performance of this invented task. Rather we are interested in the learned intermediate representation with the expectation that this representation can carry good semantic or structural meanings and can be beneficial to a variety of practical downstream tasks.\nFor example, we might rotate images at random and train a model to predict how each input image is rotated. The rotation prediction task is made-up, so the actual accuracy is unimportant, like how we treat auxiliary tasks. But we expect the model to learn high-quality latent variables for real-world tasks, such as constructing an object recognition classifier with very few labeled samples.\nBroadly speaking, all the generative models can be considered as self-supervised, but with different goals: Generative models focus on creating diverse and realistic images, while self-supervised representation learning care about producing good features generally helpful for many tasks. Generative modeling is not the focus of this post, but feel free to check my previous posts.\nImages-Based Many ideas have been proposed for self-supervised representation learning on images. A common workflow is to train a model on one or multiple pretext tasks with unlabelled images and then use one intermediate feature layer of this model to feed a multinomial logistic regression classifier on ImageNet classification. The final classification accuracy quantifies how good the learned representation is.\nRecently, some researchers proposed to train supervised learning on labelled data and self-supervised pretext tasks on unlabelled data simultaneously with shared weights, like in Zhai et al, 2019 and Sun et al, 2019.\nDistortion We expect small distortion on an image does not modify its original semantic meaning or geometric forms. Slightly distorted images are considered the same as original and thus the learned features are expected to be invariant to distortion.\nExemplar-CNN (Dosovitskiy et al., 2015) create surrogate training datasets with unlabeled image patches:\n Sample $N$ patches of size 32 × 32 pixels from different images at varying positions and scales, only from regions containing considerable gradients as those areas cover edges and tend to contain objects or parts of objects. They are \u0026ldquo;exemplary\u0026rdquo; patches. Each patch is distorted by applying a variety of random transformations (i.e., translation, rotation, scaling, etc.). All the resulting distorted patches are considered to belong to the same surrogate class. The pretext task is to discriminate between a set of surrogate classes. We can arbitrarily create as many surrogate classes as we want. Fig. 2. The original patch of a cute deer is in the top left corner. Random transformations are applied, resulting in a variety of distorted patches. All of them should be classified into the same class in the pretext task. (Image source: Dosovitskiy et al., 2015) Rotation of an entire image (Gidaris et al. 2018 is another interesting and cheap way to modify an input image while the semantic content stays unchanged. Each input image is first rotated by a multiple of $90^\\circ$ at random, corresponding to $[0^\\circ, 90^\\circ, 180^\\circ, 270^\\circ]$. The model is trained to predict which rotation has been applied, thus a 4-class classification problem.\nIn order to identify the same image with different rotations, the model has to learn to recognize high level object parts, such as heads, noses, and eyes, and the relative positions of these parts, rather than local patterns. This pretext task drives the model to learn semantic concepts of objects in this way.\nFig. 3. Illustration of self-supervised learning by rotating the entire input images. The model learns to predict which rotation is applied. (Image source: Gidaris et al. 2018) Patches The second category of self-supervised learning tasks extract multiple patches from one image and ask the model to predict the relationship between these patches.\nDoersch et al. (2015) formulates the pretext task as predicting the relative position between two random patches from one image. A model needs to understand the spatial context of objects in order to tell the relative position between parts.\nThe training patches are sampled in the following way:\n Randomly sample the first patch without any reference to image content. Considering that the first patch is placed in the middle of a 3x3 grid, and the second patch is sampled from its 8 neighboring locations around it. To avoid the model only catching low-level trivial signals, such as connecting a straight line across boundary or matching local patterns, additional noise is introduced by: Add gaps between patches Small jitters Randomly downsample some patches to as little as 100 total pixels, and then upsampling it, to build robustness to pixelation. Shift green and magenta toward gray or randomly drop 2 of 3 color channels (See \u0026ldquo;chromatic aberration\u0026rdquo; below) The model is trained to predict which one of 8 neighboring locations the second patch is selected from, a classification problem over 8 classes. Fig. 4. Illustration of self-supervised learning by predicting the relative position of two random patches. (Image source: Doersch et al., 2015) Other than trivial signals like boundary patterns or textures continuing, another interesting and a bit surprising trivial solution was found, called \u0026ldquo;chromatic aberration\u0026rdquo;. It is triggered by different focal lengths of lights at different wavelengths passing through the lens. In the process, there might exist small offsets between color channels. Hence, the model can learn to tell the relative position by simply comparing how green and magenta are separated differently in two patches. This is a trivial solution and has nothing to do with the image content. Pre-processing images by shifting green and magenta toward gray or randomly dropping 2 of 3 color channels can avoid this trivial solution.\nFig. 5. Illustration of how chromatic aberration happens. (Image source: wikipedia) Since we have already set up a 3x3 grid in each image in the above task, why not use all of 9 patches rather than only 2 to make the task more difficult? Following this idea, Noroozi \u0026amp; Favaro (2016) designed a jigsaw puzzle game as pretext task: The model is trained to place 9 shuffled patches back to the original locations.\nA convolutional network processes each patch independently with shared weights and outputs a probability vector per patch index out of a predefined set of permutations. To control the difficulty of jigsaw puzzles, the paper proposed to shuffle patches according to a predefined permutation set and configured the model to predict a probability vector over all the indices in the set.\nBecause how the input patches are shuffled does not alter the correct order to predict. A potential improvement to speed up training is to use permutation-invariant graph convolutional network (GCN) so that we don’t have to shuffle the same set of patches multiple times, same idea as in this paper.\nFig. 6. Illustration of self-supervised learning by solving jigsaw puzzle. (Image source: Noroozi \u0026 Favaro, 2016) Another idea is to consider \u0026ldquo;feature\u0026rdquo; or \u0026ldquo;visual primitives\u0026rdquo; as a scalar-value attribute that can be summed up over multiple patches and compared across different patches. Then the relationship between patches can be defined by counting features and simple arithmetic (Noroozi, et al, 2017).\nThe paper considers two transformations:\n Scaling: If an image is scaled up by 2x, the number of visual primitives should stay the same. Tiling: If an image is tiled into a 2x2 grid, the number of visual primitives is expected to be the sum, 4 times the original feature counts. The model learns a feature encoder $\\phi(.)$ using the above feature counting relationship. Given an input image $\\mathbf{x} \\in \\mathbb{R}^{m \\times n \\times 3}$, considering two types of transformation operators:\n Downsampling operator, $D: \\mathbb{R}^{m \\times n \\times 3} \\mapsto \\mathbb{R}^{\\frac{m}{2} \\times \\frac{n}{2} \\times 3}$: downsample by a factor of 2 Tiling operator $T_i: \\mathbb{R}^{m \\times n \\times 3} \\mapsto \\mathbb{R}^{\\frac{m}{2} \\times \\frac{n}{2} \\times 3}$: extract the $i$-th tile from a 2x2 grid of the image. We expect to learn:\n $$ \\phi(\\mathbf{x}) = \\phi(D \\circ \\mathbf{x}) = \\sum_{i=1}^4 \\phi(T_i \\circ \\mathbf{x}) $$ Thus the MSE loss is: $\\mathcal{L}_\\text{feat} = |\\phi(D \\circ \\mathbf{x}) - \\sum_{i=1}^4 \\phi(T_i \\circ \\mathbf{x})|^2_2$. To avoid trivial solution $\\phi(\\mathbf{x}) = \\mathbf{0}, \\forall{\\mathbf{x}}$, another loss term is added to encourage the difference between features of two different images: $\\mathcal{L}_\\text{diff} = \\max(0, c -|\\phi(D \\circ \\mathbf{y}) - \\sum_{i=1}^4 \\phi(T_i \\circ \\mathbf{x})|^2_2)$, where $\\mathbf{y}$ is another input image different from $\\mathbf{x}$ and $c$ is a scalar constant. The final loss is:\n $$ \\mathcal{L} = \\mathcal{L}_\\text{feat} + \\mathcal{L}_\\text{diff} = \\|\\phi(D \\circ \\mathbf{x}) - \\sum_{i=1}^4 \\phi(T_i \\circ \\mathbf{x})\\|^2_2 + \\max(0, M -\\|\\phi(D \\circ \\mathbf{y}) - \\sum_{i=1}^4 \\phi(T_i \\circ \\mathbf{x})\\|^2_2) $$ Fig. 7. Self-supervised representation learning by counting features. (Image source: Noroozi, et al, 2017) Colorization Colorization can be used as a powerful self-supervised task: a model is trained to color a grayscale input image; precisely the task is to map this image to a distribution over quantized color value outputs (Zhang et al. 2016).\nThe model outputs colors in the the CIE Lab* color space. The Lab* color is designed to approximate human vision, while, in contrast, RGB or CMYK models the color output of physical devices.\n L* component matches human perception of lightness; L* = 0 is black and L* = 100 indicates white. a* component represents green (negative) / magenta (positive) value. b* component models blue (negative) /yellow (positive) value. Due to the multimodal nature of the colorization problem, cross-entropy loss of predicted probability distribution over binned color values works better than L2 loss of the raw color values. The ab color space is quantized with bucket size 10.\nTo balance between common colors (usually low ab values, of common backgrounds like clouds, walls, and dirt) and rare colors (which are likely associated with key objects in the image), the loss function is rebalanced with a weighting term that boosts the loss of infrequent color buckets. This is just like why we need both tf and idf for scoring words in information retrieval model. The weighting term is constructed as: (1-λ) * Gaussian-kernel-smoothed empirical probability distribution + λ * a uniform distribution, where both distributions are over the quantized ab color space.\nGenerative Modeling The pretext task in generative modeling is to reconstruct the original input while learning meaningful latent representation.\nThe denoising autoencoder (Vincent, et al, 2008) learns to recover an image from a version that is partially corrupted or has random noise. The design is inspired by the fact that humans can easily recognize objects in pictures even with noise, indicating that key visual features can be extracted and separated from noise. See my old post.\nThe context encoder (Pathak, et al., 2016) is trained to fill in a missing piece in the image. Let $\\hat{M}$ be a binary mask, 0 for dropped pixels and 1 for remaining input pixels. The model is trained with a combination of the reconstruction (L2) loss and the adversarial loss. The removed regions defined by the mask could be of any shape.\n $$ \\begin{aligned} \\mathcal{L}(\\mathbf{x}) \u0026= \\mathcal{L}_\\text{recon}(\\mathbf{x}) + \\mathcal{L}_\\text{adv}(\\mathbf{x})\\\\ \\mathcal{L}_\\text{recon}(\\mathbf{x}) \u0026= \\|(1 - \\hat{M}) \\odot (\\mathbf{x} - E(\\hat{M} \\odot \\mathbf{x})) \\|_2^2 \\\\ \\mathcal{L}_\\text{adv}(\\mathbf{x}) \u0026= \\max_D \\mathbb{E}_{\\mathbf{x}} [\\log D(\\mathbf{x}) + \\log(1 - D(E(\\hat{M} \\odot \\mathbf{x})))] \\end{aligned} $$ where $E(.)$ is the encoder and $D(.)$ is the decoder.\nFig. 8. Illustration of context encoder. (Image source: Pathak, et al., 2016) When applying a mask on an image, the context encoder removes information of all the color channels in partial regions. How about only hiding a subset of channels? The split-brain autoencoder (Zhang et al., 2017) does this by predicting a subset of color channels from the rest of channels. Let the data tensor $\\mathbf{x} \\in \\mathbb{R}^{h \\times w \\times \\vert C \\vert }$ with $C$ color channels be the input for the $l$-th layer of the network. It is split into two disjoint parts, $\\mathbf{x}_1 \\in \\mathbb{R}^{h \\times w \\times \\vert C_1 \\vert}$ and $\\mathbf{x}_2 \\in \\mathbb{R}^{h \\times w \\times \\vert C_2 \\vert}$, where $C_1 , C_2 \\subseteq C$. Then two sub-networks are trained to do two complementary predictions: one network $f_1$ predicts $\\mathbf{x}_2$ from $\\mathbf{x}_1$ and the other network $f_1$ predicts $\\mathbf{x}_1$ from $\\mathbf{x}_2$. The loss is either L1 loss or cross entropy if color values are quantized.\nThe split can happen once on the RGB-D or Lab* colorspace, or happen even in every layer of a CNN network in which the number of channels can be arbitrary.\nFig. 9. Illustration of split-brain autoencoder. (Image source: Zhang et al., 2017) The generative adversarial networks (GANs) are able to learn to map from simple latent variables to arbitrarily complex data distributions. Studies have shown that the latent space of such generative models captures semantic variation in the data; e.g. when training GAN models on human faces, some latent variables are associated with facial expression, glasses, gender, etc (Radford et al., 2016).\nBidirectional GANs (Donahue, et al, 2017) introduces an additional encoder $E(.)$ to learn the mappings from the input to the latent variable $\\mathbf{z}$. The discriminator $D(.)$ predicts in the joint space of the input data and latent representation, $(\\mathbf{x}, \\mathbf{z})$, to tell apart the generated pair $(\\mathbf{x}, E(\\mathbf{x}))$ from the real one $(G(\\mathbf{z}), \\mathbf{z})$. The model is trained to optimize the objective: $\\min_{G, E} \\max_D V(D, E, G)$, where the generator $G$ and the encoder $E$ learn to generate data and latent variables that are realistic enough to confuse the discriminator and at the same time the discriminator $D$ tries to differentiate real and generated data.\n $$ V(D, E, G) = \\mathbb{E}_{\\mathbf{x} \\sim p_\\mathbf{x}} [ \\underbrace{\\mathbb{E}_{\\mathbf{z} \\sim p_E(.\\vert\\mathbf{x})}[\\log D(\\mathbf{x}, \\mathbf{z})]}_{\\log D(\\text{real})} ] + \\mathbb{E}_{\\mathbf{z} \\sim p_\\mathbf{z}} [ \\underbrace{\\mathbb{E}_{\\mathbf{x} \\sim p_G(.\\vert\\mathbf{z})}[\\log 1 - D(\\mathbf{x}, \\mathbf{z})]}_{\\log(1- D(\\text{fake}))}) ] $$ Fig. 10. Illustration of how Bidirectional GAN works. (Image source: Donahue, et al, 2017) Contrastive Learning The Contrastive Predictive Coding (CPC) (van den Oord, et al. 2018) is an approach for unsupervised learning from high-dimensional data by translating a generative modeling problem to a classification problem. The contrastive loss or InfoNCE loss in CPC, inspired by Noise Contrastive Estimation (NCE), uses cross-entropy loss to measure how well the model can classify the \u0026ldquo;future\u0026rdquo; representation amongst a set of unrelated \u0026ldquo;negative\u0026rdquo; samples. Such design is partially motivated by the fact that the unimodal loss like MSE has no enough capacity but learning a full generative model could be too expensive.\nFig. 11. Illustration of applying Contrastive Predictive Coding on the audio input. (Image source: van den Oord, et al. 2018) CPC uses an encoder to compress the input data $z_t = g_\\text{enc}(x_t)$ and an autoregressive decoder to learn the high-level context that is potentially shared across future predictions, $c_t = g_\\text{ar}(z_{\\leq t})$. The end-to-end training relies on the NCE-inspired contrastive loss.\nWhile predicting future information, CPC is optimized to maximize the the mutual information between input $x$ and context vector $c$:\n $$ I(x; c) = \\sum_{x, c} p(x, c) \\log\\frac{p(x, c)}{p(x)p(c)} = \\sum_{x, c} p(x, c)\\log\\frac{p(x|c)}{p(x)} $$ Rather than modeling the future observations $p_k(x_{t+k} \\vert c_t)$ directly (which could be fairly expensive), CPC models a density function to preserve the mutual information between $x_{t+k}$ and $c_t$:\n $$ f_k(x_{t+k}, c_t) = \\exp(z_{t+k}^\\top W_k c_t) \\propto \\frac{p(x_{t+k}|c_t)}{p(x_{t+k})} $$ where $f_k$ can be unnormalized and a linear transformation $W_k^\\top c_t$ is used for the prediction with a different $W_k$ matrix for every step $k$.\nGiven a set of $N$ random samples $X = \\{x_1, \\dots, x_N\\}$ containing only one positive sample $x_t \\sim p(x_{t+k} \\vert c_t)$ and $N-1$ negative samples $x_{i \\neq t} \\sim p(x_{t+k})$, the cross-entropy loss for classifying the positive sample (where $\\frac{f_k}{\\sum f_k}$ is the prediction) correctly is:\n $$ \\mathcal{L}_N = - \\mathbb{E}_X \\Big[\\log \\frac{f_k(x_{t+k}, c_t)}{\\sum_{i=1}^N f_k (x_i, c_t)}\\Big] $$ Fig. 12. Illustration of applying Contrastive Predictive Coding on images. (Image source: van den Oord, et al. 2018) When using CPC on images (Henaff, et al. 2019), the predictor network should only access a masked feature set to avoid a trivial prediction. Precisely:\n Each input image is divided into a set of overlapped patches and each patch is encoded by a resnet encoder, resulting in compressed feature vector $z_{i,j}$. A masked conv net makes prediction with a mask such that the receptive field of a given output neuron can only see things above it in the image. Otherwise, the prediction problem would be trivial. The prediction can be made in both directions (top-down and bottom-up). The prediction is made for $z_{i+k, j}$ from context $c_{i,j}$: $\\hat{z}_{i+k, j} = W_k c_{i,j}$. A contrastive loss quantifies this prediction with a goal to correctly identify the target among a set of negative representation $\\{z_l\\}$ sampled from other patches in the same image and other images in the same batch:\n $$ \\mathcal{L}_\\text{CPC} = -\\sum_{i,j,k} \\log p(z_{i+k, j} \\vert \\hat{z}_{i+k, j}, \\{z_l\\}) = -\\sum_{i,j,k} \\log \\frac{\\exp(\\hat{z}_{i+k, j}^\\top z_{i+k, j})}{\\exp(\\hat{z}_{i+k, j}^\\top z_{i+k, j}) + \\sum_l \\exp(\\hat{z}_{i+k, j}^\\top z_l)} $$ For more content on contrastive learning, check out the post on \u0026ldquo;Contrastive Representation Learning\u0026rdquo;.\nVideo-Based A video contains a sequence of semantically related frames. Nearby frames are close in time and more correlated than frames further away. The order of frames describes certain rules of reasonings and physical logics; such as that object motion should be smooth and gravity is pointing down.\nA common workflow is to train a model on one or multiple pretext tasks with unlabelled videos and then feed one intermediate feature layer of this model to fine-tune a simple model on downstream tasks of action classification, segmentation or object tracking.\nTracking The movement of an object is traced by a sequence of video frames. The difference between how the same object is captured on the screen in close frames is usually not big, commonly triggered by small motion of the object or the camera. Therefore any visual representation learned for the same object across close frames should be close in the latent feature space. Motivated by this idea, Wang \u0026amp; Gupta, 2015 proposed a way of unsupervised learning of visual representation by tracking moving objects in videos.\nPrecisely patches with motion are tracked over a small time window (e.g. 30 frames). The first patch $\\mathbf{x}$ and the last patch $\\mathbf{x}^+$ are selected and used as training data points. If we train the model directly to minimize the difference between feature vectors of two patches, the model may only learn to map everything to the same value. To avoid such a trivial solution, same as above, a random third patch $\\mathbf{x}^-$ is added. The model learns the representation by enforcing the distance between two tracked patches to be closer than the distance between the first patch and a random one in the feature space, $D(\\mathbf{x}, \\mathbf{x}^-)) \u0026gt; D(\\mathbf{x}, \\mathbf{x}^+)$, where $D(.)$ is the cosine distance,\n $$ D(\\mathbf{x}_1, \\mathbf{x}_2) = 1 - \\frac{f(\\mathbf{x}_1) f(\\mathbf{x}_2)}{\\|f(\\mathbf{x}_1)\\| \\|f(\\mathbf{x}_2\\|)} $$ The loss function is:\n $$ \\mathcal{L}(\\mathbf{x}, \\mathbf{x}^+, \\mathbf{x}^-) = \\max\\big(0, D(\\mathbf{x}, \\mathbf{x}^+) - D(\\mathbf{x}, \\mathbf{x}^-) + M\\big) + \\text{weight decay regularization term} $$ where $M$ is a scalar constant controlling for the minimum gap between two distances; $M=0.5$ in the paper. The loss enforces $D(\\mathbf{x}, \\mathbf{x}^-) \u0026gt;= D(\\mathbf{x}, \\mathbf{x}^+) + M$ at the optimal case.\nThis form of loss function is also known as triplet loss in the face recognition task, in which the dataset contains images of multiple people from multiple camera angles. Let $\\mathbf{x}^a$ be an anchor image of a specific person, $\\mathbf{x}^p$ be a positive image of this same person from a different angle and $\\mathbf{x}^n$ be a negative image of a different person. In the embedding space, $\\mathbf{x}^a$ should be closer to $\\mathbf{x}^p$ than $\\mathbf{x}^n$:\n $$ \\mathcal{L}_\\text{triplet}(\\mathbf{x}^a, \\mathbf{x}^p, \\mathbf{x}^n) = \\max(0, \\|\\phi(\\mathbf{x}^a) - \\phi(\\mathbf{x}^p) \\|_2^2 - \\|\\phi(\\mathbf{x}^a) - \\phi(\\mathbf{x}^n) \\|_2^2 + M) $$ A slightly different form of the triplet loss, named n-pair loss is also commonly used for learning observation embedding in robotics tasks. See a later section for more related content.\nFig. 13. Overview of learning representation by tracking objects in videos. (a) Identify moving patches in short traces; (b) Feed two related patched and one random patch into a conv network with shared weights. (c) The loss function enforces the distance between related patches to be closer than the distance between random patches. (Image source: Wang \u0026 Gupta, 2015) Relevant patches are tracked and extracted through a two-step unsupervised optical flow approach:\n Obtain SURF interest points and use IDT to obtain motion of each SURF point. Given the trajectories of SURF interest points, classify these points as moving if the flow magnitude is more than 0.5 pixels. During training, given a pair of correlated patches $\\mathbf{x}$ and $\\mathbf{x}^+$, $K$ random patches $\\{\\mathbf{x}^-\\}$ are sampled in this same batch to form $K$ training triplets. After a couple of epochs, hard negative mining is applied to make the training harder and more efficient, that is, to search for random patches that maximize the loss and use them to do gradient updates.\nFrame Sequence Video frames are naturally positioned in chronological order. Researchers have proposed several self-supervised tasks, motivated by the expectation that good representation should learn the correct sequence of frames.\nOne idea is to validate frame order (Misra, et al 2016). The pretext task is to determine whether a sequence of frames from a video is placed in the correct temporal order (\u0026ldquo;temporal valid\u0026rdquo;). The model needs to track and reason about small motion of an object across frames to complete such a task.\nThe training frames are sampled from high-motion windows. Every time 5 frames are sampled $(f_a, f_b, f_c, f_d, f_e)$ and the timestamps are in order $a \u0026lt; b \u0026lt; c \u0026lt; d \u0026lt; e$. Out of 5 frames, one positive tuple $(f_b, f_c, f_d)$ and two negative tuples, $(f_b, f_a, f_d)$ and $(f_b, f_e, f_d)$ are created. The parameter $\\tau_\\max = \\vert b-d \\vert$ controls the difficulty of positive training instances (i.e. higher → harder) and the parameter $\\tau_\\min = \\min(\\vert a-b \\vert, \\vert d-e \\vert)$ controls the difficulty of negatives (i.e. lower → harder).\nThe pretext task of video frame order validation is shown to improve the performance on the downstream task of action recognition when used as a pretraining step.\nFig. 14. Overview of learning representation by validating the order of video frames. (a) the data sample process; (b) the model is a triplet siamese network, where all input frames have shared weights. (Image source: Misra, et al 2016) The task in O3N (Odd-One-Out Network; Fernando et al. 2017) is based on video frame sequence validation too. One step further from above, the task is to pick the incorrect sequence from multiple video clips.\nGiven $N+1$ input video clips, one of them has frames shuffled, thus in the wrong order, and the rest $N$ of them remain in the correct temporal order. O3N learns to predict the location of the odd video clip. In their experiments, there are 6 input clips and each contain 6 frames.\nThe arrow of time in a video contains very informative messages, on both low-level physics (e.g. gravity pulls objects down to the ground; smoke rises up; water flows downward.) and high-level event reasoning (e.g. fish swim forward; you can break an egg but cannot revert it.). Thus another idea is inspired by this to learn latent representation by predicting the arrow of time (AoT) \u0026mdash; whether video playing forwards or backwards (Wei et al., 2018).\nA classifier should capture both low-level physics and high-level semantics in order to predict the arrow of time. The proposed T-CAM (Temporal Class-Activation-Map) network accepts $T$ groups, each containing a number of frames of optical flow. The conv layer outputs from each group are concatenated and fed into binary logistic regression for predicting the arrow of time.\nFig. 15. Overview of learning representation by predicting the arrow of time. (a) Conv features of multiple groups of frame sequences are concatenated. (b) The top level contains 3 conv layers and average pooling. (Image source: Wei et al, 2018) Interestingly, there exist a couple of artificial cues in the dataset. If not handled properly, they could lead to a trivial classifier without relying on the actual video content:\n Due to the video compression, the black framing might not be completely black but instead may contain certain information on the chronological order. Hence black framing should be removed in the experiments. Large camera motion, like vertical translation or zoom-in/out, also provides strong signals for the arrow of time but independent of content. The processing stage should stabilize the camera motion. The AoT pretext task is shown to improve the performance on action classification downstream task when used as a pretraining step. Note that fine-tuning is still needed.\nVideo Colorization Vondrick et al. (2018) proposed video colorization as a self-supervised learning problem, resulting in a rich representation that can be used for video segmentation and unlabelled visual region tracking, without extra fine-tuning.\nUnlike the image-based colorization, here the task is to copy colors from a normal reference frame in color to another target frame in grayscale by leveraging the natural temporal coherency of colors across video frames (thus these two frames shouldn’t be too far apart in time). In order to copy colors consistently, the model is designed to learn to keep track of correlated pixels in different frames.\nFig. 16. Video colorization by copying colors from a reference frame to target frames in grayscale. (Image source: Vondrick et al. 2018) The idea is quite simple and smart. Let $c_i$ be the true color of the $i-th$ pixel in the reference frame and $c_j$ be the color of $j$-th pixel in the target frame. The predicted color of $j$-th color in the target $\\hat{c}_j$ is a weighted sum of colors of all the pixels in reference, where the weighting term measures the similarity:\n $$ \\hat{c}_j = \\sum_i A_{ij} c_i \\text{ where } A_{ij} = \\frac{\\exp(f_i f_j)}{\\sum_{i'} \\exp(f_{i'} f_j)} $$ where $f$ are learned embeddings for corresponding pixels; $i’$ indexes all the pixels in the reference frame. The weighting term implements an attention-based pointing mechanism, similar to matching network and pointer network. As the full similarity matrix could be really large, both frames are downsampled. The categorical cross-entropy loss between $c_j$ and $\\hat{c}_j$ is used with quantized colors, just like in Zhang et al. 2016.\nBased on how the reference frame are marked, the model can be used to complete several color-based downstream tasks such as tracking segmentation or human pose in time. No fine-tuning is needed. See Fig. 15.\nFig. 17. Use video colorization to track object segmentation and human pose in time. (Image source: Vondrick et al. (2018)) A couple common observations:\n Combining multiple pretext tasks improves performance; Deeper networks improve the quality of representation; Supervised learning baselines still beat all of them by far. Control-Based When running a RL policy in the real world, such as controlling a physical robot on visual inputs, it is non-trivial to properly track states, obtain reward signals or determine whether a goal is achieved for real. The visual data has a lot of noise that is irrelevant to the true state and thus the equivalence of states cannot be inferred from pixel-level comparison. Self-supervised representation learning has shown great potential in learning useful state embedding that can be used directly as input to a control policy.\nAll the cases discussed in this section are in robotic learning, mainly for state representation from multiple camera views and goal representation.\nMulti-View Metric Learning The concept of metric learning has been mentioned multiple times in the previous sections. A common setting is: Given a triple of samples, (anchor $s_a$, positive sample $s_p$, negative sample $s_n$), the learned representation embedding $\\phi(s)$ fulfills that $s_a$ stays close to $s_p$ but far away from $s_n$ in the latent space.\nGrasp2Vec (Jang \u0026amp; Devin et al., 2018) aims to learn an object-centric vision representation in the robot grasping task from free, unlabelled grasping activities. By object-centric, it means that, irrespective of how the environment or the robot looks like, if two images contain similar items, they should be mapped to similar representation; otherwise the embeddings should be far apart.\nFig. 18. A conceptual illustration of how grasp2vec learns an object-centric state embedding. (Image source: Jang \u0026 Devin et al., 2018) The grasping system can tell whether it moves an object but cannot tell which object it is. Cameras are set up to take images of the entire scene and the grasped object. During early training, the grasp robot is executed to grasp any object $o$ at random, producing a triple of images, $(s_\\text{pre}, s_\\text{post}, o)$:\n $o$ is an image of the grasped object held up to the camera; $s_\\text{pre}$ is an image of the scene before grasping, with the object $o$ in the tray; $s_\\text{post}$ is an image of the same scene after grasping, without the object $o$ in the tray. To learn object-centric representation, we expect the difference between embeddings of $s_\\text{pre}$ and $s_\\text{post}$ to capture the removed object $o$. The idea is quite interesting and similar to relationships that have been observed in word embedding, e.g. distance(\u0026ldquo;king\u0026rdquo;, \u0026ldquo;queen\u0026rdquo;) ≈ distance(\u0026ldquo;man\u0026rdquo;, \u0026ldquo;woman\u0026rdquo;).\nLet $\\phi_s$ and $\\phi_o$ be the embedding functions for the scene and the object respectively. The model learns the representation by minimizing the distance between $\\phi_s(s_\\text{pre}) - \\phi_s(s_\\text{post})$ and $\\phi_o(o)$ using n-pair loss:\n $$ \\begin{aligned} \\mathcal{L}_\\text{grasp2vec} \u0026= \\text{NPair}(\\phi_s(s_\\text{pre}) - \\phi_s(s_\\text{post}), \\phi_o(o)) + \\text{NPair}(\\phi_o(o), \\phi_s(s_\\text{pre}) - \\phi_s(s_\\text{post})) \\\\ \\text{where }\\text{NPair}(a, p) \u0026= \\sum_{iwhere $B$ refers to a batch of (anchor, positive) sample pairs.\nWhen framing representation learning as metric learning, n-pair loss is a common choice. Rather than processing explicit a triple of (anchor, positive, negative) samples, the n-pairs loss treats all other positive instances in one mini-batch across pairs as negatives.\nThe embedding function $\\phi_o$ works great for presenting a goal $g$ with an image. The reward function that quantifies how close the actually grasped object $o$ is close to the goal is defined as $r = \\phi_o(g) \\cdot \\phi_o(o)$. Note that computing rewards only relies on the learned latent space and doesn\u0026rsquo;t involve ground truth positions, so it can be used for training on real robots.\nFig. 19. Localization results of grasp2vec embedding. The heatmap of localizing a goal object in a pre-grasping scene is defined as $\\phi\\_o(o)^\\top \\phi\\_{s, \\text{spatial}} (s\\_\\text{pre})$, where $\\phi\\_{s, \\text{spatial}}$ is the output of the last resnet block after ReLU. The fourth column is a failure case and the last three columns take real images as goals. (Image source: Jang \u0026 Devin et al., 2018) Other than the embedding-similarity-based reward function, there are a few other tricks for training the RL policy in the grasp2vec framework:\n Posthoc labeling: Augment the dataset by labeling a randomly grasped object as a correct goal, like HER (Hindsight Experience Replay; Andrychowicz, et al., 2017). Auxiliary goal augmentation: Augment the replay buffer even further by relabeling transitions with unachieved goals; precisely, in each iteration, two goals are sampled $(g, g')$ and both are used to add new transitions into replay buffer. TCN (Time-Contrastive Networks; Sermanet, et al. 2018) learn from multi-camera view videos with the intuition that different viewpoints at the same timestep of the same scene should share the same embedding (like in FaceNet) while embedding should vary in time, even of the same camera viewpoint. Therefore embedding captures the semantic meaning of the underlying state rather than visual similarity. The TCN embedding is trained with triplet loss.\nThe training data is collected by taking videos of the same scene simultaneously but from different angles. All the videos are unlabelled.\nFig. 20. An illustration of time-contrastive approach for learning state embedding. The blue frames selected from two camera views at the same timestep are anchor and positive samples, while the red frame at a different timestep is the negative sample. TCN embedding extracts visual features that are invariant to camera configurations. It can be used to construct a reward function for imitation learning based on the euclidean distance between the demo video and the observations in the latent space.\nA further improvement over TCN is to learn embedding over multiple frames jointly rather than a single frame, resulting in mfTCN (Multi-frame Time-Contrastive Networks; Dwibedi et al., 2019). Given a set of videos from several synchronized camera viewpoints, $v_1, v_2, \\dots, v_k$, the frame at time $t$ and the previous $n-1$ frames selected with stride $s$ in each video are aggregated and mapped into one embedding vector, resulting in a lookback window of size $(n−1) \\times s + 1$. Each frame first goes through a CNN to extract low-level features and then we use 3D temporal convolutions to aggregate frames in time. The model is trained with n-pairs loss.\nFig. 21. The sampling process for training mfTCN. (Image source: Dwibedi et al., 2019) The training data is sampled as follows:\n First we construct two pairs of video clips. Each pair contains two clips from different camera views but with synchronized timesteps. These two sets of videos should be far apart in time. Sample a fixed number of frames from each video clip in the same pair simultaneously with the same stride. Frames with the same timesteps are trained as positive samples in the n-pair loss, while frames across pairs are negative samples. mfTCN embedding can capture the position and velocity of objects in the scene (e.g. in cartpole) and can also be used as inputs for policy.\nAutonomous Goal Generation RIG (Reinforcement learning with Imagined Goals; Nair et al., 2018) described a way to train a goal-conditioned policy with unsupervised representation learning. A policy learns from self-supervised practice by first imagining \u0026ldquo;fake\u0026rdquo; goals and then trying to achieve them.\nFig. 22. The workflow of RIG. (Image source: Nair et al., 2018) The task is to control a robot arm to push a small puck on a table to a desired position. The desired position, or the goal, is present in an image. During training, it learns latent embedding of both state $s$ and goal $g$ through $\\beta$-VAE encoder and the control policy operates entirely in the latent space.\nLet’s say a $\\beta$-VAE has an encoder $q_\\phi$ mapping input states to latent variable $z$ which is modeled by a Gaussian distribution and a decoder $p_\\psi$ mapping $z$ back to the states. The state encoder in RIG is set to be the mean of $\\beta$-VAE encoder.\n $$ \\begin{aligned} z \u0026\\sim q_\\phi(z \\vert s) = \\mathcal{N}(z; \\mu_\\phi(s), \\sigma^2_\\phi(s)) \\\\ \\mathcal{L}_{\\beta\\text{-VAE}} \u0026= - \\mathbb{E}_{z \\sim q_\\phi(z \\vert s)} [\\log p_\\psi (s \\vert z)] + \\beta D_\\text{KL}(q_\\phi(z \\vert s) \\| p_\\psi(s)) \\\\ e(s) \u0026\\triangleq \\mu_\\phi(s) \\end{aligned} $$ The reward is the Euclidean distance between state and goal embedding vectors: $r(s, g) = -|e(s) - e(g)|$. Similar to grasp2vec, RIG applies data augmentation as well by latent goal relabeling: precisely half of the goals are generated from the prior at random and the other half are selected using HER. Also same as grasp2vec, rewards do not depend on any ground truth states but only the learned state encoding, so it can be used for training on real robots.\nFig. 23. The algorithm of RIG. (Image source: Nair et al., 2018) The problem with RIG is a lack of object variations in the imagined goal pictures. If $\\beta$-VAE is only trained with a black puck, it would not be able to create a goal with other objects like blocks of different shapes and colors. A follow-up improvement replaces $\\beta$-VAE with a CC-VAE (Context-Conditioned VAE; Nair, et al., 2019), inspired by CVAE (Conditional VAE; Sohn, Lee \u0026amp; Yan, 2015), for goal generation.\nFig. 24. The workflow of context-conditioned RIG. (Image source: Nair, et al., 2019). A CVAE conditions on a context variable $c$. It trains an encoder $q_\\phi(z \\vert s, c)$ and a decoder $p_\\psi (s \\vert z, c)$ and note that both have access to $c$. The CVAE loss penalizes information passing from the input state $s$ through an information bottleneck but allows for unrestricted information flow from $c$ to both encoder and decoder.\n $$ \\mathcal{L}_\\text{CVAE} = - \\mathbb{E}_{z \\sim q_\\phi(z \\vert s,c)} [\\log p_\\psi (s \\vert z, c)] + \\beta D_\\text{KL}(q_\\phi(z \\vert s, c) \\| p_\\psi(s)) $$ To create plausible goals, CC-VAE conditions on a starting state $s_0$ so that the generated goal presents a consistent type of object as in $s_0$. This goal consistency is necessary; e.g. if the current scene contains a red puck but the goal has a blue block, it would confuse the policy.\nOther than the state encoder $e(s) \\triangleq \\mu_\\phi(s)$, CC-VAE trains a second convolutional encoder $e_0(.)$ to translate the starting state $s_0$ into a compact context representation $c = e_0(s_0)$. Two encoders, $e(.)$ and $e_0(.)$, are intentionally different without shared weights, as they are expected to encode different factors of image variation. In addition to the loss function of CVAE, CC-VAE adds an extra term to learn to reconstruct $c$ back to $s_0$, $\\hat{s}_0 = d_0(c)$.\n $$ \\mathcal{L}_\\text{CC-VAE} = \\mathcal{L}_\\text{CVAE} + \\log p(s_0\\vert c) $$ Fig. 25. Examples of imagined goals generated by CVAE that conditions on the context image (the first row), while VAE fails to capture the object consistency. (Image source: Nair, et al., 2019). Bisimulation Task-agnostic representation (e.g. a model that intends to represent all the dynamics in the system) may distract the RL algorithms as irrelevant information is also presented. For example, if we just train an auto-encoder to reconstruct the input image, there is no guarantee that the entire learned representation will be useful for RL. Therefore, we need to move away from reconstruction-based representation learning if we only want to learn information relevant to control, as irrelevant details are still important for reconstruction.\nRepresentation learning for control based on bisimulation does not depend on reconstruction, but aims to group states based on their behavioral similarity in MDP.\nBisimulation (Givan et al. 2003) refers to an equivalence relation between two states with similar long-term behavior. Bisimulation metrics quantify such relation so that we can aggregate states to compress a high-dimensional state space into a smaller one for more efficient computation. The bisimulation distance between two states corresponds to how behaviorally different these two states are.\nGiven a MDP $\\mathcal{M} = \\langle \\mathcal{S}, \\mathcal{A}, \\mathcal{P}, \\mathcal{R}, \\gamma \\rangle$ and a bisimulation relation $B$, two states that are equal under relation $B$ (i.e. $s_i B s_j$) should have the same immediate reward for all actions and the same transition probabilities over the next bisimilar states:\n $$ \\begin{aligned} \\mathcal{R}(s_i, a) \u0026= \\mathcal{R}(s_j, a) \\; \\forall a \\in \\mathcal{A} \\\\ \\mathcal{P}(G \\vert s_i, a) \u0026= \\mathcal{P}(G \\vert s_j, a) \\; \\forall a \\in \\mathcal{A} \\; \\forall G \\in \\mathcal{S}_B \\end{aligned} $$ where $\\mathcal{S}_B$ is a partition of the state space under the relation $B$.\nNote that $=$ is always a bisimulation relation. The most interesting one is the maximal bisimulation relation $\\sim$, which defines a partition $\\mathcal{S}_\\sim$ with fewest groups of states.\nFig. 26. DeepMDP learns a latent space model by minimizing two losses on a reward model and a dynamics model. (Image source: Gelada, et al. 2019) With a goal similar to bisimulation metric, DeepMDP (Gelada, et al. 2019) simplifies high-dimensional observations in RL tasks and learns a latent space model via minimizing two losses:\n prediction of rewards and prediction of the distribution over next latent states. $$ \\begin{aligned} \\mathcal{L}_{\\bar{\\mathcal{R}}}(s, a) = \\vert \\mathcal{R}(s, a) - \\bar{\\mathcal{R}}(\\phi(s), a) \\vert \\\\ \\mathcal{L}_{\\bar{\\mathcal{P}}}(s, a) = D(\\phi \\mathcal{P}(s, a), \\bar{\\mathcal{P}}(. \\vert \\phi(s), a)) \\end{aligned} $$ where $\\phi(s)$ is the embedding of state $s$; symbols with bar are functions (reward function $R$ and transition function $P$) in the same MDP but running in the latent low-dimensional observation space. Here the embedding representation $\\phi$ can be connected to bisimulation metrics, as the bisimulation distance is proved to be upper-bounded by the L2 distance in the latent space.\nThe function $D$ quantifies the distance between two probability distributions and should be chosen carefully. DeepMDP focuses on Wasserstein-1 metric (also known as “earth-mover distance”). The Wasserstein-1 distance between distributions $P$ and $Q$ on a metric space $(M, d)$ (i.e., $d: M \\times M \\to \\mathbb{R}$) is:\n $$ W_d (P, Q) = \\inf_{\\lambda \\in \\Pi(P, Q)} \\int_{M \\times M} d(x, y) \\lambda(x, y) \\; \\mathrm{d}x \\mathrm{d}y $$ where $\\Pi(P, Q)$ is the set of all couplings of $P$ and $Q$. $d(x, y)$ defines the cost of moving a particle from point $x$ to point $y$.\nThe Wasserstein metric has a dual form according to the Monge-Kantorovich duality:\n $$ W_d (P, Q) = \\sup_{f \\in \\mathcal{F}_d} \\vert \\mathbb{E}_{x \\sim P} f(x) - \\mathbb{E}_{y \\sim Q} f(y) \\vert $$ where $\\mathcal{F}_d$ is the set of 1-Lipschitz functions under the metric $d$ - $\\mathcal{F}_d = \\{ f: \\vert f(x) - f(y) \\vert \\leq d(x, y) \\}$.\nDeepMDP generalizes the model to the Norm Maximum Mean Discrepancy (Norm-MMD) metrics to improve the tightness of the bounds of its deep value function and, at the same time, to save computation (Wasserstein is expensive computationally). In their experiments, they found the model architecture of the transition prediction model can have a big impact on the performance. Adding these DeepMDP losses as auxiliary losses when training model-free RL agents leads to good improvement on most of the Atari games.\nDeep Bisimulatioin for Control (short for DBC; Zhang et al. 2020) learns the latent representation of observations that are good for control in RL tasks, without domain knowledge or pixel-level reconstruction.\nFig. 27. The Deep Bisimulation for Control algorithm learns a bisimulation metric representation via learning a reward model and a dynamics model. The model architecture is a siamese network. (Image source: Zhang et al. 2020) Similar to DeepMDP, DBC models the dynamics by learning a reward model and a transition model. Both models operate in the latent space, $\\phi(s)$. The optimization of embedding $\\phi$ depends on one important conclusion from Ferns, et al. 2004 (Theorem 4.5) and Ferns, et al 2011 (Theorem 2.6):\n Given $c \\in (0, 1)$ a discounting factor, $\\pi$ a policy that is being improved continuously, and $M$ the space of bounded pseudometric on the state space $\\mathcal{S}$, we can define $\\mathcal{F}: M \\mapsto M$:\n $$ \\mathcal{F}(d; \\pi)(s_i, s_j) = (1-c) \\vert \\mathcal{R}_{s_i}^\\pi - \\mathcal{R}_{s_j}^\\pi \\vert + c W_d (\\mathcal{P}_{s_i}^\\pi, \\mathcal{P}_{s_j}^\\pi) $$ Then, $\\mathcal{F}$ has a unique fixed point $\\tilde{d}$ which is a $\\pi^*$-bisimulation metric and $\\tilde{d}(s_i, s_j) = 0 \\iff s_i \\sim s_j$.\n [The proof is not trivial. I may or may not add it in the future _(:3」∠)_ \u0026hellip;]\nGiven batches of observations pairs, the training loss for $\\phi$, $J(\\phi)$, minimizes the mean square error between the on-policy bisimulation metric and Euclidean distance in the latent space:\n$$ J(\\phi) = \\Big( \\|\\phi(s_i) - \\phi(s_j)\\|_1 - \\vert \\hat{\\mathcal{R}}(\\bar{\\phi}(s_i)) - \\hat{\\mathcal{R}}(\\bar{\\phi}(s_j)) \\vert - \\gamma W_2(\\hat{\\mathcal{P}}(\\cdot \\vert \\bar{\\phi}(s_i), \\bar{\\pi}(\\bar{\\phi}(s_i))), \\hat{\\mathcal{P}}(\\cdot \\vert \\bar{\\phi}(s_j), \\bar{\\pi}(\\bar{\\phi}(s_j)))) \\Big)^2 $$ where $\\bar{\\phi}(s)$ denotes $\\phi(s)$ with stop gradient and $\\bar{\\pi}$ is the mean policy output. The learned reward model $\\hat{\\mathcal{R}}$ is deterministic and the learned forward dynamics model $\\hat{\\mathcal{P}}$ outputs a Gaussian distribution.\nDBC is based on SAC but operates on the latent space:\nFig. 28. The algorithm of Deep Bisimulation for Control. (Image source: Zhang et al. 2020) Cited as:\n@article{weng2019selfsup, title = \u0026quot;Self-Supervised Representation Learning\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2019\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2019-11-10-self-supervised/\u0026quot; } References [1] Alexey Dosovitskiy, et al. \u0026ldquo;Discriminative unsupervised feature learning with exemplar convolutional neural networks.\u0026quot; IEEE transactions on pattern analysis and machine intelligence 38.9 (2015): 1734-1747.\n[2] Spyros Gidaris, Praveer Singh \u0026amp; Nikos Komodakis. \u0026ldquo;Unsupervised Representation Learning by Predicting Image Rotations\u0026rdquo; ICLR 2018.\n[3] Carl Doersch, Abhinav Gupta, and Alexei A. Efros. \u0026ldquo;Unsupervised visual representation learning by context prediction.\u0026quot; ICCV. 2015.\n[4] Mehdi Noroozi \u0026amp; Paolo Favaro. \u0026ldquo;Unsupervised learning of visual representations by solving jigsaw puzzles.\u0026quot; ECCV, 2016.\n[5] Mehdi Noroozi, Hamed Pirsiavash, and Paolo Favaro. \u0026ldquo;Representation learning by learning to count.\u0026quot; ICCV. 2017.\n[6] Richard Zhang, Phillip Isola \u0026amp; Alexei A. Efros. \u0026ldquo;Colorful image colorization.\u0026quot; ECCV, 2016.\n[7] Pascal Vincent, et al. \u0026ldquo;Extracting and composing robust features with denoising autoencoders.\u0026quot; ICML, 2008.\n[8] Jeff Donahue, Philipp Krähenbühl, and Trevor Darrell. \u0026ldquo;Adversarial feature learning.\u0026quot; ICLR 2017.\n[9] Deepak Pathak, et al. \u0026ldquo;Context encoders: Feature learning by inpainting.\u0026quot; CVPR. 2016.\n[10] Richard Zhang, Phillip Isola, and Alexei A. Efros. \u0026ldquo;Split-brain autoencoders: Unsupervised learning by cross-channel prediction.\u0026quot; CVPR. 2017.\n[11] Xiaolong Wang \u0026amp; Abhinav Gupta. \u0026ldquo;Unsupervised Learning of Visual Representations using Videos.\u0026quot; ICCV. 2015.\n[12] Carl Vondrick, et al. \u0026ldquo;Tracking Emerges by Colorizing Videos\u0026rdquo; ECCV. 2018.\n[13] Ishan Misra, C. Lawrence Zitnick, and Martial Hebert. \u0026ldquo;Shuffle and learn: unsupervised learning using temporal order verification.\u0026quot; ECCV. 2016.\n[14] Basura Fernando, et al. \u0026ldquo;Self-Supervised Video Representation Learning With Odd-One-Out Networks\u0026rdquo; CVPR. 2017.\n[15] Donglai Wei, et al. \u0026ldquo;Learning and Using the Arrow of Time\u0026rdquo; CVPR. 2018.\n[16] Florian Schroff, Dmitry Kalenichenko and James Philbin. \u0026ldquo;FaceNet: A Unified Embedding for Face Recognition and Clustering\u0026rdquo; CVPR. 2015.\n[17] Pierre Sermanet, et al. \u0026ldquo;Time-Contrastive Networks: Self-Supervised Learning from Video\u0026rdquo; CVPR. 2018.\n[18] Debidatta Dwibedi, et al. \u0026ldquo;Learning actionable representations from visual observations.\u0026quot; IROS. 2018.\n[19] Eric Jang \u0026amp; Coline Devin, et al. \u0026ldquo;Grasp2Vec: Learning Object Representations from Self-Supervised Grasping\u0026rdquo; CoRL. 2018.\n[20] Ashvin Nair, et al. \u0026ldquo;Visual reinforcement learning with imagined goals\u0026rdquo; NeuriPS. 2018.\n[21] Ashvin Nair, et al. \u0026ldquo;Contextual imagined goals for self-supervised robotic learning\u0026rdquo; CoRL. 2019.\n[22] Aaron van den Oord, Yazhe Li \u0026amp; Oriol Vinyals. \u0026ldquo;Representation Learning with Contrastive Predictive Coding\u0026rdquo; arXiv preprint arXiv:1807.03748, 2018.\n[23] Olivier J. Henaff, et al. \u0026ldquo;Data-Efficient Image Recognition with Contrastive Predictive Coding\u0026rdquo; arXiv preprint arXiv:1905.09272, 2019.\n[24] Kaiming He, et al. \u0026ldquo;Momentum Contrast for Unsupervised Visual Representation Learning.\u0026quot; CVPR 2020.\n[25] Zhirong Wu, et al. \u0026ldquo;Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination.\u0026quot; CVPR 2018.\n[26] Ting Chen, et al. \u0026ldquo;A Simple Framework for Contrastive Learning of Visual Representations.\u0026quot; arXiv preprint arXiv:2002.05709, 2020.\n[27] Aravind Srinivas, Michael Laskin \u0026amp; Pieter Abbeel \u0026ldquo;CURL: Contrastive Unsupervised Representations for Reinforcement Learning.\u0026quot; arXiv preprint arXiv:2004.04136, 2020.\n[28] Carles Gelada, et al. “DeepMDP: Learning Continuous Latent Space Models for Representation Learning” ICML 2019.\n[29] Amy Zhang, et al. “Learning Invariant Representations for Reinforcement Learning without Reconstruction” arXiv preprint arXiv:2006.10742, 2020.\n[30] Xinlei Chen, et al. “Improved Baselines with Momentum Contrastive Learning” arXiv preprint arXiv:2003.04297, 2020.\n[31] Jean-Bastien Grill, et al. “Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning” arXiv preprint arXiv:2006.07733, 2020.\n[32] Abe Fetterman \u0026amp; Josh Albrecht. “Understanding self-supervised and contrastive learning with Bootstrap Your Own Latent (BYOL)” Untitled blog. Aug 24, 2020.\n","permalink":"https://lilianweng.github.io/posts/2019-11-10-self-supervised/","summary":"[Updated on 2020-01-09: add a new section on Contrastive Predictive Coding]. [Updated on 2020-04-13: add a \u0026ldquo;Momentum Contrast\u0026rdquo; section on MoCo, SimCLR and CURL.] [Updated on 2020-07-08: add a \u0026ldquo;Bisimulation\u0026rdquo; section on DeepMDP and DBC.] [Updated on 2020-09-12: add MoCo V2 and BYOL in the \u0026ldquo;Momentum Contrast\u0026rdquo; section.] [Updated on 2021-05-31: remove section on \u0026ldquo;Momentum Contrast\u0026rdquo; and add a pointer to a full post on \u0026ldquo;Contrastive Representation Learning\u0026rdquo;]","title":"Self-Supervised Representation Learning"},{"content":"Stochastic gradient descent is a universal choice for optimizing deep learning models. However, it is not the only option. With black-box optimization algorithms, you can evaluate a target function $f(x): \\mathbb{R}^n \\to \\mathbb{R}$, even when you don\u0026rsquo;t know the precise analytic form of $f(x)$ and thus cannot compute gradients or the Hessian matrix. Examples of black-box optimization methods include Simulated Annealing, Hill Climbing and Nelder-Mead method.\nEvolution Strategies (ES) is one type of black-box optimization algorithms, born in the family of Evolutionary Algorithms (EA). In this post, I would dive into a couple of classic ES methods and introduce a few applications of how ES can play a role in deep reinforcement learning.\nWhat are Evolution Strategies? Evolution strategies (ES) belong to the big family of evolutionary algorithms. The optimization targets of ES are vectors of real numbers, $x \\in \\mathbb{R}^n$.\nEvolutionary algorithms refer to a division of population-based optimization algorithms inspired by natural selection. Natural selection believes that individuals with traits beneficial to their survival can live through generations and pass down the good characteristics to the next generation. Evolution happens by the selection process gradually and the population grows better adapted to the environment.\nFig. 1. How natural selection works. (Image source: Khan Academy: Darwin, evolution, \u0026 natural selection) Evolutionary algorithms can be summarized in the following format as a general optimization solution:\nLet\u0026rsquo;s say we want to optimize a function $f(x)$ and we are not able to compute gradients directly. But we still can evaluate $f(x)$ given any $x$ and the result is deterministic. Our belief in the probability distribution over $x$ as a good solution to $f(x)$ optimization is $p_\\theta(x)$, parameterized by $\\theta$. The goal is to find an optimal configuration of $\\theta$.\n Here given a fixed format of distribution (i.e. Gaussian), the parameter $\\theta$ carries the knowledge about the best solutions and is being iteratively updated across generations.\n Starting with an initial value of $\\theta$, we can continuously update $\\theta$ by looping three steps as follows:\n Generate a population of samples $D = \\{(x_i, f(x_i)\\}$ where $x_i \\sim p_\\theta(x)$. Evaluate the \u0026ldquo;fitness\u0026rdquo; of samples in $D$. Select the best subset of individuals and use them to update $\\theta$, generally based on fitness or rank. In Genetic Algorithms (GA), another popular subcategory of EA, $x$ is a sequence of binary codes, $x \\in \\{0, 1\\}^n$. While in ES, $x$ is just a vector of real numbers, $x \\in \\mathbb{R}^n$.\nSimple Gaussian Evolution Strategies This is the most basic and canonical version of evolution strategies. It models $p_\\theta(x)$ as a $n$-dimensional isotropic Gaussian distribution, in which $\\theta$ only tracks the mean $\\mu$ and standard deviation $\\sigma$.\n $$ \\theta = (\\mu, \\sigma),\\;p_\\theta(x) \\sim \\mathcal{N}(\\mathbf{\\mu}, \\sigma^2 I) = \\mu + \\sigma \\mathcal{N}(0, I) $$ The process of Simple-Gaussian-ES, given $x \\in \\mathcal{R}^n$:\n Initialize $\\theta = \\theta^{(0)}$ and the generation counter $t=0$ Generate the offspring population of size $\\Lambda$ by sampling from the Gaussian distribution:$D^{(t+1)}=\\{ x^{(t+1)}_i \\mid x^{(t+1)}_i = \\mu^{(t)} + \\sigma^{(t)} y^{(t+1)}_i \\text{ where } y^{(t+1)}_i \\sim \\mathcal{N}(x \\vert 0, \\mathbf{I}),;i = 1, \\dots, \\Lambda\\}$. Select a top subset of $\\lambda$ samples with optimal $f(x_i)$ and this subset is called elite set. Without loss of generality, we may consider the first $k$ samples in $D^{(t+1)}$ to belong to the elite group \u0026mdash; Let\u0026rsquo;s label them as $$ D^{(t+1)}\\_\\text{elite} = \\\\{x^{(t+1)}\\_i \\mid x^{(t+1)}\\_i \\in D^{(t+1)}, i=1,\\dots, \\lambda, \\lambda\\leq \\Lambda\\\\} $$ Then we estimate the new mean and std for the next generation using the elite set: $$ \\begin{aligned} \\mu^{(t+1)} \u0026= \\text{avg}(D^{(t+1)}_\\text{elite}) = \\frac{1}{\\lambda}\\sum_{i=1}^\\lambda x_i^{(t+1)} \\\\ {\\sigma^{(t+1)}}^2 \u0026= \\text{var}(D^{(t+1)}_\\text{elite}) = \\frac{1}{\\lambda}\\sum_{i=1}^\\lambda (x_i^{(t+1)} -\\mu^{(t)})^2 \\end{aligned} $$ Repeat steps (2)-(4) until the result is good enough ✌️ Covariance Matrix Adaptation Evolution Strategies (CMA-ES) The standard deviation $\\sigma$ accounts for the level of exploration: the larger $\\sigma$ the bigger search space we can sample our offspring population. In vanilla ES, $\\sigma^{(t+1)}$ is highly correlated with $\\sigma^{(t)}$, so the algorithm is not able to rapidly adjust the exploration space when needed (i.e. when the confidence level changes).\nCMA-ES, short for \u0026ldquo;Covariance Matrix Adaptation Evolution Strategy\u0026rdquo;, fixes the problem by tracking pairwise dependencies between the samples in the distribution with a covariance matrix $C$. The new distribution parameter becomes:\n $$ \\theta = (\\mu, \\sigma, C),\\; p_\\theta(x) \\sim \\mathcal{N}(\\mu, \\sigma^2 C) \\sim \\mu + \\sigma \\mathcal{N}(0, C) $$ where $\\sigma$ controls for the overall scale of the distribution, often known as step size.\nBefore we dig into how the parameters are updated in CMA-ES, it is better to review how the covariance matrix works in the multivariate Gaussian distribution first. As a real symmetric matrix, the covariance matrix $C$ has the following nice features (See proof \u0026amp; proof):\n It is always diagonalizable. Always positive semi-definite. All of its eigenvalues are real non-negative numbers. All of its eigenvectors are orthogonal. There is an orthonormal basis of $\\mathbb{R}^n$ consisting of its eigenvectors. Let the matrix $C$ have an orthonormal basis of eigenvectors $B = [b_1, \\dots, b_n]$, with corresponding eigenvalues $\\lambda_1^2, \\dots, \\lambda_n^2$. Let $D=\\text{diag}(\\lambda_1, \\dots, \\lambda_n)$.\n $$ C = B^\\top D^2 B = \\begin{bmatrix} \\mid \u0026 \\mid \u0026 \u0026 \\mid \\\\ b_1 \u0026 b_2 \u0026 \\dots \u0026 b_n\\\\ \\mid \u0026 \\mid \u0026 \u0026 \\mid \\\\ \\end{bmatrix} \\begin{bmatrix} \\lambda_1^2 \u0026 0 \u0026 \\dots \u0026 0 \\\\ 0 \u0026 \\lambda_2^2 \u0026 \\dots \u0026 0 \\\\ \\vdots \u0026 \\dots \u0026 \\ddots \u0026 \\vdots \\\\ 0 \u0026 \\dots \u0026 0 \u0026 \\lambda_n^2 \\end{bmatrix} \\begin{bmatrix} - \u0026 b_1 \u0026 - \\\\ - \u0026 b_2 \u0026 - \\\\ \u0026 \\dots \u0026 \\\\ - \u0026 b_n \u0026 - \\\\ \\end{bmatrix} $$ The square root of $C$ is:\n $$ C^{\\frac{1}{2}} = B^\\top D B $$ Symbol Meaning $x_i^{(t)} \\in \\mathbb{R}^n$ the $i$-th samples at the generation (t) $y_i^{(t)} \\in \\mathbb{R}^n$ $x_i^{(t)} = \\mu^{(t-1)} + \\sigma^{(t-1)} y_i^{(t)} $ $\\mu^{(t)}$ mean of the generation (t) $\\sigma^{(t)}$ step size $C^{(t)}$ covariance matrix $B^{(t)}$ a matrix of $C$\u0026rsquo;s eigenvectors as row vectors $D^{(t)}$ a diagonal matrix with $C$\u0026rsquo;s eigenvalues on the diagnose. $p_\\sigma^{(t)}$ evaluation path for $\\sigma$ at the generation (t) $p_c^{(t)}$ evaluation path for $C$ at the generation (t) $\\alpha_\\mu$ learning rate for $\\mu$\u0026rsquo;s update $\\alpha_\\sigma$ learning rate for $p_\\sigma$ $d_\\sigma$ damping factor for $\\sigma$\u0026rsquo;s update $\\alpha_{cp}$ learning rate for $p_c$ $\\alpha_{c\\lambda}$ learning rate for $C$\u0026rsquo;s rank-min(λ, n) update $\\alpha_{c1}$ learning rate for $C$\u0026rsquo;s rank-1 update Updating the Mean $$ \\mu^{(t+1)} = \\mu^{(t)} + \\alpha_\\mu \\frac{1}{\\lambda}\\sum_{i=1}^\\lambda (x_i^{(t+1)} - \\mu^{(t)}) $$ CMA-ES has a learning rate $\\alpha_\\mu \\leq 1$ to control how fast the mean $\\mu$ should be updated. Usually it is set to 1 and thus the equation becomes the same as in vanilla ES, $\\mu^{(t+1)} = \\frac{1}{\\lambda}\\sum_{i=1}^\\lambda (x_i^{(t+1)}$.\nControlling the Step Size The sampling process can be decoupled from the mean and standard deviation:\n $$ x^{(t+1)}_i = \\mu^{(t)} + \\sigma^{(t)} y^{(t+1)}_i \\text{, where } y^{(t+1)}_i = \\frac{x_i^{(t+1)} - \\mu^{(t)}}{\\sigma^{(t)}} \\sim \\mathcal{N}(0, C) $$ The parameter $\\sigma$ controls the overall scale of the distribution. It is separated from the covariance matrix so that we can change steps faster than the full covariance. A larger step size leads to faster parameter update. In order to evaluate whether the current step size is proper, CMA-ES constructs an evolution path $p_\\sigma$ by summing up a consecutive sequence of moving steps, $\\frac{1}{\\lambda}\\sum_{i}^\\lambda y_i^{(j)}, j=1, \\dots, t$. By comparing this path length with its expected length under random selection (meaning single steps are uncorrelated), we are able to adjust $\\sigma$ accordingly (See Fig. 2).\nFig. 2. Three scenarios of how single steps are correlated in different ways and their impacts on step size update. (Image source: additional annotations on Fig 5 in CMA-ES tutorial paper) Each time the evolution path is updated with the average of moving step $y_i$ in the same generation.\n $$ \\begin{aligned} \u0026\\frac{1}{\\lambda}\\sum_{i=1}^\\lambda y_i^{(t+1)} = \\frac{1}{\\lambda} \\frac{\\sum_{i=1}^\\lambda x_i^{(t+1)} - \\lambda \\mu^{(t)}}{\\sigma^{(t)}} = \\frac{\\mu^{(t+1)} - \\mu^{(t)}}{\\sigma^{(t)}} \\\\ \u0026\\frac{1}{\\lambda}\\sum_{i=1}^\\lambda y_i^{(t+1)} \\sim \\frac{1}{\\lambda}\\mathcal{N}(0, \\lambda C^{(t)}) \\sim \\frac{1}{\\sqrt{\\lambda}}{C^{(t)}}^{\\frac{1}{2}}\\mathcal{N}(0, I) \\\\ \u0026\\text{Thus } \\sqrt{\\lambda}\\;{C^{(t)}}^{-\\frac{1}{2}} \\frac{\\mu^{(t+1)} - \\mu^{(t)}}{\\sigma^{(t)}} \\sim \\mathcal{N}(0, I) \\end{aligned} $$ By multiplying with $C^{-\\frac{1}{2}}$, the evolution path is transformed to be independent of its direction. The term ${C^{(t)}}^{-\\frac{1}{2}} = {B^{(t)}}^\\top {D^{(t)}}^{-\\frac{1}{2}} {B^{(t)}}$ transformation works as follows:\n ${B^{(t)}}$ contains row vectors of $C$\u0026rsquo;s eigenvectors. It projects the original space onto the perpendicular principal axes. Then ${D^{(t)}}^{-\\frac{1}{2}} = \\text{diag}(\\frac{1}{\\lambda_1}, \\dots, \\frac{1}{\\lambda_n})$ scales the length of principal axes to be equal. ${B^{(t)}}^\\top$ transforms the space back to the original coordinate system. In order to assign higher weights to recent generations, we use polyak averaging to update the evolution path with learning rate $\\alpha_\\sigma$. Meanwhile, the weights are balanced so that $p_\\sigma$ is conjugate, $\\sim \\mathcal{N}(0, I)$ both before and after one update.\n $$ \\begin{aligned} p_\\sigma^{(t+1)} \u0026 = (1 - \\alpha_\\sigma) p_\\sigma^{(t)} + \\sqrt{1 - (1 - \\alpha_\\sigma)^2}\\;\\sqrt{\\lambda}\\; {C^{(t)}}^{-\\frac{1}{2}} \\frac{\\mu^{(t+1)} - \\mu^{(t)}}{\\sigma^{(t)}} \\\\ \u0026 = (1 - \\alpha_\\sigma) p_\\sigma^{(t)} + \\sqrt{c_\\sigma (2 - \\alpha_\\sigma)\\lambda}\\;{C^{(t)}}^{-\\frac{1}{2}} \\frac{\\mu^{(t+1)} - \\mu^{(t)}}{\\sigma^{(t)}} \\end{aligned} $$ The expected length of $p_\\sigma$ under random selection is $\\mathbb{E}|\\mathcal{N}(0,I)|$, that is the expectation of the L2-norm of a $\\mathcal{N}(0,I)$ random variable. Following the idea in Fig. 2, we adjust the step size according to the ratio of $|p_\\sigma^{(t+1)}| / \\mathbb{E}|\\mathcal{N}(0,I)|$:\n $$ \\begin{aligned} \\ln\\sigma^{(t+1)} \u0026= \\ln\\sigma^{(t)} + \\frac{\\alpha_\\sigma}{d_\\sigma} \\Big(\\frac{\\|p_\\sigma^{(t+1)}\\|}{\\mathbb{E}\\|\\mathcal{N}(0,I)\\|} - 1\\Big) \\\\ \\sigma^{(t+1)} \u0026= \\sigma^{(t)} \\exp\\Big(\\frac{\\alpha_\\sigma}{d_\\sigma} \\Big(\\frac{\\|p_\\sigma^{(t+1)}\\|}{\\mathbb{E}\\|\\mathcal{N}(0,I)\\|} - 1\\Big)\\Big) \\end{aligned} $$ where $d_\\sigma \\approx 1$ is a damping parameter, scaling how fast $\\ln\\sigma$ should be changed.\nAdapting the Covariance Matrix For the covariance matrix, it can be estimated from scratch using $y_i$ of elite samples (recall that $y_i \\sim \\mathcal{N}(0, C)$):\n $$ C_\\lambda^{(t+1)} = \\frac{1}{\\lambda}\\sum_{i=1}^\\lambda y^{(t+1)}_i {y^{(t+1)}_i}^\\top = \\frac{1}{\\lambda {\\sigma^{(t)}}^2} \\sum_{i=1}^\\lambda (x_i^{(t+1)} - \\mu^{(t)})(x_i^{(t+1)} - \\mu^{(t)})^\\top $$ The above estimation is only reliable when the selected population is large enough. However, we do want to run fast iteration with a small population of samples in each generation. That\u0026rsquo;s why CMA-ES invented a more reliable but also more complicated way to update $C$. It involves two independent routes,\n Rank-min(λ, n) update: uses the history of $\\{C_\\lambda\\}$, each estimated from scratch in one generation. Rank-one update: estimates the moving steps $y_i$ and the sign information from the history. The first route considers the estimation of $C$ from the entire history of $\\{C_\\lambda\\}$. For example, if we have experienced a large number of generations, $C^{(t+1)} \\approx \\text{avg}(C_\\lambda^{(i)}; i=1,\\dots,t)$ would be a good estimator. Similar to $p_\\sigma$, we also use polyak averaging with a learning rate to incorporate the history:\n $$ C^{(t+1)} = (1 - \\alpha_{c\\lambda}) C^{(t)} + \\alpha_{c\\lambda} C_\\lambda^{(t+1)} = (1 - \\alpha_{c\\lambda}) C^{(t)} + \\alpha_{c\\lambda} \\frac{1}{\\lambda} \\sum_{i=1}^\\lambda y^{(t+1)}_i {y^{(t+1)}_i}^\\top $$ A common choice for the learning rate is $\\alpha_{c\\lambda} \\approx \\min(1, \\lambda/n^2)$.\nThe second route tries to solve the issue that $y_i{y_i}^\\top = (-y_i)(-y_i)^\\top$ loses the sign information. Similar to how we adjust the step size $\\sigma$, an evolution path $p_c$ is used to track the sign information and it is constructed in a way that $p_c$ is conjugate, $\\sim \\mathcal{N}(0, C)$ both before and after a new generation.\nWe may consider $p_c$ as another way to compute $\\text{avg}_i(y_i)$ (notice that both $\\sim \\mathcal{N}(0, C)$) while the entire history is used and the sign information is maintained. Note that we\u0026rsquo;ve known $\\sqrt{k}\\frac{\\mu^{(t+1)} - \\mu^{(t)}}{\\sigma^{(t)}} \\sim \\mathcal{N}(0, C)$ in the last section,\n $$ \\begin{aligned} p_c^{(t+1)} \u0026= (1-\\alpha_{cp}) p_c^{(t)} + \\sqrt{1 - (1-\\alpha_{cp})^2}\\;\\sqrt{\\lambda}\\;\\frac{\\mu^{(t+1)} - \\mu^{(t)}}{\\sigma^{(t)}} \\\\ \u0026= (1-\\alpha_{cp}) p_c^{(t)} + \\sqrt{\\alpha_{cp}(2 - \\alpha_{cp})\\lambda}\\;\\frac{\\mu^{(t+1)} - \\mu^{(t)}}{\\sigma^{(t)}} \\end{aligned} $$ Then the covariance matrix is updated according to $p_c$:\n $$ C^{(t+1)} = (1-\\alpha_{c1}) C^{(t)} + \\alpha_{c1}\\;p_c^{(t+1)} {p_c^{(t+1)}}^\\top $$ The rank-one update approach is claimed to generate a significant improvement over the rank-min(λ, n)-update when $k$ is small, because the signs of moving steps and correlations between consecutive steps are all utilized and passed down through generations.\nEventually we combine two approaches together,\n $$ C^{(t+1)} = (1 - \\alpha_{c\\lambda} - \\alpha_{c1}) C^{(t)} + \\alpha_{c1}\\;\\underbrace{p_c^{(t+1)} {p_c^{(t+1)}}^\\top}_\\textrm{rank-one update} + \\alpha_{c\\lambda} \\underbrace{\\frac{1}{\\lambda} \\sum_{i=1}^\\lambda y^{(t+1)}_i {y^{(t+1)}_i}^\\top}_\\textrm{rank-min(lambda, n) update} $$ In all my examples above, each elite sample is considered to contribute an equal amount of weights, $1/\\lambda$. The process can be easily extended to the case where selected samples are assigned with different weights, $w_1, \\dots, w_\\lambda$, according to their performances. See more detail in tutorial.\nFig. 3. Illustration of how CMA-ES works on a 2D optimization problem (the lighter color the better). Black dots are samples in one generation. The samples are more spread out initially but when the model has higher confidence in finding a good solution in the late stage, the samples become very concentrated over the global optimum. (Image source: Wikipedia CMA-ES) Natural Evolution Strategies Natural Evolution Strategies (NES; Wierstra, et al, 2008) optimizes in a search distribution of parameters and moves the distribution in the direction of high fitness indicated by the natural gradient.\nNatural Gradients Given an objective function $\\mathcal{J}(\\theta)$ parameterized by $\\theta$, let\u0026rsquo;s say our goal is to find the optimal $\\theta$ to maximize the objective function value. A plain gradient finds the steepest direction within a small Euclidean distance from the current $\\theta$; the distance restriction is applied on the parameter space. In other words, we compute the plain gradient with respect to a small change of the absolute value of $\\theta$. The optimal step is:\n $$ d^{*} = \\operatorname*{argmax}_{\\|d\\| = \\epsilon} \\mathcal{J}(\\theta + d)\\text{, where }\\epsilon \\to 0 $$ Differently, natural gradient works with a probability distribution space parameterized by $\\theta$, $p_\\theta(x)$ (referred to as \u0026ldquo;search distribution\u0026rdquo; in NES paper). It looks for the steepest direction within a small step in the distribution space where the distance is measured by KL divergence. With this constraint we ensure that each update is moving along the distributional manifold with constant speed, without being slowed down by its curvature.\n $$ d^{*}_\\text{N} = \\operatorname*{argmax}_{\\text{KL}[p_\\theta \\| p_{\\theta+d}] = \\epsilon} \\mathcal{J}(\\theta + d) $$ Estimation using Fisher Information Matrix But, how to compute $\\text{KL}[p_\\theta | p_{\\theta+\\Delta\\theta}]$ precisely? By running Taylor expansion of $\\log p_{\\theta + d}$ at $\\theta$, we get:\n $$ \\begin{aligned} \u0026 \\text{KL}[p_\\theta \\| p_{\\theta+d}] \\\\ \u0026= \\mathbb{E}_{x \\sim p_\\theta} [\\log p_\\theta(x) - \\log p_{\\theta+d}(x)] \u0026 \\\\ \u0026\\approx \\mathbb{E}_{x \\sim p_\\theta} [ \\log p_\\theta(x) -( \\log p_{\\theta}(x) + \\nabla_\\theta \\log p_{\\theta}(x) d + \\frac{1}{2}d^\\top \\nabla^2_\\theta \\log p_{\\theta}(x) d)] \u0026 \\scriptstyle{\\text{; Taylor expand }\\log p_{\\theta+d}} \\\\ \u0026\\approx - \\mathbb{E}_x [\\nabla_\\theta \\log p_{\\theta}(x)] d - \\frac{1}{2}d^\\top \\mathbb{E}_x [\\nabla^2_\\theta \\log p_{\\theta}(x)] d \u0026 \\end{aligned} $$ where\n $$ \\begin{aligned} \\mathbb{E}_x [\\nabla_\\theta \\log p_{\\theta}] d \u0026= \\int_{x\\sim p_\\theta} p_\\theta(x) \\nabla_\\theta \\log p_\\theta(x) \u0026 \\\\ \u0026= \\int_{x\\sim p_\\theta} p_\\theta(x) \\frac{1}{p_\\theta(x)} \\nabla_\\theta p_\\theta(x) \u0026 \\\\ \u0026= \\nabla_\\theta \\Big( \\int_{x} p_\\theta(x) \\Big) \u0026 \\scriptstyle{\\textrm{; note that }p_\\theta(x)\\textrm{ is probability distribution.}} \\\\ \u0026= \\nabla_\\theta (1) = 0 \\end{aligned} $$ Finally we have,\n $$ \\text{KL}[p_\\theta \\| p_{\\theta+d}] = - \\frac{1}{2}d^\\top \\mathbf{F}_\\theta d \\text{, where }\\mathbf{F}_\\theta = \\mathbb{E}_x [(\\nabla_\\theta \\log p_{\\theta}) (\\nabla_\\theta \\log p_{\\theta})^\\top] $$ where $\\mathbf{F}_\\theta$ is called the Fisher Information Matrix and it is the covariance matrix of $\\nabla_\\theta \\log p_\\theta$ since $\\mathbb{E}[\\nabla_\\theta \\log p_\\theta] = 0$.\nThe solution to the following optimization problem:\n $$ \\max \\mathcal{J}(\\theta + d) \\approx \\max \\big( \\mathcal{J}(\\theta) + {\\nabla_\\theta\\mathcal{J}(\\theta)}^\\top d \\big)\\;\\text{ s.t. }\\text{KL}[p_\\theta \\| p_{\\theta+d}] - \\epsilon = 0 $$ can be found using a Lagrangian multiplier,\n $$ \\begin{aligned} \\mathcal{L}(\\theta, d, \\beta) \u0026= \\mathcal{J}(\\theta) + \\nabla_\\theta\\mathcal{J}(\\theta)^\\top d - \\beta (\\frac{1}{2}d^\\top \\mathbf{F}_\\theta d + \\epsilon) = 0 \\text{ s.t. } \\beta 0 \\\\ \\nabla_d \\mathcal{L}(\\theta, d, \\beta) \u0026= \\nabla_\\theta\\mathcal{J}(\\theta) - \\beta\\mathbf{F}_\\theta d = 0 \\\\ \\text{Thus } d_\\text{N}^* \u0026= \\nabla_\\theta^\\text{N} \\mathcal{J}(\\theta) = \\mathbf{F}_\\theta^{-1} \\nabla_\\theta\\mathcal{J}(\\theta) \\end{aligned} $$ where $d_\\text{N}^*$ only extracts the direction of the optimal moving step on $\\theta$, ignoring the scalar $\\beta^{-1}$.\nFig. 4. The natural gradient samples (black solid arrows) in the right are the plain gradient samples (black solid arrows) in the left multiplied by the inverse of their covariance. In this way, a gradient direction with high uncertainty (indicated by high covariance with other samples) are penalized with a small weight. The aggregated natural gradient (red dash arrow) is therefore more trustworthy than the natural gradient (green solid arrow). (Image source: additional annotations on Fig 2 in NES paper) NES Algorithm The fitness associated with one sample is labeled as $f(x)$ and the search distribution over $x$ is parameterized by $\\theta$. NES is expected to optimize the parameter $\\theta$ to achieve maximum expected fitness:\n $$ \\mathcal{J}(\\theta) = \\mathbb{E}_{x\\sim p_\\theta(x)} [f(x)] = \\int_x f(x) p_\\theta(x) dx $$ Using the same log-likelihood trick in REINFORCE:\n $$ \\begin{aligned} \\nabla_\\theta\\mathcal{J}(\\theta) \u0026= \\nabla_\\theta \\int_x f(x) p_\\theta(x) dx \\\\ \u0026= \\int_x f(x) \\frac{p_\\theta(x)}{p_\\theta(x)}\\nabla_\\theta p_\\theta(x) dx \\\\ \u0026 = \\int_x f(x) p_\\theta(x) \\nabla_\\theta \\log p_\\theta(x) dx \\\\ \u0026 = \\mathbb{E}_{x \\sim p_\\theta} [f(x) \\nabla_\\theta \\log p_\\theta(x)] \\end{aligned} $$ Besides natural gradients, NES adopts a couple of important heuristics to make the algorithm performance more robust.\n NES applies rank-based fitness shaping, that is to use the rank under monotonically increasing fitness values instead of using $f(x)$ directly. Or it can be a function of the rank (“utility function”), which is considered as a free parameter of NES. NES adopts adaptation sampling to adjust hyperparameters at run time. When changing $\\theta \\to \\theta’$, samples drawn from $p_\\theta$ are compared with samples from $p_{\\theta’}$ using [Mann-Whitney U-test(https://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U_test)]; if there shows a positive or negative sign, the target hyperparameter decreases or increases by a multiplication constant. Note the score of a sample $x’_i \\sim p_{\\theta’}(x)$ has importance sampling weights applied $w_i’ = p_\\theta(x) / p_{\\theta’}(x)$. Applications: ES in Deep Reinforcement Learning OpenAI ES for RL The concept of using evolutionary algorithms in reinforcement learning can be traced back long ago, but only constrained to tabular RL due to computational limitations.\nInspired by NES, researchers at OpenAI (Salimans, et al. 2017) proposed to use NES as a gradient-free black-box optimizer to find optimal policy parameters $\\theta$ that maximizes the return function $F(\\theta)$. The key is to add Gaussian noise $\\epsilon$ on the model parameter $\\theta$ and then use the log-likelihood trick to write it as the gradient of the Gaussian pdf. Eventually only the noise term is left as a weighting scalar for measured performance.\nLet’s say the current parameter value is $\\hat{\\theta}$ (the added hat is to distinguish the value from the random variable $\\theta$). The search distribution of $\\theta$ is designed to be an isotropic multivariate Gaussian with a mean $\\hat{\\theta}$ and a fixed covariance matrix $\\sigma^2 I$,\n $$ \\theta \\sim \\mathcal{N}(\\hat{\\theta}, \\sigma^2 I) \\text{ equivalent to } \\theta = \\hat{\\theta} + \\sigma\\epsilon, \\epsilon \\sim \\mathcal{N}(0, I) $$ The gradient for $\\theta$ update is:\n $$ \\begin{aligned} \u0026 \\nabla_\\theta \\mathbb{E}_{\\theta\\sim\\mathcal{N}(\\hat{\\theta}, \\sigma^2 I)} F(\\theta) \\\\ \u0026= \\nabla_\\theta \\mathbb{E}_{\\epsilon\\sim\\mathcal{N}(0, I)} F(\\hat{\\theta} + \\sigma\\epsilon) \\\\ \u0026= \\nabla_\\theta \\int_{\\epsilon} p(\\epsilon) F(\\hat{\\theta} + \\sigma\\epsilon) d\\epsilon \u0026 \\scriptstyle{\\text{; Gaussian }p(\\epsilon)=(2\\pi)^{-\\frac{n}{2}} \\exp(-\\frac{1}{2}\\epsilon^\\top\\epsilon)} \\\\ \u0026= \\int_{\\epsilon} p(\\epsilon) \\nabla_\\epsilon \\log p(\\epsilon) \\nabla_\\theta \\epsilon\\;F(\\hat{\\theta} + \\sigma\\epsilon) d\\epsilon \u0026 \\scriptstyle{\\text{; log-likelihood trick}}\\\\ \u0026= \\mathbb{E}_{\\epsilon\\sim\\mathcal{N}(0, I)} [ \\nabla_\\epsilon \\big(-\\frac{1}{2}\\epsilon^\\top\\epsilon\\big) \\nabla_\\theta \\big(\\frac{\\theta - \\hat{\\theta}}{\\sigma}\\big) F(\\hat{\\theta} + \\sigma\\epsilon) ] \u0026 \\\\ \u0026= \\mathbb{E}_{\\epsilon\\sim\\mathcal{N}(0, I)} [ (-\\epsilon) (\\frac{1}{\\sigma}) F(\\hat{\\theta} + \\sigma\\epsilon) ] \u0026 \\\\ \u0026= \\frac{1}{\\sigma}\\mathbb{E}_{\\epsilon\\sim\\mathcal{N}(0, I)} [ \\epsilon F(\\hat{\\theta} + \\sigma\\epsilon) ] \u0026 \\scriptstyle{\\text{; negative sign can be absorbed.}} \\end{aligned} $$ In one generation, we can sample many $epsilon_i, i=1,\\dots,n$ and evaluate the fitness in parallel. One beautiful design is that no large model parameter needs to be shared. By only communicating the random seeds between workers, it is enough for the master node to do parameter update. This approach is later extended to adaptively learn a loss function; see my previous post on Evolved Policy Gradient.\nFig. 5. The algorithm for training a RL policy using evolution strategies. (Image source: ES-for-RL paper) To make the performance more robust, OpenAI ES adopts virtual batch normalization (BN with mini-batch used for calculating statistics fixed), mirror sampling (sampling a pair of $(-\\epsilon, \\epsilon)$ for evaluation), and fitness shaping.\nExploration with ES Exploration (vs exploitation) is an important topic in RL. The optimization direction in the ES algorithm above is only extracted from the cumulative return $F(\\theta)$. Without explicit exploration, the agent might get trapped in a local optimum.\nNovelty-Search ES (NS-ES; Conti et al, 2018) encourages exploration by updating the parameter in the direction to maximize the novelty score. The novelty score depends on a domain-specific behavior characterization function $b(\\pi_\\theta)$. The choice of $b(\\pi_\\theta)$ is specific to the task and seems to be a bit arbitrary; for example, in the Humanoid locomotion task in the paper, $b(\\pi_\\theta)$ is the final $(x,y)$ location of the agent.\n Every policy\u0026rsquo;s $b(\\pi_\\theta)$ is pushed to an archive set $\\mathcal{A}$. Novelty of a policy $\\pi_\\theta$ is measured as the k-nearest neighbor score between $b(\\pi_\\theta)$ and all other entries in $\\mathcal{A}$. (The use case of the archive set sounds quite similar to episodic memory.) $$ N(\\theta, \\mathcal{A}) = \\frac{1}{\\lambda} \\sum_{i=1}^\\lambda \\| b(\\pi_\\theta), b^\\text{knn}_i \\|_2 \\text{, where }b^\\text{knn}_i \\in \\text{kNN}(b(\\pi_\\theta), \\mathcal{A}) $$ The ES optimization step relies on the novelty score instead of fitness:\n $$ \\nabla_\\theta \\mathbb{E}_{\\theta\\sim\\mathcal{N}(\\hat{\\theta}, \\sigma^2 I)} N(\\theta, \\mathcal{A}) = \\frac{1}{\\sigma}\\mathbb{E}_{\\epsilon\\sim\\mathcal{N}(0, I)} [ \\epsilon N(\\hat{\\theta} + \\sigma\\epsilon, \\mathcal{A}) ] $$ NS-ES maintains a group of $M$ independently trained agents (\u0026ldquo;meta-population\u0026rdquo;), $\\mathcal{M} = \\{\\theta_1, \\dots, \\theta_M \\}$ and picks one to advance proportional to the novelty score. Eventually we select the best policy. This process is equivalent to ensembling; also see the same idea in SVPG.\n $$ \\begin{aligned} m \u0026\\leftarrow \\text{pick } i=1,\\dots,M\\text{ according to probability}\\frac{N(\\theta_i, \\mathcal{A})}{\\sum_{j=1}^M N(\\theta_j, \\mathcal{A})} \\\\ \\theta_m^{(t+1)} \u0026\\leftarrow \\theta_m^{(t)} + \\alpha \\frac{1}{\\sigma}\\sum_{i=1}^N \\epsilon_i N(\\theta^{(t)}_m + \\epsilon_i, \\mathcal{A}) \\text{ where }\\epsilon_i \\sim \\mathcal{N}(0, I) \\end{aligned} $$ where $N$ is the number of Gaussian perturbation noise vectors and $\\alpha$ is the learning rate.\nNS-ES completely discards the reward function and only optimizes for novelty to avoid deceptive local optima. To incorporate the fitness back into the formula, another two variations are proposed.\nNSR-ES:\n $$ \\theta_m^{(t+1)} \\leftarrow \\theta_m^{(t)} + \\alpha \\frac{1}{\\sigma}\\sum_{i=1}^N \\epsilon_i \\frac{N(\\theta^{(t)}_m + \\epsilon_i, \\mathcal{A}) + F(\\theta^{(t)}_m + \\epsilon_i)}{2} $$ NSRAdapt-ES (NSRA-ES): the adaptive weighting parameter $w = 1.0$ initially. We start decreasing $w$ if performance stays flat for a number of generations. Then when the performance starts to increase, we stop decreasing $w$ but increase it instead. In this way, fitness is preferred when the performance stops growing but novelty is preferred otherwise.\n $$ \\theta_m^{(t+1)} \\leftarrow \\theta_m^{(t)} + \\alpha \\frac{1}{\\sigma}\\sum_{i=1}^N \\epsilon_i \\big((1-w) N(\\theta^{(t)}_m + \\epsilon_i, \\mathcal{A}) + w F(\\theta^{(t)}_m + \\epsilon_i)\\big) $$ Fig. 6. (Left) The environment is Humanoid locomotion with a three-sided wall which plays a role as a deceptive trap to create local optimum. (Right) Experiments compare ES baseline and other variations that encourage exploration. (Image source: NS-ES paper) CEM-RL Fig. 7. Architectures of the (a) CEM-RL and (b) ERL algorithms (Image source: CEM-RL paper) The CEM-RL method (Pourchot \u0026amp; Sigaud, 2019) combines Cross Entropy Method (CEM) with either DDPG or TD3. CEM here works pretty much the same as the simple Gaussian ES described above and therefore the same function can be replaced using CMA-ES. CEM-RL is built on the framework of Evolutionary Reinforcement Learning (ERL; Khadka \u0026amp; Tumer, 2018) in which the standard EA algorithm selects and evolves a population of actors and the rollout experience generated in the process is then added into reply buffer for training both RL-actor and RL-critic networks.\nWorkflow:\n The mean actor of the CEM population is $\\pi_\\mu$ is initialized with a random actor network. The critic network $Q$ is initialized too, which will be updated by DDPG/TD3. Repeat until happy: a. Sample a population of actors $\\sim \\mathcal{N}(\\pi_\\mu, \\Sigma)$. b. Half of the population is evaluated. Their fitness scores are used as the cumulative reward $R$ and added into replay buffer. c. The other half are updated together with the critic. d. The new $\\pi_mu$ and $\\Sigma$ is computed using top performing elite samples. CMA-ES can be used for parameter update too. Extension: EA in Deep Learning (This section is not on evolution strategies, but still an interesting and relevant reading.)\nThe Evolutionary Algorithms have been applied on many deep learning problems. POET (Wang et al, 2019) is a framework based on EA and attempts to generate a variety of different tasks while the problems themselves are being solved. POET has been introduced in my last post on meta-RL. Evolutionary Reinforcement Learning (ERL) is another example; See Fig. 7 (b).\nBelow I would like to introduce two applications in more detail, Population-Based Training (PBT) and Weight-Agnostic Neural Networks (WANN).\nHyperparameter Tuning: PBT Fig. 8. Paradigms of comparing different ways of hyperparameter tuning. (Image source: PBT paper) Population-Based Training (Jaderberg, et al, 2017), short for PBT applies EA on the problem of hyperparameter tuning. It jointly trains a population of models and corresponding hyperparameters for optimal performance.\nPBT starts with a set of random candidates, each containing a pair of model weights initialization and hyperparameters, $\\{(\\theta_i, h_i)\\mid i=1, \\dots, N\\}$. Every sample is trained in parallel and asynchronously evaluates its own performance periodically. Whenever a member deems ready (i.e. after taking enough gradient update steps, or when the performance is good enough), it has a chance to be updated by comparing with the whole population:\n exploit(): When this model is under-performing, the weights could be replaced with a better performing model. explore(): If the model weights are overwritten, explore step perturbs the hyperparameters with random noise. In this process, only promising model and hyperparameter pairs can survive and keep on evolving, achieving better utilization of computational resources.\nFig. 9. The algorithm of population-based training. (Image source: PBT paper) Network Topology Optimization: WANN Weight Agnostic Neural Networks (short for WANN; Gaier \u0026amp; Ha 2019) experiments with searching for the smallest network topologies that can achieve the optimal performance without training the network weights. By not considering the best configuration of network weights, WANN puts much more emphasis on the architecture itself, making the focus different from NAS. WANN is heavily inspired by a classic genetic algorithm to evolve network topologies, called NEAT (\u0026ldquo;Neuroevolution of Augmenting Topologies\u0026rdquo;; Stanley \u0026amp; Miikkulainen 2002).\nThe workflow of WANN looks pretty much the same as standard GA:\n Initialize: Create a population of minimal networks. Evaluation: Test with a range of shared weight values. Rank and Selection: Rank by performance and complexity. Mutation: Create new population by varying best networks. Fig. 10. mutation operations for searching for new network topologies in WANN (Image source: WANN paper) At the \u0026ldquo;evaluation\u0026rdquo; stage, all the network weights are set to be the same. In this way, WANN is actually searching for network that can be described with a minimal description length. In the \u0026ldquo;selection\u0026rdquo; stage, both the network connection and the model performance are considered.\nFig. 11. Performance of WANN found network topologies on different RL tasks are compared with baseline FF networks commonly used in the literature. \"Tuned Shared Weight\" only requires adjusting one weight value. (Image source: WANN paper) As shown in Fig. 11, WANN results are evaluated with both random weights and shared weights (single weight). It is interesting that even when enforcing weight-sharing on all weights and tuning this single parameter, WANN can discover topologies that achieve non-trivial good performance.\n Cited as:\n@article{weng2019ES, title = \u0026quot;Evolution Strategies\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2019\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2019-09-05-evolution-strategies/\u0026quot; } References [1] Nikolaus Hansen. \u0026ldquo;The CMA Evolution Strategy: A Tutorial\u0026rdquo; arXiv preprint arXiv:1604.00772 (2016).\n[2] Marc Toussaint. Slides: \u0026ldquo;Introduction to Optimization\u0026rdquo;\n[3] David Ha. \u0026ldquo;A Visual Guide to Evolution Strategies\u0026rdquo; blog.otoro.net. Oct 2017.\n[4] Daan Wierstra, et al. \u0026ldquo;Natural evolution strategies.\u0026quot; IEEE World Congress on Computational Intelligence, 2008.\n[5] Agustinus Kristiadi. \u0026ldquo;Natural Gradient Descent\u0026rdquo; Mar 2018.\n[6] Razvan Pascanu \u0026amp; Yoshua Bengio. \u0026ldquo;Revisiting Natural Gradient for Deep Networks.\u0026quot; arXiv preprint arXiv:1301.3584 (2013).\n[7] Tim Salimans, et al. \u0026ldquo;Evolution strategies as a scalable alternative to reinforcement learning.\u0026quot; arXiv preprint arXiv:1703.03864 (2017).\n[8] Edoardo Conti, et al. \u0026ldquo;Improving exploration in evolution strategies for deep reinforcement learning via a population of novelty-seeking agents.\u0026quot; NIPS. 2018.\n[9] Aloïs Pourchot \u0026amp; Olivier Sigaud. \u0026ldquo;CEM-RL: Combining evolutionary and gradient-based methods for policy search.\u0026quot; ICLR 2019.\n[10] Shauharda Khadka \u0026amp; Kagan Tumer. \u0026ldquo;Evolution-guided policy gradient in reinforcement learning.\u0026quot; NIPS 2018.\n[11] Max Jaderberg, et al. \u0026ldquo;Population based training of neural networks.\u0026quot; arXiv preprint arXiv:1711.09846 (2017).\n[12] Adam Gaier \u0026amp; David Ha. \u0026ldquo;Weight Agnostic Neural Networks.\u0026quot; arXiv preprint arXiv:1906.04358 (2019).\n","permalink":"https://lilianweng.github.io/posts/2019-09-05-evolution-strategies/","summary":"Stochastic gradient descent is a universal choice for optimizing deep learning models. However, it is not the only option. With black-box optimization algorithms, you can evaluate a target function $f(x): \\mathbb{R}^n \\to \\mathbb{R}$, even when you don\u0026rsquo;t know the precise analytic form of $f(x)$ and thus cannot compute gradients or the Hessian matrix. Examples of black-box optimization methods include Simulated Annealing, Hill Climbing and Nelder-Mead method.\nEvolution Strategies (ES) is one type of black-box optimization algorithms, born in the family of Evolutionary Algorithms (EA).","title":"Evolution Strategies"},{"content":"In my earlier post on meta-learning, the problem is mainly defined in the context of few-shot classification. Here I would like to explore more into cases when we try to \u0026ldquo;meta-learn\u0026rdquo; Reinforcement Learning (RL) tasks by developing an agent that can solve unseen tasks fast and efficiently.\nTo recap, a good meta-learning model is expected to generalize to new tasks or new environments that have never been encountered during training. The adaptation process, essentially a mini learning session, happens at test with limited exposure to the new configurations. Even without any explicit fine-tuning (no gradient backpropagation on trainable variables), the meta-learning model autonomously adjusts internal hidden states to learn.\nTraining RL algorithms can be notoriously difficult sometimes. If the meta-learning agent could become so smart that the distribution of solvable unseen tasks grows extremely broad, we are on track towards general purpose methods \u0026mdash; essentially building a \u0026ldquo;brain\u0026rdquo; which would solve all kinds of RL problems without much human interference or manual feature engineering. Sounds amazing, right? 💖\nOn the Origin of Meta-RL Back in 2001 I encountered a paper written in 2001 by Hochreiter et al. when reading Wang et al., 2016. Although the idea was proposed for supervised learning, there are so many resemblances to the current approach to meta-RL.\nFig. 1. The meta-learning system consists of the supervisory and the subordinate systems. The subordinate system is a recurrent neural network that takes as input both the observation at the current time step, $x\\_t$ and the label at the last time step, $y\\_{t-1}$. (Image source: Hochreiter et al., 2001) Hochreiter\u0026rsquo;s meta-learning model is a recurrent network with LSTM cell. LSTM is a good choice because it can internalize a history of inputs and tune its own weights effectively through BPTT. The training data contains $K$ sequences and each sequence is consist of $N$ samples generated by a target function $f_k(.), k=1, \\dots, K$,\n $$ \\{\\text{input: }(\\mathbf{x}^k_i, \\mathbf{y}^k_{i-1}) \\to \\text{label: }\\mathbf{y}^k_i\\}_{i=1}^N \\text{ where }\\mathbf{y}^k_i = f_k(\\mathbf{x}^k_i) $$ Noted that the last label $\\mathbf{y}^k_{i-1}$ is also provided as an auxiliary input so that the function can learn the presented mapping.\nIn the experiment of decoding two-dimensional quadratic functions, $a x_1^2 + b x_2^2 + c x_1 x_2 + d x_1 + e x_2 + f$, with coefficients $a$-$f$ are randomly sampled from [-1, 1], this meta-learning system was able to approximate the function after seeing only ~35 examples.\nProposal in 2016 In the modern days of DL, Wang et al. (2016) and Duan et al. (2017) simultaneously proposed the very similar idea of Meta-RL (it is called RL^2 in the second paper). A meta-RL model is trained over a distribution of MDPs, and at test time, it is able to learn to solve a new task quickly. The goal of meta-RL is ambitious, taking one step further towards general algorithms.\nDefine Meta-RL Meta Reinforcement Learning, in short, is to do meta-learning in the field of reinforcement learning. Usually the train and test tasks are different but drawn from the same family of problems; i.e., experiments in the papers included multi-armed bandit with different reward probabilities, mazes with different layouts, same robots but with different physical parameters in simulator, and many others.\nFormulation Let\u0026rsquo;s say we have a distribution of tasks, each formularized as an MDP (Markov Decision Process), $M_i \\in \\mathcal{M}$. An MDP is determined by a 4-tuple, $M_i= \\langle \\mathcal{S}, \\mathcal{A}, P_i, R_i \\rangle$:\n Symbol Meaning $\\mathcal{S}$ A set of states. $\\mathcal{A}$ A set of actions. $P_i: \\mathcal{S} \\times \\mathcal{A} \\times \\mathcal{S} \\to \\mathbb{R}_{+}$ Transition probability function. $R_i: \\mathcal{S} \\times \\mathcal{A} \\to \\mathbb{R}$ Reward function. (RL^2 paper adds an extra parameter, horizon $T$, into the MDP tuple to emphasize that each MDP should have a finite horizon.)\nNote that common state $\\mathcal{S}$ and action space $\\mathcal{A}$ are used above, so that a (stochastic) policy: $\\pi_\\theta: \\mathcal{S} \\times \\mathcal{A} \\to \\mathbb{R}_{+}$ would get inputs compatible across different tasks. The test tasks are sampled from the same distribution $\\mathcal{M}$ or slightly modified version.\nFig. 2. Illustration of meta-RL, containing two optimization loops. The outer loop samples a new environment in every iteration and adjusts parameters that determine the agent's behavior. In the inner loop, the agent interacts with the environment and optimizes for the maximal reward. (Image source: Botvinick, et al. 2019) Main Differences from RL The overall configure of meta-RL is very similar to an ordinary RL algorithm, except that the last reward $r_{t-1}$ and the last action $a_{t-1}$ are also incorporated into the policy observation in addition to the current state $s_t$.\n In RL: $\\pi_\\theta(s_t) \\to$ a distribution over $\\mathcal{A}$ In meta-RL: $\\pi_\\theta(a_{t-1}, r_{t-1}, s_t) \\to$ a distribution over $\\mathcal{A}$ The intention of this design is to feed a history into the model so that the policy can internalize the dynamics between states, rewards, and actions in the current MDP and adjust its strategy accordingly. This is well aligned with the setup in Hochreiter\u0026rsquo;s system. Both meta-RL and RL^2 implemented an LSTM policy and the LSTM\u0026rsquo;s hidden states serve as a memory for tracking characteristics of the trajectories. Because the policy is recurrent, there is no need to feed the last state as inputs explicitly.\nThe training procedure works as follows:\n Sample a new MDP, $M_i \\sim \\mathcal{M}$; Reset the hidden state of the model; Collect multiple trajectories and update the model weights; Repeat from step 1. Fig. 3. In the meta-RL paper, different actor-critic architectures all use a recurrent model. Last reward and last action are additional inputs. The observation is fed into the LSTM either as a one-hot vector or as an embedding vector after passed through an encoder model. (Image source: Wang et al., 2016) Fig. 4. As described in the RL^2 paper, illustration of the procedure of the model interacting with a series of MDPs in training time . (Image source: Duan et al., 2017) Key Components There are three key components in Meta-RL:\n ⭐ A Model with Memory A recurrent neural network maintains a hidden state. Thus, it could acquire and memorize the knowledge about the current task by updating the hidden state during rollouts. Without memory, meta-RL would not work.\n ⭐ Meta-learning Algorithm A meta-learning algorithm refers to how we can update the model weights to optimize for the purpose of solving an unseen task fast at test time. In both Meta-RL and RL^2 papers, the meta-learning algorithm is the ordinary gradient descent update of LSTM with hidden state reset between a switch of MDPs.\n ⭐ A Distribution of MDPs While the agent is exposed to a variety of environments and tasks during training, it has to learn how to adapt to different MDPs.\n According to Botvinick et al. (2019), one source of slowness in RL training is weak inductive bias ( = \u0026ldquo;a set of assumptions that the learner uses to predict outputs given inputs that it has not encountered\u0026rdquo;). As a general ML rule, a learning algorithm with weak inductive bias will be able to master a wider range of variance, but usually, will be less sample-efficient. Therefore, to narrow down the hypotheses with stronger inductive biases help improve the learning speed.\nIn meta-RL, we impose certain types of inductive biases from the task distribution and store them in memory. Which inductive bias to adopt at test time depends on the algorithm. Together, these three key components depict a compelling view of meta-RL: Adjusting the weights of a recurrent network is slow but it allows the model to work out a new task fast with its own RL algorithm implemented in its internal activity dynamics.\nMeta-RL interestingly and not very surprisingly matches the ideas in the AI-GAs (\u0026ldquo;AI-Generating Algorithms\u0026rdquo;) paper by Jeff Clune (2019). He proposed that one efficient way towards building general AI is to make learning as automatic as possible. The AI-GAs approach involves three pillars: (1) meta-learning architectures, (2) meta-learning algorithms, and (3) automatically generated environments for effective learning.\n The topic of designing good recurrent network architectures is a bit too broad to be discussed here, so I will skip it. Next, let\u0026rsquo;s look further into another two components: meta-learning algorithms in the context of meta-RL and how to acquire a variety of training MDPs.\nMeta-Learning Algorithms for Meta-RL My previous post on meta-learning has covered several classic meta-learning algorithms. Here I\u0026rsquo;m gonna include more related to RL.\nOptimizing Model Weights for Meta-learning Both MAML (Finn, et al. 2017) and Reptile (Nichol et al., 2018) are methods on updating model parameters in order to achieve good generalization performance on new tasks. See an earlier post section on MAML and Reptile.\nMeta-learning Hyperparameters The return function in an RL problem, $G_t^{(n)}$ or $G_t^\\lambda$, involves a few hyperparameters that are often set heuristically, like the discount factor $\\gamma$ and the bootstrapping parameter $\\lambda$. Meta-gradient RL (Xu et al., 2018) considers them as meta-parameters, $\\eta=\\{\\gamma, \\lambda \\}$, that can be tuned and learned online while an agent is interacting with the environment. Therefore, the return becomes a function of $\\eta$ and dynamically adapts itself to a specific task over time.\n $$ \\begin{aligned} G_\\eta^{(n)}(\\tau_t) \u0026= R_{t+1} + \\gamma R_{t+2} + \\dots + \\gamma^{n-1}R_{t+n} + \\gamma^n v_\\theta(s_{t+n}) \u0026 \\scriptstyle{\\text{; n-step return}} \\\\ G_\\eta^{\\lambda}(\\tau_t) \u0026= (1-\\lambda) \\sum_{n=1}^\\infty \\lambda^{n-1} G_\\eta^{(n)} \u0026 \\scriptstyle{\\text{; λ-return, mixture of n-step returns}} \\end{aligned} $$ During training, we would like to update the policy parameters with gradients as a function of all the information in hand, $\\theta' = \\theta + f(\\tau, \\theta, \\eta)$, where $\\theta$ are the current model weights, $\\tau$ is a sequence of trajectories, and $\\eta$ are the meta-parameters.\nMeanwhile, let\u0026rsquo;s say we have a meta-objective function $J(\\tau, \\theta, \\eta)$ as a performance measure. The training process follows the principle of online cross-validation, using a sequence of consecutive experiences:\n Starting with parameter $\\theta$, the policy $\\pi_\\theta$ is updated on the first batch of samples $\\tau$, resulting in $\\theta'$. Then we continue running the policy $\\pi_{\\theta'}$ to collect a new set of experiences $\\tau'$, just following $\\tau$ consecutively in time. The performance is measured as $J(\\tau', \\theta', \\bar{\\eta})$ with a fixed meta-parameter $\\bar{\\eta}$. The gradient of meta-objective $J(\\tau', \\theta', \\bar{\\eta})$ w.r.t. $\\eta$ is used to update $\\eta$: $$ \\begin{aligned} \\Delta \\eta \u0026= -\\beta \\frac{\\partial J(\\tau', \\theta', \\bar{\\eta})}{\\partial \\eta} \\\\ \u0026= -\\beta \\frac{\\partial J(\\tau', \\theta', \\bar{\\eta})}{\\partial \\theta'} \\frac{d\\theta'}{d\\eta} \u0026 \\scriptstyle{\\text{ ; single variable chain rule.}} \\\\ \u0026= -\\beta \\frac{\\partial J(\\tau', \\theta', \\bar{\\eta})}{\\partial \\theta'} \\frac{\\partial (\\theta + f(\\tau, \\theta, \\eta))}{\\partial\\eta} \\\\ \u0026= -\\beta \\frac{\\partial J(\\tau', \\theta', \\bar{\\eta})}{\\partial \\theta'} \\Big(\\frac{d\\theta}{d\\eta} + \\frac{\\partial f(\\tau, \\theta, \\eta)}{\\partial\\theta}\\frac{d\\theta}{d\\eta} + \\frac{\\partial f(\\tau, \\theta, \\eta)}{\\partial\\eta}\\frac{d\\eta}{d\\eta} \\Big) \u0026 \\scriptstyle{\\text{; multivariable chain rule.}}\\\\ \u0026= -\\beta \\frac{\\partial J(\\tau', \\theta', \\bar{\\eta})}{\\partial \\theta'} \\Big( \\color{red}{\\big(\\mathbf{I} + \\frac{\\partial f(\\tau, \\theta, \\eta)}{\\partial\\theta}\\big)}\\frac{d\\theta}{d\\eta} + \\frac{\\partial f(\\tau, \\theta, \\eta)}{\\partial\\eta}\\Big) \u0026 \\scriptstyle{\\text{; secondary gradient term in red.}} \\end{aligned} $$ where $\\beta$ is the learning rate for $\\eta$.\nThe meta-gradient RL algorithm simplifies the computation by setting the secondary gradient term to zero, $\\mathbf{I} + \\partial g(\\tau, \\theta, \\eta)/\\partial\\theta = 0$ \u0026mdash; this choice prefers the immediate effect of the meta-parameters $\\eta$ on the parameters $\\theta$. Eventually we get:\n $$ \\Delta \\eta = -\\beta \\frac{\\partial J(\\tau', \\theta', \\bar{\\eta})}{\\partial \\theta'} \\frac{\\partial f(\\tau, \\theta, \\eta)}{\\partial\\eta} $$ Experiments in the paper adopted the meta-objective function same as $TD(\\lambda)$ algorithm, minimizing the error between the approximated value function $v_\\theta(s)$ and the $\\lambda$-return:\n $$ \\begin{aligned} J(\\tau, \\theta, \\eta) \u0026= (G^\\lambda_\\eta(\\tau) - v_\\theta(s))^2 \\\\ J(\\tau', \\theta', \\bar{\\eta}) \u0026= (G^\\lambda_{\\bar{\\eta}}(\\tau') - v_{\\theta'}(s'))^2 \\end{aligned} $$ Meta-learning the Loss Function In policy gradient algorithms, the expected total reward is maximized by updating the policy parameters $\\theta$ in the direction of estimated gradient (Schulman et al., 2016),\n $$ g = \\mathbb{E}[\\sum_{t=0}^\\infty \\Psi_t \\nabla_\\theta \\log \\pi_\\theta (a_t \\mid s_t)] $$ where the candidates for $\\Psi_t$ include the trajectory return $G_t$, the Q value $Q(s_t, a_t)$, or the advantage value $A(s_t, a_t)$. The corresponding surrogate loss function for the policy gradient can be reverse-engineered:\n $$ L_\\text{pg} = \\mathbb{E}[\\sum_{t=0}^\\infty \\Psi_t \\log \\pi_\\theta (a_t \\mid s_t)] $$ This loss function is a measure over a history of trajectories, $(s_0, a_0, r_0, \\dots, s_t, a_t, r_t, \\dots)$. Evolved Policy Gradient (EPG; Houthooft, et al, 2018) takes a step further by defining the policy gradient loss function as a temporal convolution (1-D convolution) over the agent\u0026rsquo;s past experience, $L_\\phi$. The parameters $\\phi$ of the loss function network are evolved in a way that an agent can achieve higher returns.\nSimilar to many meta-learning algorithms, EPG has two optimization loops:\n In the internal loop, an agent learns to improve its policy $\\pi_\\theta$. In the outer loop, the model updates the parameters $\\phi$ of the loss function $L_\\phi$. Because there is no explicit way to write down a differentiable equation between the return and the loss, EPG turned to Evolutionary Strategies (ES). A general idea is to train a population of $N$ agents, each of them is trained with the loss function $L_{\\phi + \\sigma \\epsilon_i}$ parameterized with $\\phi$ added with a small Gaussian noise $\\epsilon_i \\sim \\mathcal{N}(0, \\mathbf{I})$ of standard deviation $\\sigma$. During the inner loop\u0026rsquo;s training, EPG tracks a history of experience and updates the policy parameters according to the loss function $L_{\\phi + \\sigma\\epsilon_i}$ for each agent:\n $$ \\theta_i \\leftarrow \\theta - \\alpha_\\text{in} \\nabla_\\theta L_{\\phi + \\sigma \\epsilon_i} (\\pi_\\theta, \\tau_{t-K, \\dots, t}) $$ where $\\alpha_\\text{in}$ is the learning rate of the inner loop and $\\tau_{t-K, \\dots, t}$ is a sequence of $M$ transitions up to the current time step $t$.\nOnce the inner loop policy is mature enough, the policy is evaluated by the mean return $\\bar{G}_{\\phi+\\sigma\\epsilon_i}$ over multiple randomly sampled trajectories. Eventually, we are able to estimate the gradient of $\\phi$ according to NES numerically (Salimans et al, 2017). While repeating this process, both the policy parameters $\\theta$ and the loss function weights $\\phi$ are being updated simultaneously to achieve higher returns.\n $$ \\phi \\leftarrow \\phi + \\alpha_\\text{out} \\frac{1}{\\sigma N} \\sum_{i=1}^N \\epsilon_i G_{\\phi+\\sigma\\epsilon_i} $$ where $\\alpha_\\text{out}$ is the learning rate of the outer loop.\nIn practice, the loss $L_\\phi$ is bootstrapped with an ordinary policy gradient (such as REINFORCE or PPO) surrogate loss $L_\\text{pg}$, $\\hat{L} = (1-\\alpha) L_\\phi + \\alpha L_\\text{pg}$. The weight $\\alpha$ is annealing from 1 to 0 gradually during training. At test time, the loss function parameter $\\phi$ stays fixed and the loss value is computed over a history of experience to update the policy parameters $\\theta$.\nMeta-learning the Exploration Strategies The exploitation vs exploration dilemma is a critical problem in RL. Common ways to do exploration include $\\epsilon$-greedy, random noise on actions, or stochastic policy with built-in randomness on the action space.\nMAESN (Gupta et al, 2018) is an algorithm to learn structured action noise from prior experience for better and more effective exploration. Simply adding random noise on actions cannot capture task-dependent or time-correlated exploration strategies. MAESN changes the policy to condition on a per-task random variable $z_i \\sim \\mathcal{N}(\\mu_i, \\sigma_i)$, for $i$-th task $M_i$, so we would have a policy $a \\sim \\pi_\\theta(a\\mid s, z_i)$. The latent variable $z_i$ is sampled once and fixed during one episode. Intuitively, the latent variable determines one type of behavior (or skills) that should be explored more at the beginning of a rollout and the agent would adjust its actions accordingly. Both the policy parameters and latent space are optimized to maximize the total task rewards. In the meantime, the policy learns to make use of the latent variables for exploration.\nIn addition, the loss function includes a KL divergence between the learned latent variable and a unit Gaussian prior, $D_\\text{KL}(\\mathcal{N}(\\mu_i, \\sigma_i)|\\mathcal{N}(0, \\mathbf{I}))$. On one hand, it restricts the learned latent space not too far from a common prior. On the other hand, it creates the variational evidence lower bound (ELBO) for the reward function. Interestingly the paper found that $(\\mu_i, \\sigma_i)$ for each task are usually close to the prior at convergence.\nFig. 5. The policy is conditioned on a latent variable variable $z\\_i \\sim \\mathcal{N}(\\mu, \\sigma)$ that is sampled once every episode. Each task has different hyperparameters for the latent variable distribution, $(\\mu\\_i, \\sigma\\_i)$ and they are optimized in the outer loop. (Image source: Gupta et al, 2018) Episodic Control A major criticism of RL is on its sample inefficiency. A large number of samples and small learning steps are required for incremental parameter adjustment in RL in order to maximize generalization and avoid catastrophic forgetting of earlier learning (Botvinick et al., 2019).\nEpisodic control (Lengyel \u0026amp; Dayan, 2008) is proposed as a solution to avoid forgetting and improve generalization while training at a faster speed. It is partially inspired by hypotheses on instance-based hippocampal learning.\nAn episodic memory keeps explicit records of past events and uses these records directly as point of reference for making new decisions (i.e. just like metric-based meta-learning). In MFEC (Model-Free Episodic Control; Blundell et al., 2016), the memory is modeled as a big table, storing the state-action pair $(s, a)$ as key and the corresponding Q-value $Q_\\text{EC}(s, a)$ as value. When receiving a new observation $s$, the Q value is estimated in an non-parametric way as the average Q-value of top $k$ most similar samples:\n $$ \\hat{Q}_\\text{EC}(s, a) = \\begin{cases} Q_\\text{EC}(s, a) \u0026 \\text{if } (s,a) \\in Q_\\text{EC}, \\\\ \\frac{1}{k} \\sum_{i=1}^k Q(s^{(i)}, a) \u0026 \\text{otherwise} \\end{cases} $$ where $s^{(i)}, i=1, \\dots, k$ are top $k$ states with smallest distances to the state $s$. Then the action that yields the highest estimated Q value is selected. Then the memory table is updated according to the return received at $s_t$:\n $$ Q_\\text{EC}(s, a) \\leftarrow \\begin{cases} \\max\\{Q_\\text{EC}(s_t, a_t), G_t\\} \u0026 \\text{if } (s,a) \\in Q_\\text{EC}, \\\\ G_t \u0026 \\text{otherwise} \\end{cases} $$ As a tabular RL method, MFEC suffers from large memory consumption and a lack of ways to generalize among similar states. The first one can be fixed with an LRU cache. Inspired by metric-based meta-learning, especially Matching Networks (Vinyals et al., 2016), the generalization problem is improved in a follow-up algorithm, NEC (Neural Episodic Control; Pritzel et al., 2016).\nThe episodic memory in NEC is a Differentiable Neural Dictionary (DND), where the key is a convolutional embedding vector of input image pixels and the value stores estimated Q value. Given an inquiry key, the output is a weighted sum of values of top similar keys, where the weight is a normalized kernel measure between the query key and the selected key in the dictionary. This sounds like a hard attention machanism.\nFig. 6 Illustrations of episodic memory module in NEC and two operations on a differentiable neural dictionary. (Image source: Pritzel et al., 2016) Further, Episodic LSTM (Ritter et al., 2018) enhances the basic LSTM architecture with a DND episodic memory, which stores task context embeddings as keys and the LSTM cell states as values. The stored hidden states are retrieved and added directly to the current cell state through the same gating mechanism within LSTM:\nFig. 7. Illustration of the episodic LSTM architecture. The additional structure of episodic memory is in bold. (Image source: Ritter et al., 2018) $$ \\begin{aligned} \\mathbf{c}_t \u0026= \\mathbf{i}_t \\circ \\mathbf{c}_\\text{in} + \\mathbf{f}_t \\circ \\mathbf{c}_{t-1} + \\color{green}{\\mathbf{r}_t \\circ \\mathbf{c}_\\text{ep}} \u0026\\\\ \\mathbf{i}_t \u0026= \\sigma(\\mathbf{W}_{i} \\cdot [\\mathbf{h}_{t-1}, \\mathbf{x}_t] + \\mathbf{b}_i) \u0026 \\scriptstyle{\\text{; input gate}} \\\\ \\mathbf{f}_t \u0026= \\sigma(\\mathbf{W}_{f} \\cdot [\\mathbf{h}_{t-1}, \\mathbf{x}_t] + \\mathbf{b}_f) \u0026 \\scriptstyle{\\text{; forget gate}} \\\\ \\color{green}{\\mathbf{r}_t} \u0026 \\color{green}{=} \\color{green}{\\sigma(\\mathbf{W}_{r} \\cdot [\\mathbf{h}_{t-1}, \\mathbf{x}_t] + \\mathbf{b}_r)} \u0026 \\scriptstyle{\\text{; reinstatement gate}} \\end{aligned} $$ where $\\mathbf{c}_t$ and $\\mathbf{h}_t$ are hidden and cell state at time $t$; $\\mathbf{i}_t$, $\\mathbf{f}_t$ and $\\mathbf{r}_t$ are input, forget and reinstatement gates, respectively; $\\mathbf{c}_\\text{ep}$ is the retrieved cell state from episodic memory. The newly added episodic memory components are marked in green.\nThis architecture provides a shortcut to the prior experience through context-based retrieval. Meanwhile, explicitly saving the task-dependent experience in an external memory avoids forgetting. In the paper, all the experiments have manually designed context vectors. How to construct an effective and efficient format of task context embeddings for more free-formed tasks would be an interesting topic.\nOverall the capacity of episodic control is limited by the complexity of the environment. It is very rare for an agent to repeatedly visit exactly the same states in a real-world task, so properly encoding the states is critical. The learned embedding space compresses the observation data into a lower dimension space and, in the meantime, two states being close in this space are expected to demand similar strategies.\nTraining Task Acquisition Among three key components, how to design a proper distribution of tasks is the less studied and probably the most specific one to meta-RL itself. As described above, each task is a MDP: $M_i = \\langle \\mathcal{S}, \\mathcal{A}, P_i, R_i \\rangle \\in \\mathcal{M}$. We can build a distribution of MDPs by modifying:\n The reward configuration: Among different tasks, same behavior might get rewarded differently according to $R_i$. Or, the environment: The transition function $P_i$ can be reshaped by initializing the environment with varying shifts between states. Task Generation by Domain Randomization Randomizing parameters in a simulator is an easy way to obtain tasks with modified transition functions. If interested in learning further, check my last post on domain randomization.\nEvolutionary Algorithm on Environment Generation Evolutionary algorithm is a gradient-free heuristic-based optimization method, inspired by natural selection. A population of solutions follows a loop of evaluation, selection, reproduction, and mutation. Eventually, good solutions survive and thus get selected.\nPOET (Wang et al, 2019), a framework based on the evolutionary algorithm, attempts to generate tasks while the problems themselves are being solved. The implementation of POET is only specifically designed for a simple 2D bipedal walker environment but points out an interesting direction. It is noteworthy that the evolutionary algorithm has had some compelling applications in Deep Learning like EPG and PBT (Population-Based Training; Jaderberg et al, 2017).\nFig. 8. An example bipedal walking environment (top) and an overview of POET (bottom). (Image source: POET blog post) The 2D bipedal walking environment is evolving: from a simple flat surface to a much more difficult trail with potential gaps, stumps, and rough terrains. POET pairs the generation of environmental challenges and the optimization of agents together so as to (a) select agents that can resolve current challenges and (b) evolve environments to be solvable. The algorithm maintains a list of environment-agent pairs and repeats the following:\n Mutation: Generate new environments from currently active environments. Note that here types of mutation operations are created just for bipedal walker and a new environment would demand a new set of configurations. Optimization: Train paired agents within their respective environments. Selection: Periodically attempt to transfer current agents from one environment to another. Copy and update the best performing agent for every environment. The intuition is that skills learned in one environment might be helpful for a different environment. The procedure above is quite similar to PBT, but PBT mutates and evolves hyperparameters instead. To some extent, POET is doing domain randomization, as all the gaps, stumps and terrain roughness are controlled by some randomization probability parameters. Different from DR, the agents are not exposed to a fully randomized difficult environment all at once, but instead they are learning gradually with a curriculum configured by the evolutionary algorithm.\nLearning with Random Rewards An MDP without a reward function $R$ is known as a Controlled Markov process (CMP). Given a predefined CMP, $\\langle \\mathcal{S}, \\mathcal{A}, P\\rangle$, we can acquire a variety of tasks by generating a collection of reward functions $\\mathcal{R}$ that encourage the training of an effective meta-learning policy.\nGupta et al. (2018) proposed two unsupervised approaches for growing the task distribution in the context of CMP. Assuming there is an underlying latent variable $z \\sim p(z)$ associated with every task, it parameterizes/determines a reward function: $r_z(s) = \\log D(z|s)$, where a \u0026ldquo;discriminator\u0026rdquo; function $D(.)$ is used to extract the latent variable from the state. The paper described two ways to construct a discriminator function:\n Sample random weights $\\phi_\\text{rand}$ of the discriminator, $D_{\\phi_\\text{rand}}(z \\mid s)$. Learn a discriminator function to encourage diversity-driven exploration. This method is introduced in more details in another sister paper \u0026ldquo;DIAYN\u0026rdquo; (Eysenbach et al., 2018). DIAYN, short for \u0026ldquo;Diversity is all you need\u0026rdquo;, is a framework to encourage a policy to learn useful skills without a reward function. It explicitly models the latent variable $z$ as a skill embedding and makes the policy conditioned on $z$ in addition to state $s$, $\\pi_\\theta(a \\mid s, z)$. (Ok, this part is same as MAESN unsurprisingly, as the papers are from the same group.) The design of DIAYN is motivated by a few hypotheses:\n Skills should be diverse and lead to visitations of different states. → maximize the mutual information between states and skills, $I(S; Z)$ Skills should be distinguishable by states, not actions. → minimize the mutual information between actions and skills, conditioned on states $I(A; Z \\mid S)$ The objective function to maximize is as follows, where the policy entropy is also added to encourage diversity:\n $$ \\begin{aligned} \\mathcal{F}(\\theta) \u0026= I(S; Z) + H[A \\mid S] - I(A; Z \\mid S) \u0026 \\\\ \u0026= (H(Z) - H(Z \\mid S)) + H[A \\mid S] - (H[A\\mid S] - H[A\\mid S, Z]) \u0026 \\\\ \u0026= H[A\\mid S, Z] \\color{green}{- H(Z \\mid S) + H(Z)} \u0026 \\\\ \u0026= H[A\\mid S, Z] + \\mathbb{E}_{z\\sim p(z), s\\sim\\rho(s)}[\\log p(z \\mid s)] - \\mathbb{E}_{z\\sim p(z)}[\\log p(z)] \u0026 \\scriptstyle{\\text{; can infer skills from states \u0026 p(z) is diverse.}} \\\\ \u0026\\ge H[A\\mid S, Z] + \\mathbb{E}_{z\\sim p(z), s\\sim\\rho(s)}[\\color{red}{\\log D_\\phi(z \\mid s) - \\log p(z)}] \u0026 \\scriptstyle{\\text{; according to Jensen's inequality; \"pseudo-reward\" in red.}} \\end{aligned} $$ where $I(.)$ is mutual information and $H[.]$ is entropy measure. We cannot integrate all states to compute $p(z \\mid s)$, so approximate it with $D_\\phi(z \\mid s)$ \u0026mdash; that is the diversity-driven discriminator function.\nFig. 9. DIAYN Algorithm. (Image source: Eysenbach et al., 2019) Once the discriminator function is learned, sampling a new MDP for training is strainght-forward: First, sample a latent variable, $z \\sim p(z)$ and construct a reward function $r_z(s) = \\log(D(z \\vert s))$. Pairing the reward function with a predefined CMP creates a new MDP.\n Cited as:\n@article{weng2019metaRL, title = \u0026quot;Meta Reinforcement Learning\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2019\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2019-06-23-meta-rl/\u0026quot; } References [1] Richard S. Sutton. \u0026ldquo;The Bitter Lesson.\u0026quot; March 13, 2019.\n[2] Sepp Hochreiter, A. Steven Younger, and Peter R. Conwell. \u0026ldquo;Learning to learn using gradient descent.\u0026quot; Intl. Conf. on Artificial Neural Networks. 2001.\n[3] Jane X Wang, et al. \u0026ldquo;Learning to reinforcement learn.\u0026quot; arXiv preprint arXiv:1611.05763 (2016).\n[4] Yan Duan, et al. \u0026ldquo;RL $^ 2$: Fast Reinforcement Learning via Slow Reinforcement Learning.\u0026quot; ICLR 2017.\n[5] Matthew Botvinick, et al. \u0026ldquo;Reinforcement Learning, Fast and Slow\u0026rdquo; Cell Review, Volume 23, Issue 5, P408-422, May 01, 2019.\n[6] Jeff Clune. \u0026ldquo;AI-GAs: AI-generating algorithms, an alternate paradigm for producing general artificial intelligence\u0026rdquo; arXiv preprint arXiv:1905.10985 (2019).\n[7] Zhongwen Xu, et al. \u0026ldquo;Meta-Gradient Reinforcement Learning\u0026rdquo; NIPS 2018.\n[8] Rein Houthooft, et al. \u0026ldquo;Evolved Policy Gradients.\u0026quot; NIPS 2018.\n[9] Tim Salimans, et al. \u0026ldquo;Evolution strategies as a scalable alternative to reinforcement learning.\u0026quot; arXiv preprint arXiv:1703.03864 (2017).\n[10] Abhishek Gupta, et al. \u0026ldquo;Meta-Reinforcement Learning of Structured Exploration Strategies.\u0026quot; NIPS 2018.\n[11] Alexander Pritzel, et al. \u0026ldquo;Neural episodic control.\u0026quot; Proc. Intl. Conf. on Machine Learning, Volume 70, 2017.\n[12] Charles Blundell, et al. \u0026ldquo;Model-free episodic control.\u0026quot; arXiv preprint arXiv:1606.04460 (2016).\n[13] Samuel Ritter, et al. \u0026ldquo;Been there, done that: Meta-learning with episodic recall.\u0026quot; ICML, 2018.\n[14] Rui Wang et al. \u0026ldquo;Paired Open-Ended Trailblazer (POET): Endlessly Generating Increasingly Complex and Diverse Learning Environments and Their Solutions\u0026rdquo; arXiv preprint arXiv:1901.01753 (2019).\n[15] Uber Engineering Blog: \u0026ldquo;POET: Endlessly Generating Increasingly Complex and Diverse Learning Environments and their Solutions through the Paired Open-Ended Trailblazer.\u0026quot; Jan 8, 2019.\n[16] Abhishek Gupta, et al.\u0026ldquo;Unsupervised meta-learning for Reinforcement Learning\u0026rdquo; arXiv preprint arXiv:1806.04640 (2018).\n[17] Eysenbach, Benjamin, et al. \u0026ldquo;Diversity is all you need: Learning skills without a reward function.\u0026quot; ICLR 2019.\n[18] Max Jaderberg, et al. \u0026ldquo;Population Based Training of Neural Networks.\u0026quot; arXiv preprint arXiv:1711.09846 (2017).\n","permalink":"https://lilianweng.github.io/posts/2019-06-23-meta-rl/","summary":"In my earlier post on meta-learning, the problem is mainly defined in the context of few-shot classification. Here I would like to explore more into cases when we try to \u0026ldquo;meta-learn\u0026rdquo; Reinforcement Learning (RL) tasks by developing an agent that can solve unseen tasks fast and efficiently.\nTo recap, a good meta-learning model is expected to generalize to new tasks or new environments that have never been encountered during training. The adaptation process, essentially a mini learning session, happens at test with limited exposure to the new configurations.","title":"Meta Reinforcement Learning"},{"content":"In Robotics, one of the hardest problems is how to make your model transfer to the real world. Due to the sample inefficiency of deep RL algorithms and the cost of data collection on real robots, we often need to train models in a simulator which theoretically provides an infinite amount of data. However, the reality gap between the simulator and the physical world often leads to failure when working with physical robots. The gap is triggered by an inconsistency between physical parameters (i.e. friction, kp, damping, mass, density) and, more fatally, the incorrect physical modeling (i.e. collision between soft surfaces).\nTo close the sim2real gap, we need to improve the simulator and make it closer to reality. A couple of approaches:\n System identification System identification is to build a mathematical model for a physical system; in the context of RL, the mathematical model is the simulator. To make the simulator more realistic, careful calibration is necessary. Unfortunately, calibration is expensive. Furthermore, many physical parameters of the same machine might vary significantly due to temperature, humidity, positioning or its wear-and-tear in time. Domain adaptation Domain adaptation (DA) refers to a set of transfer learning techniques developed to update the data distribution in sim to match the real one through a mapping or regularization enforced by the task model. Many DA models, especially for image classification or end-to-end image-based RL task, are built on adversarial loss or GAN. Domain randomization With domain randomization (DR), we are able to create a variety of simulated environments with randomized properties and train a model that works across all of them. Likely this model can adapt to the real-world environment, as the real system is expected to be one sample in that rich distribution of training variations. Both DA and DR are unsupervised. Compared to DA which requires a decent amount of real data samples to capture the distribution, DR may need only a little or no real data. DR is the focus of this post.\nFig. 1. Conceptual illustrations of three approaches for sim2real transfer. What is Domain Randomization? To make the definition more general, let us call the environment that we have full access to (i.e. simulator) source domain and the environment that we would like to transfer the model to target domain (i.e. physical world). Training happens in the source domain. We can control a set of $N$ randomization parameters in the source domain $e_\\xi$ with a configuration $\\xi$, sampled from a randomization space, $\\xi \\in \\Xi \\subset \\mathbb{R}^N$.\nDuring policy training, episodes are collected from source domain with randomization applied. Thus the policy is exposed to a variety of environments and learns to generalize. The policy parameter $\\theta$ is trained to maximize the expected reward $R(.)$ average across a distribution of configurations:\n $$ \\theta^* = \\arg\\max_\\theta \\mathbb{E}_{\\xi \\sim \\Xi} [\\mathbb{E}_{\\pi_\\theta, \\tau \\sim e_\\xi} [R(\\tau)]] $$ where $\\tau_\\xi$ is a trajectory collected in source domain randomized with $\\xi$. In a way, \u0026ldquo;discrepancies between the source and target domains are modeled as variability in the source domain.\u0026quot; (quote from Peng et al. 2018).\nUniform Domain Randomization In the original form of DR (Tobin et al, 2017; Sadeghi et al. 2016), each randomization parameter $\\xi_i$ is bounded by an interval, $\\xi_i \\in [\\xi_i^\\text{low}, \\xi_i^\\text{high}], i=1,\\dots,N$ and each parameter is uniformly sampled within the range.\nThe randomization parameters can control appearances of the scene, including but not limited to the followings (see Fig. 2). A model trained on simulated and randomized images is able to transfer to real non-randomized images.\n Position, shape, and color of objects, Material texture, Lighting condition, Random noise added to images, Position, orientation, and field of view of the camera in the simulator. Fig. 2. Images captured in the training environment are randomized. (Image source: Tobin et al, 2017) Physical dynamics in the simulator can also be randomized (Peng et al. 2018). Studies have showed that a recurrent policy can adapt to different physical dynamics including the partially observable reality. A set of physical dynamics features include but are not limited to:\n Mass and dimensions of objects, Mass and dimensions of robot bodies, Damping, kp, friction of the joints, Gains for the PID controller (P term), Joint limit, Action delay, Observation noise. With visual and dynamics DR, at OpenAI Robotics, we were able to learn a policy that works on real dexterous robot hand (OpenAI, 2018). Our manipulation task is to teach the robot hand to rotate an object continously to achieve 50 successive random target orientations. The sim2real gap in this task is very large, due to (a) a high number of simultaneous contacts between the robot and the object and (b) imperfect simulation of object collision and other motions. At first, the policy could barely survive for more than 5 seconds without dropping the object. But with the help of DR, the policy evolved to work surprisingly well in reality eventually.\n Why does Domain Randomization Work? Now you may ask, why does domain randomization work so well? The idea sounds really simple. Here are two non-exclusive explanations I found most convincing.\nDR as Optimization One idea (Vuong, et al, 2019) is to view learning randomization parameters in DR as a bilevel optimization. Assuming we have access to the real environment $e_\\text{real}$ and the randomization config is sampled from a distribution parameterized by $\\phi$, $\\xi \\sim P_\\phi(\\xi)$, we would like to learn a distribution on which a policy $\\pi_\\theta$ is trained on can achieve maximal performance in $e_\\text{real}$:\n $$ \\begin{aligned} \u0026\\phi^* = \\arg\\min_{\\phi} \\mathcal{L}(\\pi_{\\theta^*(\\phi)}; e_\\text{real}) \\\\ \\text{where } \u0026\\theta^*(\\phi) = \\arg\\min_\\theta \\mathbb{E}_{\\xi \\sim P_\\phi(\\xi)}[\\mathcal{L}(\\pi_\\theta; e_\\xi)] \\end{aligned} $$ where $\\mathcal{L}(\\pi; e)$ is the loss function of policy $\\pi$ evaluated in the environment $e$.\nAlthough randomization ranges are hand-picked in uniform DR, it often involves domain knowledge and a couple rounds of trial-and-error adjustment based on the transfer performance. Essentially this is a manual optimization process on tuning $\\phi$ for the optimal $\\mathcal{L}(\\pi_{\\theta^*(\\phi)}; e_\\text{real})$.\nGuided domain randomization in the next section is largely inspired by this view, aiming to do bilevel optimization and learn the best parameter distribution automatically.\nDR as Meta-Learning In our learning dexterity project (OpenAI, 2018), we trained an LSTM policy to generalize across different environmental dynamics. We observed that once a robot achieved the first rotation, the time it needed for the following successes was much shorter. Also, a FF policy without memory was found not able to transfer to a physical robot. Both are evidence of the policy dynamically learning and adapting to a new environment.\nIn some ways, domain randomization composes a collection of different tasks. Memory in the recurrent network empowers the policy to achieve meta-learning across tasks and further work on a real-world setting.\nGuided Domain Randomization The vanilla DR assumes no access to the real data, and thus the randomization config is sampled as broadly and uniformly as possible in sim, hoping that the real environment could be covered under this broad distribution. It is reasonable to think of a more sophisticated strategy \u0026mdash; replacing uniform sampling with guidance from task performance, real data, or simulator.\nOne motivation for guided DR is to save computation resources by avoiding training models in unrealistic environments. Another is to avoid infeasible solutions that might arise from overly wide randomization distributions and thus might hinder successful policy learning.\nOptimization for Task Performance Say we train a family of policies with different randomization parameters $\\xi \\sim P_\\phi(\\xi)$, where $P_\\xi$ is the distribution for $\\xi$ parameterized by $\\phi$. Later we decide to try every one of them on the downstream task in the target domain (i.e. control a robot in reality or evaluate on a validation set) to collect feedback. This feedback tells us how good a configuration $\\xi$ is and provides signals for optimizing $\\phi$.\nInspired by NAS, AutoAugment (Cubuk, et al. 2018) frames the problem of learning best data augmentation operations (i.e. shearing, rotation, invert, etc.) for image classification as an RL problem. Note that AutoAugment is not proposed for sim2real transfer, but falls in the bucket of DR guided by task performance. Individual augmentation configuration is tested on the evaluation set and the performance improvement is used as a reward to train a PPO policy. This policy outputs different augmentation strategies for different datasets; for example, for CIFAR-10 AutoAugment mostly picks color-based transformations, while ImageNet prefers geometric based.\nRuiz (2019) considered the task feedback as reward in RL problem and proposed a RL-based method, named \u0026ldquo;learning to simulate\u0026rdquo;, for adjusting $\\xi$. A policy is trained to predict $\\xi$ using performance metrics on the validation data of the main task as rewards, which is modeled as a multivariate Gaussian. Overall the idea is similar to AutoAugment, applying NAS on data generation. According to their experiments, even if the main task model is not converged, it still can provide a reasonable signal to the data generation policy.\nFig. 3. An overview of the \"learning to simulate\" approach. (Image source: Ruiz (2019)) Evolutionary algorithm is another way to go, where the feedback is treated as fitness for guiding evolution (Yu et al, 2019). In this study, they used CMA-ES (covariance matrix adaptation evolution strategy) while fitness is the performance of a $\\xi$-conditional policy in target environment. In the appendix, they compared CMA-ES with other ways of modeling the dynamics of $\\xi$, including Bayesian optimization or a neural network. The main claim was those methods are not as stable or sample efficient as CMA-ES. Interestly, when modeling $P(\\xi)$ as a neural network, LSTM is found to notably outperform FF.\nSome believe that sim2real gap is a combination of appearance gap and content gap; i.e. most GAN-inspired DA models focus on appearance gap. Meta-Sim (Kar, et al. 2019) aims to close the content gap by generating task-specific synthetic datasets. Meta-Sim uses self-driving car training as an example and thus the scene could be very complicated. In this case, the synthetic scenes are parameterized by a hierarchy of objects with properties (i.e., location, color) as well as relationships between objects. The hierarchy is specified by a probabilistic scene grammar akin to structure domain randomization (SDR; Prakash et al., 2018) and it is assumed to be known beforehand. A model $G$ is trained to augment the distribution of scene properties $s$ by following:\n Learn the prior first: pre-train $G$ to learn the identity function $G(s) = s$. Minimize MMD loss between the real and sim data distributions. This involves backpropagation through non-differentiable renderer. The paper computes it numerically by perturbing the attributes of $G(s)$. Minimize REINFORCE task loss when trained on synthetic data but evaluated on real data. Again, very similar to AutoAugment. Unfortunately, this family of methods are not suitable for sim2real case. Either an RL policy or an EA model requires a large number of real samples. And it is really expensive to include real-time feedback collection on a physical robot into the training loop. Whether you want to trade less computation resource for real data collection would depend on your task.\nMatch Real Data Distribution Using real data to guide domain randomization feels a lot like doing system identification or DA. The core idea behind DA is to improve the synthetic data to match the real data distribution. In the case of real-data-guided DR, we would like to learn the randomization parameters $\\xi$ that bring the state distribution in simulator close to the state distribution in the real world.\nThe SimOpt model (Chebotar et al, 2019) is trained under an initial randomization distribution $P_\\phi(\\xi)$ first, getting a policy $\\pi_{\\theta, P_\\phi}$. Then this policy is deployed on both simulator and physical robot to collect trajectories $\\tau_\\xi$ and $\\tau_\\text{real}$ respectively. The optimization objective is to minimize the discrepancy between sim and real trajectories:\n $$ \\phi^* = \\arg\\min_{\\phi}\\mathbb{E}_{\\xi \\sim P_\\phi(\\xi)} [\\mathbb{E}_{\\pi_{\\theta, P_\\phi}} [D(\\tau_\\text{sim}, \\tau_\\text{real})]] $$ where $D(.)$ is a trajectory-based discrepancy measure. Like the \u0026ldquo;Learning to simulate\u0026rdquo; paper, SimOpt also has to solve the tricky problem of how to propagate gradient through non-differentiable simulator. It used a method called relative entropy policy search, see paper for more details.\nFig. 4. An overview of the SimOpt framework. (Image source: Chebotar et al, 2019) RCAN (James et al., 2019), short for \u0026ldquo;Randomized-to-Canonical Adaptation Networks\u0026rdquo;, is a nice combination of DA and DR for end-to-end RL tasks. An image-conditional GAN (cGAN) is trained in sim to translate a domain-randomized image into a non-randomized version (aka \u0026ldquo;canonical version\u0026rdquo;). Later the same model is used to translate real images into corresponding simulated version so that the agent would consume consistent observation as what it has encountered in training. Still, the underlying assumption is that the distribution of domain-randomized sim images is broad enough to cover real-world samples.\nFig. 5. RCAN is an image-conditional generator that can convert a domain-randomized or real image into its corresponding non-randomized simulator version. (Image source: James et al., 2019) The RL model is trained end-to-end in a simulator to do vision-based robot arm grasping. Randomization is applied at each timestep, including the position of tray divider, objects to grasp, random textures, as well as the position, direction, and color of the lighting. The canonical version is the default simulator look. RCAN is trying to learn a generator\n$G$: randomized image $\\to$ {canonical image, segmentation, depth}\nwhere segmentation masks and depth images are used as auxiliary tasks. RCAN had a better zero-shot transfer compared to uniform DR, although both were shown to be worse than the model trained on only real images. Conceptually, RCAN operates in a reverse direction of GraspGAN which translates synthetic images into real ones by domain adaptation.\nGuided by Data in Simulator Network-driven domain randomization (Zakharov et al., 2019), also known as DeceptionNet, is motivated by learning which randomizations are actually useful to bridge the domain gap for image classification tasks.\nRandomization is applied through a set of deception modules with encoder-decoder architecture. The deception modules are specifically designed to transform images; such as change backgrounds, add distortion, change lightings, etc. The other recognition network handles the main task by running classification on transformed images.\nThe training involves two steps:\n With the recognition network fixed, maximize the difference between the prediction and the labels by applying reversed gradients during backpropagation. So that the deception module can learn the most confusing tricks. With the deception modules fixed, train the recognition network with input images altered. Fig. 6. How DeceptionNet works. (Image source: Zakharov et al., 2019) The feedback for training deception modules is provided by the downstream classifier. But rather than trying to maximize the task performance like the section above, the randomization modules aim to create harder cases. One big disadvantage is you need to manually design different deception modules for different datasets or tasks, making it not easily scalable. Given the fact that it is zero-shot, the results are still worse than SOTA DA methods on MNIST and LineMOD.\nSimilarly, Active domain randomization (ADR; Mehta et al., 2019) also relies on sim data to create harder training samples. ADR searches for the most informative environment variations within the given randomization ranges, where the informativeness is measured as the discrepancies of policy rollouts in randomized and reference (original, non-randomized) environment instances. Sounds a bit like SimOpt? Well, noted that SimOpt measures the discrepancy between sim and real rollouts, while ADR measures between randomized and non-randomized sim, avoiding the expensive real data collection part.\nFig. 7. How active domain randomization (ADR) works. (Image source: Mehta et al., 2019) Precisely the training happens as follows:\n Given a policy, run it on both reference and randomized envs and collect two sets of trajectories respectively. Train a discriminator model to tell whether a rollout trajectory is randomized apart from reference run. The predicted $\\log p$ (probability of being randomized) is used as reward. The more different randomized and reference rollouts, the easier the prediction, the higher the reward. The intuition is that if an environment is easy, the same policy agent can produce similar trajectories as in the reference one. Then the model should reward and explore hard environments by encouraging different behaviors. The reward by discriminator is fed into Stein Variational Policy Gradient (SVPG) particles, outputting a diverse set of randomization configurations. The idea of ADR is very appealing with two small concerns. The similarity between trajectories might not be a good way to measure the env difficulty when running a stochastic policy. The sim2real results look unfortunately not as exciting, but the paper pointed out the win being ADR explores a smaller range of randomization parameters.\n Cited as:\n@article{weng2019DR, title = \u0026quot;Domain Randomization for Sim2Real Transfer\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2019\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2019-05-05-domain-randomization/\u0026quot; } Overall, after reading this post, I hope you like domain randomization as much as I do :).\nReferences [1] Josh Tobin, et al. \u0026ldquo;Domain randomization for transferring deep neural networks from simulation to the real world.\u0026quot; IROS, 2017.\n[2] Fereshteh Sadeghi and Sergey Levine. \u0026ldquo;CAD2RL: Real single-image flight without a single real image.\u0026quot; arXiv:1611.04201 (2016).\n[3] Xue Bin Peng, et al. \u0026ldquo;Sim-to-real transfer of robotic control with dynamics randomization.\u0026quot; ICRA, 2018.\n[4] Nataniel Ruiz, et al. \u0026ldquo;Learning to Simulate.\u0026quot; ICLR 2019\n[5] OpenAI. \u0026ldquo;Learning Dexterous In-Hand Manipulation.\u0026quot; arXiv:1808.00177 (2018).\n[6] OpenAI Blog. \u0026ldquo;Learning dexterity\u0026rdquo; July 30, 2018.\n[7] Quan Vuong, et al. \u0026ldquo;How to pick the domain randomization parameters for sim-to-real transfer of reinforcement learning policies?.\u0026quot; arXiv:1903.11774 (2019).\n[8] Ekin D. Cubuk, et al. \u0026ldquo;AutoAugment: Learning augmentation policies from data.\u0026quot; arXiv:1805.09501 (2018).\n[9] Wenhao Yu et al. \u0026ldquo;Policy Transfer with Strategy Optimization.\u0026quot; ICLR 2019\n[10] Yevgen Chebotar et al. \u0026ldquo;Closing the Sim-to-Real Loop: Adapting Simulation Randomization with Real World Experience.\u0026quot; Arxiv: 1810.05687 (2019).\n[11] Stephen James et al. \u0026ldquo;Sim-to-real via sim-to-sim: Data-efficient robotic grasping via randomized-to-canonical adaptation networks\u0026rdquo; CVPR 2019.\n[12] Bhairav Mehta et al. \u0026ldquo;Active Domain Randomization\u0026rdquo; arXiv:1904.04762\n[13] Sergey Zakharov,et al. \u0026ldquo;DeceptionNet: Network-Driven Domain Randomization.\u0026quot; arXiv:1904.02750 (2019).\n[14] Amlan Kar, et al. \u0026ldquo;Meta-Sim: Learning to Generate Synthetic Datasets.\u0026quot; arXiv:1904.11621 (2019).\n[15] Aayush Prakash, et al. \u0026ldquo;Structured Domain Randomization: Bridging the Reality Gap by Context-Aware Synthetic Data.\u0026quot; arXiv:1810.10093 (2018).\n","permalink":"https://lilianweng.github.io/posts/2019-05-05-domain-randomization/","summary":"In Robotics, one of the hardest problems is how to make your model transfer to the real world. Due to the sample inefficiency of deep RL algorithms and the cost of data collection on real robots, we often need to train models in a simulator which theoretically provides an infinite amount of data. However, the reality gap between the simulator and the physical world often leads to failure when working with physical robots.","title":"Domain Randomization for Sim2Real Transfer"},{"content":"[Updated on 2019-05-27: add the section on Lottery Ticket Hypothesis.]\nIf you are like me, entering into the field of deep learning with experience in traditional machine learning, you may often ponder over this question: Since a typical deep neural network has so many parameters and training error can easily be perfect, it should surely suffer from substantial overfitting. How could it be ever generalized to out-of-sample data points?\nThe effort in understanding why deep neural networks can generalize somehow reminds me of this interesting paper on System Biology \u0026mdash; \u0026ldquo;Can a biologist fix a radio?\u0026quot; (Lazebnik, 2002). If a biologist intends to fix a radio machine like how she works on a biological system, life could be hard. Because the full mechanism of the radio system is not revealed, poking small local functionalities might give some hints but it can hardly present all the interactions within the system, let alone the entire working flow. No matter whether you think it is relevant to DL, it is a very fun read.\nI would like to discuss a couple of papers on generalizability and complexity measurement of deep learning models in the post. Hopefully, it could shed light on your thinking path towards the understanding of why DNN can generalize.\nClassic Theorems on Compression and Model Selection Let\u0026rsquo;s say we have a classification problem and a dataset, we can develop many models to solve it, from fitting a simple linear regression to memorizing the full dataset in disk space. Which one is better? If we only care about the accuracy over training data (especially given that testing data is likely unknown), the memorization approach seems to be the best \u0026mdash; well, it doesn\u0026rsquo;t sound right.\nThere are many classic theorems to guide us when deciding what types of properties a good model should possess in such scenarios.\nOccam\u0026rsquo;s Razor Occam\u0026rsquo;s Razor is an informal principle for problem-solving, proposed by William of Ockham in the 14th century:\n \u0026ldquo;Simpler solutions are more likely to be correct than complex ones.\u0026rdquo;\n The statement is extremely powerful when we are facing multiple candidates of underlying theories to explain the world and have to pick one. Too many unnecessary assumptions might seem to be plausible for one problem, but harder to be generalized to other complications or to eventually lead to the basic principles of the universe.\nThink of this, it took people hundreds of years to figure out that the sky is blue in the daytime but reddish at sunset are because of the same reason (Rayleigh scattering), although two phenomena look very different. People must have proposed many other explanations for them separately but the unified and simple version won eventually.\nMinimum Description Length principle The principle of Occam\u0026rsquo;s Razor can be similarly applied to machine learning models. A formalized version of such concept is called the Minimum Description Length (MDL) principle, used for comparing competing models / explanations given data observed.\n \u0026ldquo;Comprehension is compression.\u0026rdquo;\n The fundamental idea in MDL is to view learning as data compression. By compressing the data, we need to discover regularity or patterns in the data with the high potentiality to generalize to unseen samples. Information bottleneck theory believes that a deep neural network is trained first to represent the data by minimizing the generalization error and then learn to compress this representation by trimming noise.\nMeanwhile, MDL considers the model description as part of the compression delivery, so the model cannot be arbitrarily large.\nA two-part version of MDL principle states that: Let $\\mathcal{H}^{(1)}, \\mathcal{H}^{(2)}, \\dots$ be a list of models that can explain the dataset $\\mathcal{D}$. The best hypothesis among them should be the one that minimizes the sum:\n $$ \\mathcal{H}^\\text{best} = \\arg\\min_\\mathcal{H} [L(\\mathcal{H}) + L(\\mathcal{D}\\vert\\mathcal{H})] $$ $L(\\mathcal{H})$ is the length of the description of model $\\mathcal{H}$ in bits. $L(\\mathcal{D}\\vert\\mathcal{H})$ is the length of the description of the data $\\mathcal{D}$ in bits when encoded with $\\mathcal{H}$. In simple words, the best model is the smallest model containing the encoded data and the model itself. Following this criterion, the memorization approach I proposed at the beginning of the section sounds horrible no matter how good accuracy it can achieve on the training data.\nPeople might argue Occam\u0026rsquo;s Razor is wrong, as given the real world can be arbitrarily complicated, why do we have to find simple models? One interesting view by MDL is to consider models as \u0026ldquo;languages\u0026rdquo; instead of fundamental generative theorems. We would like to find good compression strategies to describe regularity in a small set of samples, and they do not have to be the \u0026ldquo;real\u0026rdquo; generative model for explaining the phenomenon. Models can be wrong but still useful (i.e., think of any Bayesian prior).\nKolmogorov Complexity Kolmogorov Complexity relies on the concept of modern computers to define the algorithmic (descriptive) complexity of an object: It is the length of the shortest binary computer program that describes the object. Following MDL, a computer is essentially the most general form of data decompressor.\nThe formal definition of Kolmogorov Complexity states that: Given a universal computer $\\mathcal{U}$ and a program $p$, let\u0026rsquo;s denote $\\mathcal{U}(p)$ as the output of the computer processing the program and $L(p)$ as the descriptive length of the program. Then Kolmogorov Complexity $K_\\mathcal{U}$ of a string $s$ with respect to a universal computer $\\mathcal{U}$ is:\n $$ K_\\mathcal{U}(s) = \\min_{p: \\mathcal{U}(p)=s} L(p) $$ Note that a universal computer is one that can mimic the actions of any other computers. All modern computers are universal as they can all be reduced to Turing machines. The definition is universal no matter which computers we are using, because another universal computer can always be programmed to clone the behavior of $\\mathcal{U}$, while encoding this clone program is just a constant.\nThere are a lot of connections between Kolmogorov Complexity and Shannon Information Theory, as both are tied to universal coding. It is an amazing fact that the expected Kolmogorov Complexity of a random variable is approximately equal to its Shannon entropy (see Sec 2.3 of the report). More on this topic is out of the scope here, but there are many interesting readings online. Help yourself :)\nSolomonoff\u0026rsquo;s Inference Theory Another mathematical formalization of Occam\u0026rsquo;s Razor is Solomonoff\u0026rsquo;s theory of universal inductive inference (Solomonoff, 1964). The principle is to favor models that correspond to the \u0026ldquo;shortest program\u0026rdquo; to produce the training data, based on its Kolmogorov complexity\nExpressive Power of DL Models Deep neural networks have an extremely large number of parameters compared to the traditional statistical models. If we use MDL to measure the complexity of a deep neural network and consider the number of parameters as the model description length, it would look awful. The model description $L(\\mathcal{H})$ can easily grow out of control.\nHowever, having numerous parameters is necessary for a neural network to obtain high expressivity power. Because of its great capability to capture any flexible data representation, deep neural networks have achieved great success in many applications.\nUniversal Approximation Theorem The Universal Approximation Theorem states that a feedforward network with: 1) a linear output layer, 2) at least one hidden layer containing a finite number of neurons and 3) some activation function can approximate any continuous functions on a compact subset of $\\mathbb{R}^n$ to arbitrary accuracy. The theorem was first proved for sigmoid activation function (Cybenko, 1989). Later it was shown that the universal approximation property is not specific to the choice of activation (Hornik, 1991) but the multilayer feedforward architecture.\nAlthough a feedforward network with a single layer is sufficient to represent any function, the width has to be exponentially large. The universal approximation theorem does not guarantee whether the model can be learned or generalized properly. Often, adding more layers helps to reduce the number of hidden neurons needed in a shallow network.\nTo take advantage of the universal approximation theorem, we can always find a neural network to represent the target function with error under any desired threshold, but we need to pay the price \u0026mdash; the network might grow super large.\nProof: Finite Sample Expressivity of Two-layer NN The Universal Approximation Theorem we have discussed so far does not consider a finite sample set. Zhang, et al. (2017) provided a neat proof on the finite-sample expressivity of two-layer neural networks.\nA neural network $C$ can represent any function given a sample size $n$ in $d$ dimensions if: For every finite sample set $S \\subseteq \\mathbb{R}^d$ with $\\vert S \\vert = n$ and every function defined on this sample set: $f: S \\mapsto \\mathbb{R}$, we can find a set of weight configuration for $C$ so that $C(\\boldsymbol{x}) = f(\\boldsymbol{x}), \\forall \\boldsymbol{x} \\in S$.\nThe paper proposed a theorem:\n There exists a two-layer neural network with ReLU activations and $2n + d$ weights that can represent any function on a sample of size $n$ in $d$ dimensions.\n Proof. First we would like to construct a two-layer neural network $C: \\mathbb{R}^d \\mapsto \\mathbb{R}$. The input is a $d$-dimensional vector, $\\boldsymbol{x} \\in \\mathbb{R}^d$. The hidden layer has $h$ hidden units, associated with a weight matrix $\\mathbf{W} \\in \\mathbb{R}^{d\\times h}$, a bias vector $-\\mathbf{b} \\in \\mathbb{R}^h$ and ReLU activation function. The second layer outputs a scalar value with weight vector $\\boldsymbol{v} \\in \\mathbb{R}^h$ and zero biases.\nThe output of network $C$ for a input vector $\\boldsymbol{x}$ can be represented as follows:\n $$ C(\\boldsymbol{x}) = \\boldsymbol{v} \\max\\{ \\boldsymbol{x}\\mathbf{W} - \\boldsymbol{b}, 0\\}^\\top = \\sum_{i=1}^h v_i \\max\\{\\boldsymbol{x}\\boldsymbol{W}_{(:,i)} - b_i, 0\\} $$ where $\\boldsymbol{W}_{(:,i)}$ is the $i$-th column in the $d \\times h$ matrix.\nGiven a sample set $S = \\{\\boldsymbol{x}_1, \\dots, \\boldsymbol{x}_n\\}$ and target values $\\boldsymbol{y} = \\{y_1, \\dots, y_n \\}$, we would like to find proper weights $\\mathbf{W} \\in \\mathbb{R}^{d\\times h}$, $\\boldsymbol{b}, \\boldsymbol{v} \\in \\mathbb{R}^h$ so that $C(\\boldsymbol{x}_i) = y_i, \\forall i=1,\\dots,n$.\nLet\u0026rsquo;s combine all sample points into one batch as one input matrix $\\mathbf{X} \\in \\mathbb{R}^{n \\times d}$. If set $h=n$, $\\mathbf{X}\\mathbf{W} - \\boldsymbol{b}$ would be a square matrix of size $n \\times n$.\n $$ \\mathbf{M}_\\text{ReLU} = \\max\\{\\mathbf{X}\\mathbf{W} - \\boldsymbol{b}, 0 \\} = \\begin{bmatrix} \\boldsymbol{x}_1\\mathbf{W} - \\boldsymbol{b} \\\\ \\dots \\\\ \\boldsymbol{x}_n\\mathbf{W} - \\boldsymbol{b} \\\\ \\end{bmatrix} = [\\boldsymbol{x}_i\\boldsymbol{W}_{(:,j)} - b_j]_{i \\times j} $$ We can simplify $\\mathbf{W}$ to have the same column vectors across all the columns:\n $$ \\mathbf{W}_{(:,j)} = \\boldsymbol{w} \\in \\mathbb{R}^{d}, \\forall j = 1, \\dots, n $$ Let $a_i = \\boldsymbol{x}_i \\boldsymbol{w}$, we would like to find a suitable $\\boldsymbol{w}$ and $\\boldsymbol{b}$ such that $b_1 \u0026lt; a_1 \u0026lt; b_2 \u0026lt; a_2 \u0026lt; \\dots \u0026lt; b_n \u0026lt; a_n$. This is always achievable because we try to solve $n+d$ unknown variables with $n$ constraints and $\\boldsymbol{x}_i$ are independent (i.e. pick a random $\\boldsymbol{w}$, sort $\\boldsymbol{x}_i \\boldsymbol{w}$ and then set $b_j$\u0026rsquo;s as values in between). Then $\\mathbf{M}_\\text{ReLU}$ becomes a lower triangular matrix:\n $$ \\mathbf{M}_\\text{ReLU} = [a_i - b_j]_{i \\times j} = \\begin{bmatrix} a_1 - b_1 \u0026 0 \u0026 0 \u0026 \\dots \u0026 0 \\\\ \\vdots \u0026 \\ddots \u0026 \u0026 \u0026 \\vdots \\\\ a_i - b_1 \u0026 \\dots \u0026 a_i - b_i \u0026 \\dots \u0026 0\\\\ \\vdots \u0026 \u0026 \u0026 \\ddots \u0026 \\vdots \\\\ a_n - b_1 \u0026 a_n - b_2 \u0026 \\dots \u0026 \\dots \u0026 a_n - b_n \\\\ \\end{bmatrix} $$ It is a nonsingular square matrix as $\\det(\\mathbf{M}_\\text{ReLU}) \\neq 0$, so we can always find suitable $\\boldsymbol{v}$ to solve $\\boldsymbol{v}\\mathbf{M}_\\text{ReLU}=\\boldsymbol{y}$ (In other words, the column space of $\\mathbf{M}_\\text{ReLU}$ is all of $\\mathbb{R}^n$ and we can find a linear combination of column vectors to obtain any $\\boldsymbol{y}$).\nDeep NN can Learn Random Noise As we know two-layer neural networks are universal approximators, it is less surprising to see that they are able to learn unstructured random noise perfectly, as shown in Zhang, et al. (2017). If labels of image classification dataset are randomly shuffled, the high expressivity power of deep neural networks can still empower them to achieve near-zero training loss. These results do not change with regularization terms added.\nFig. 1. Fit models on CIFAR10 with random labels or random pixels: (a) learning curves; (b-c) label corruption ratio is the percentage of randomly shuffled labels. (Image source: Zhang et al. 2017) Are Deep Learning Models Dramatically Overfitted? Deep learning models are heavily over-parameterized and can often get to perfect results on training data. In the traditional view, like bias-variance trade-offs, this could be a disaster that nothing may generalize to the unseen test data. However, as is often the case, such \u0026ldquo;overfitted\u0026rdquo; (training error = 0) deep learning models still present a decent performance on out-of-sample test data. Hmm … interesting and why?\nModern Risk Curve for Deep Learning The traditional machine learning uses the following U-shape risk curve to measure the bias-variance trade-offs and quantify how generalizable a model is. If I get asked how to tell whether a model is overfitted, this would be the first thing popping into my mind.\nAs the model turns larger (more parameters added), the training error decreases to close to zero, but the test error (generalization error) starts to increase once the model complexity grows to pass the threshold between \u0026ldquo;underfitting\u0026rdquo; and \u0026ldquo;overfitting\u0026rdquo;. In a way, this is well aligned with Occam\u0026rsquo;s Razor.\nFig. 2. U-shaped bias-variance risk curve. (Image source: (left) paper (right) fig. 6 of this post) Unfortunately this does not apply to deep learning models. Belkin et al. (2018) reconciled the traditional bias-variance trade-offs and proposed a new double-U-shaped risk curve for deep neural networks. Once the number of network parameters is high enough, the risk curve enters another regime.\nFig. 3. A new double-U-shaped bias-variance risk curve for deep neural networks. (Image source: original paper) The paper claimed that it is likely due to two reasons:\n The number of parameters is not a good measure of inductive bias, defined as the set of assumptions of a learning algorithm used to predict for unknown samples. See more discussion on DL model complexity in later sections. Equipped with a larger model, we might be able to discover larger function classes and further find interpolating functions that have smaller norm and are thus \u0026ldquo;simpler\u0026rdquo;. The double-U-shaped risk curve was observed empirically, as shown in the paper. However I was struggling quite a bit to reproduce the results. There are some signs of life, but in order to generate a pretty smooth curve similar to the theorem, many details in the experiment have to be taken care of.\nFig. 4. Training and evaluation errors of a one hidden layer fc network of different numbers of hidden units, trained on 4000 data points sampled from MNIST. (Image source: original paper) Regularization is not the Key to Generalization Regularization is a common way to control overfitting and improve model generalization performance. Interestingly some research (Zhang, et al. 2017) has shown that explicit regularization (i.e. data augmentation, weight decay and dropout) is neither necessary or sufficient for reducing generalization error.\nTaking the Inception model trained on CIFAR10 as an example (see Fig. 5), regularization techniques help with out-of-sample generalization but not much. No single regularization seems to be critical independent of other terms. Thus, it is unlikely that regularizers are the fundamental reason for generalization.\nFig. 5. The accuracy of Inception model trained on CIFAR10 with different combinations of taking on or off data augmentation and weight decay. (Image source: Table 1 in the original paper) Intrinsic Dimension The number of parameters is not correlated with model overfitting in the field of deep learning, suggesting that parameter counting cannot indicate the true complexity of deep neural networks.\nApart from parameter counting, researchers have proposed many ways to quantify the complexity of these models, such as the number of degrees of freedom of models (Gao \u0026amp; Jojic, 2016), or prequential code (Blier \u0026amp; Ollivier, 2018).\nI would like to discuss a recent method on this matter, named intrinsic dimension (Li et al, 2018). Intrinsic dimension is intuitive, easy to measure, while still revealing many interesting properties of models of different sizes.\nConsidering a neural network with a great number of parameters, forming a high-dimensional parameter space, the learning happens on this high-dimensional objective landscape. The shape of the parameter space manifold is critical. For example, a smoother manifold is beneficial for optimization by providing more predictive gradients and allowing for larger learning rates\u0026mdash;this was claimed to be the reason why batch normalization has succeeded in stabilizing training (Santurkar, et al, 2019).\nEven though the parameter space is huge, fortunately we don\u0026rsquo;t have to worry too much about the optimization process getting stuck in local optima, as it has been shown that local optimal points in the objective landscape almost always lay in saddle-points rather than valleys. In other words, there is always a subset of dimensions containing paths to leave local optima and keep on exploring.\nFig. 6. Illustrations of various types of critical points on the parameter optimization landscape. (Image source: here) One intuition behind the measurement of intrinsic dimension is that, since the parameter space has such high dimensionality, it is probably not necessary to exploit all the dimensions to learn efficiently. If we only travel through a slice of objective landscape and still can learn a good solution, the complexity of the resulting model is likely lower than what it appears to be by parameter-counting. This is essentially what intrinsic dimension tries to assess.\nSay a model has $D$ dimensions and its parameters are denoted as $\\theta^{(D)}$. For learning, a smaller $d$-dimensional subspace is randomly sampled, $\\theta^{(d)}$, where $d \u0026lt; D$. During one optimization update, rather than taking a gradient step according to all $D$ dimensions, only the smaller subspace $\\theta^{(d)}$ is used and remapped to update model parameters.\nFig. 7. Illustration of parameter vectors for direct optimization when $D=3$. (Image source: original paper) The gradient update formula looks like the follows:\n $$ \\theta^{(D)} = \\theta_0^{(D)} + \\mathbf{P} \\theta^{(d)} $$ where $\\theta_0^{(D)}$ are the initialization values and $\\mathbf{P}$ is a $D \\times d$ projection matrix that is randomly sampled before training. Both $\\theta_0^{(D)}$ and $\\mathbf{P}$ are not trainable and fixed during training. $\\theta^{(d)}$ is initialized as all zeros.\nBy searching through the value of $d = 1, 2, \\dots, D$, the corresponding $d$ when the solution emerges is defined as the intrinsic dimension.\nIt turns out many problems have much smaller intrinsic dimensions than the number of parameters. For example, on CIFAR10 image classification, a fully-connected network with 650k+ parameters has only 9k intrinsic dimension and a convolutional network containing 62k parameters has an even lower intrinsic dimension of 2.9k.\nFig. 8. The measured intrinsic dimensions $d$ for various models achieving 90% of the best performance. (Image source: original paper) The measurement of intrinsic dimensions suggests that deep learning models are significantly simpler than what they might appear to be.\nHeterogeneous Layer Robustness Zhang et al. (2019) investigated the role of parameters in different layers. The fundamental question raised by the paper is: \u0026ldquo;are all layers created equal?\u0026quot; The short answer is: No. The model is more sensitive to changes in some layers but not others.\nThe paper proposed two types of operations that can be applied to parameters of the $\\ell$-th layer, $\\ell = 1, \\dots, L$, at time $t$, $\\theta^{(\\ell)}_t$ to test their impacts on model robustness:\n Re-initialization: Reset the parameters to the initial values, $\\theta^{(\\ell)}_t \\leftarrow \\theta^{(\\ell)}_0$. The performance of a network in which layer $\\ell$ was re-initialized is referred to as the re-initialization robustness of layer $\\ell$.\n Re-randomization: Re-sampling the layer\u0026rsquo;s parameters at random, $\\theta^{(\\ell)}_t \\leftarrow \\tilde{\\theta}^{(\\ell)} \\sim \\mathcal{P}^{(\\ell)}$. The corresponding network performance is called the re-randomization robustness of layer $\\ell$.\n Layers can be categorized into two categories with the help of these two operations:\n Robust Layers: The network has no or only negligible performance degradation after re-initializing or re-randomizing the layer. Critical Layers: Otherwise. Similar patterns are observed on fully-connected and convolutional networks. Re-randomizing any of the layers completely destroys the model performance, as the prediction drops to random guessing immediately. More interestingly and surprisingly, when applying re-initialization, only the first or the first few layers (those closest to the input layer) are critical, while re-initializing higher levels causes only negligible decrease in performance.\nFig. 9. (a) A fc network trained on MNIST. Each row corresponds to one layer in the network. The first column is re-randomization robustness of each layer and the rest of the columns indicate re-initialization robustness at different training time. (b) VGG11 model (conv net) trained on CIFAR 10. Similar representation as in (a) but rows and columns are transposed. (Image source: original paper) ResNet is able to use shortcuts between non-adjacent layers to re-distribute the sensitive layers across the networks rather than just at the bottom. With the help of residual block architecture, the network can evenly be robust to re-randomization. Only the first layer of each residual block is still sensitive to both re-initialization and re-randomization. If we consider each residual block as a local sub-network, the robustness pattern resembles the fc and conv nets above.\nFig. 10. Re-randomization (first row) and re-initialization (the reset rows) robustness of layers in ResNet-50 model trained on CIFAR10. (Image source: original paper) Based on the fact that many top layers in deep neural networks are not critical to the model performance after re-initialization, the paper loosely concluded that:\n \u0026ldquo;Over-capacitated deep networks trained with stochastic gradient have low-complexity due to self-restricting the number of critical layers.\u0026rdquo;\n We can consider re-initialization as a way to reduce the effective number of parameters, and thus the observation is aligned with what intrinsic dimension has demonstrated.\nThe Lottery Ticket Hypothesis The lottery ticket hypothesis (Frankle \u0026amp; Carbin, 2019) is another intriguing and inspiring discovery, supporting that only a subset of network parameters have impact on the model performance and thus the network is not overfitted. The lottery ticket hypothesis states that a randomly initialized, dense, feed-forward network contains a pool of subnetworks and among them only a subset are \u0026ldquo;winning tickets\u0026rdquo; which can achieve the optimal performance when trained in isolation.\nThe idea is motivated by network pruning techniques \u0026mdash; removing unnecessary weights (i.e. tiny weights that are almost negligible) without harming the model performance. Although the final network size can be reduced dramatically, it is hard to train such a pruned network architecture successfully from scratch. It feels like in order to successfully train a neural network, we need a large number of parameters, but we don\u0026rsquo;t need that many parameters to keep the accuracy high once the model is trained. Why is that?\nThe lottery ticket hypothesis did the following experiments:\n Randomly initialize a dense feed-forward network with initialization values $\\theta_0$; Train the network for multiple iterations to achieve a good performance with parameter config $\\theta$; Run pruning on $\\theta$ and creating a mask $m$. The \u0026ldquo;winning ticket\u0026rdquo; initialization config is $m \\odot \\theta_0$. Only training the small \u0026ldquo;winning ticket\u0026rdquo; subset of parameters with the initial values as found in step 1, the model is able to achieve the same level of accuracy as in step 2. It turns out a large parameter space is not needed in the final solution representation, but needed for training as it provides a big pool of initialization configs of many much smaller subnetworks.\nThe lottery ticket hypothesis opens a new perspective about interpreting and dissecting deep neural network results. Many interesting following-up works are on the way.\nExperiments After seeing all the interesting findings above, it should be pretty fun to reproduce them. Some results are easily to reproduce than others. Details are described below. My code is available on github lilianweng/generalization-experiment.\nNew Risk Curve for DL Models\nThis is the trickiest one to reproduce. The authors did give me a lot of good advice and I appreciate it a lot. Here are a couple of noticeable settings in their experiments:\n There are no regularization terms like weight decay, dropout. In Fig 3, the training set contains 4k samples. It is only sampled once and fixed for all the models. The evaluation uses the full MNIST test set. Each network is trained for a long time to achieve near-zero training risk. The learning rate is adjusted differently for models of different sizes. To make the model less sensitive to the initialization in the under-parameterization region, their experiments adopted a \u0026ldquo;weight reuse\u0026rdquo; scheme: the parameters obtained from training a smaller neural network are used as initialization for training larger networks. I did not train or tune each model long enough to get perfect training performance, but evaluation error indeed shows a special twist around the interpolation threshold, different from training error. For example, for MNIST, the threshold is the number of training samples times the number of classes (10), that is 40000.\nThe x-axis is the number of model parameters: (28 * 28 + 1) * num. units + num. units * 10, in logarithm.\nLayers are not Created Equal\nThis one is fairly easy to reproduce. See my implementation here.\nIn the first experiment, I used a three-layer fc networks with 256 units in each layer. Layer 0 is the input layer while layer 3 is the output. The network is trained on MNIST for 100 epochs.\nIn the second experiment, I used a four-layer fc networks with 128 units in each layer. Other settings are the same as experiment 1.\nIntrinsic Dimension Measurement\nTo correctly map the $d$-dimensional subspace to the full parameter space, the projection matrix $\\mathbf{P}$ should have orthogonal columns. Because the production $\\mathbf{P}\\theta^{(d)}$ is the sum of columns of $\\mathbf{P}$ scaled by corresponding scalar values in the $d$-dim vector, $\\sum_{i=1}^d \\theta^{(d)}_i \\mathbf{P}^\\top_{(:,i)}$, it is better to fully utilize the subspace with orthogonal columns in $\\mathbf{P}$.\nMy implementation follows a naive approach by sampling a large matrix with independent entries from a standard normal distribution. The columns are expected to be independent in a high dimension space and thus to be orthogonal. This works when the dimension is not too large. When exploring with a large $d$, there are methods for creating sparse projection matrices, which is what the intrinsic dimension paper suggested.\nHere are experiment runs on two networks: (left) a two-layer fc network with 64 units in each layer and (right) a one-layer fc network with 128 hidden units, trained on 10% of MNIST. For every $d$, the model is trained for 100 epochs. See the code here.\n Cited as:\n@article{weng2019overfit, title = \u0026quot;Are Deep Neural Networks Dramatically Overfitted?\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2019\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2019-03-14-overfit/\u0026quot; } References [1] Wikipedia page on Occam\u0026rsquo;s Razor.\n[2] Occam\u0026rsquo;s Razor on Principia Cybernetica Web.\n[3] Peter Grunwald. \u0026ldquo;A Tutorial Introduction to the Minimum Description Length Principle\u0026rdquo;. 2004.\n[4] Ian Goodfellow, et al. Deep Learning. 2016. Sec 6.4.1.\n[5] Zhang, Chiyuan, et al. \u0026ldquo;Understanding deep learning requires rethinking generalization.\u0026quot; ICLR 2017.\n[6] Shibani Santurkar, et al. \u0026ldquo;How does batch normalization help optimization?.\u0026quot; NIPS 2018.\n[7] Mikhail Belkin, et al. \u0026ldquo;Reconciling modern machine learning and the bias-variance trade-off.\u0026quot; arXiv:1812.11118, 2018.\n[8] Chiyuan Zhang, et al. \u0026ldquo;Are All Layers Created Equal?\u0026quot; arXiv:1902.01996, 2019.\n[9] Chunyuan Li, et al. \u0026ldquo;Measuring the intrinsic dimension of objective landscapes.\u0026quot; ICLR 2018.\n[10] Jonathan Frankle and Michael Carbin. \u0026ldquo;The lottery ticket hypothesis: Finding sparse, trainable neural networks.\u0026quot; ICLR 2019.\n","permalink":"https://lilianweng.github.io/posts/2019-03-14-overfit/","summary":"[Updated on 2019-05-27: add the section on Lottery Ticket Hypothesis.]\nIf you are like me, entering into the field of deep learning with experience in traditional machine learning, you may often ponder over this question: Since a typical deep neural network has so many parameters and training error can easily be perfect, it should surely suffer from substantial overfitting. How could it be ever generalized to out-of-sample data points?\nThe effort in understanding why deep neural networks can generalize somehow reminds me of this interesting paper on System Biology \u0026mdash; \u0026ldquo;Can a biologist fix a radio?","title":"Are Deep Neural Networks Dramatically Overfitted?"},{"content":"[Updated on 2019-02-14: add ULMFiT and GPT-2.] [Updated on 2020-02-29: add ALBERT.] [Updated on 2020-10-25: add RoBERTa.] [Updated on 2020-12-13: add T5.] [Updated on 2020-12-30: add GPT-3.] [Updated on 2021-11-13: add XLNet, BART and ELECTRA; Also updated the Summary section.]\nFig. 0. I guess they are Elmo \u0026 Bert? (Image source: here) We have seen amazing progress in NLP in 2018. Large-scale pre-trained language modes like OpenAI GPT and BERT have achieved great performance on a variety of language tasks using generic model architectures. The idea is similar to how ImageNet classification pre-training helps many vision tasks (*). Even better than vision classification pre-training, this simple and powerful approach in NLP does not require labeled data for pre-training, allowing us to experiment with increased training scale, up to our very limit.\n(*) He et al. (2018) found that pre-training might not be necessary for image segmentation task.\nIn my previous NLP post on word embedding, the introduced embeddings are not context-specific \u0026mdash; they are learned based on word concurrency but not sequential context. So in two sentences, \u0026ldquo;I am eating an apple\u0026rdquo; and \u0026ldquo;I have an Apple phone\u0026rdquo;, two \u0026ldquo;apple\u0026rdquo; words refer to very different things but they would still share the same word embedding vector.\nDespite this, early adoption of word embeddings in problem-solving is to use them as additional features for an existing task-specific model and in a way the improvement is bounded.\nIn this post, we will discuss how various approaches were proposed to make embeddings dependent on context, and to make them easier and cheaper to be applied to downstream tasks in general form.\nCoVe CoVe (McCann et al. 2017), short for Contextual Word Vectors, is a type of word embeddings learned by an encoder in an attentional seq-to-seq machine translation model. Different from traditional word embeddings introduced here, CoVe word representations are functions of the entire input sentence.\nNMT Recap Here the Neural Machine Translation (NMT) model is composed of a standard, two-layer, bidirectional LSTM encoder and an attentional two-layer unidirectional LSTM decoder. It is pre-trained on the English-German translation task. The encoder learns and optimizes the embedding vectors of English words in order to translate them to German. With the intuition that the encoder should capture high-level semantic and syntactic meanings before transforming words into another language, the encoder output is used to provide contextualized word embeddings for various downstream language tasks.\nFig. 1. The NMT base model used in CoVe. A sequence of $n$ words in source language (English): $x = [x_1, \\dots, x_n]$. A sequence of $m$ words in target language (German): $y = [y_1, \\dots, y_m]$. The GloVe vectors of source words: $\\text{GloVe}(x)$. Randomly initialized embedding vectors of target words: $z = [z_1, \\dots, z_m]$. The biLSTM encoder outputs a sequence of hidden states: $h = [h_1, \\dots, h_n] = \\text{biLSTM}(\\text{GloVe}(x))$ and $h_t = [\\overrightarrow{h}_t; \\overleftarrow{h}_t]$ where the forward LSTM computes $\\overrightarrow{h}_t = \\text{LSTM}(x_t, \\overrightarrow{h}_{t-1})$ and the backward computation gives us $\\overleftarrow{h}_t = \\text{LSTM}(x_t, \\overleftarrow{h}_{t-1})$. The attentional decoder outputs a distribution over words: $p(y_t \\mid H, y_1, \\dots, y_{t-1})$ where $H$ is a stack of hidden states $\\{h\\}$ along the time dimension: $$ \\begin{aligned} \\text{decoder hidden state: } s_t \u0026= \\text{LSTM}([z_{t-1}; \\tilde{h}_{t-1}], s_{t-1}) \\\\ \\text{attention weights: } \\alpha_t \u0026= \\text{softmax}(H(W_1 s_t + b_1)) \\\\ \\text{context-adjusted hidden state: } \\tilde{h}_t \u0026= \\tanh(W_2[H^\\top\\alpha_t;s_t] + b_2) \\\\ \\text{decoder output: } p(y_t\\mid H, y_1, \\dots, y_{t-1}) \u0026= \\text{softmax}(W_\\text{out} \\tilde{h}_t + b_\\text{out}) \\end{aligned} $$ Use CoVe in Downstream Tasks The hidden states of NMT encoder are defined as context vectors for other language tasks:\n $$ \\text{CoVe}(x) = \\text{biLSTM}(\\text{GloVe}(x)) $$ The paper proposed to use the concatenation of GloVe and CoVe for question-answering and classification tasks. GloVe learns from the ratios of global word co-occurrences, so it has no sentence context, while CoVe is generated by processing text sequences is able to capture the contextual information.\n $$ v = [\\text{GloVe}(x); \\text{CoVe}(x)] $$ Given a downstream task, we first generate the concatenation of GloVe + CoVe vectors of input words and then feed them into the task-specific models as additional features.\nFig. 2. The CoVe embeddings are generated by an encoder trained for machine translation task. The encoder can be plugged into any downstream task-specific model. (Image source: original paper) Summary: The limitation of CoVe is obvious: (1) pre-training is bounded by available datasets on the supervised translation task; (2) the contribution of CoVe to the final performance is constrained by the task-specific model architecture.\nIn the following sections, we will see that ELMo overcomes issue (1) by unsupervised pre-training and OpenAI GPT \u0026amp; BERT further overcome both problems by unsupervised pre-training + using generative model architecture for different downstream tasks.\nELMo ELMo, short for Embeddings from Language Model (Peters, et al, 2018) learns contextualized word representation by pre-training a language model in an unsupervised way.\nBidirectional Language Model The bidirectional Language Model (biLM) is the foundation for ELMo. While the input is a sequence of $n$ tokens, $(x_1, \\dots, x_n)$, the language model learns to predict the probability of next token given the history.\nIn the forward pass, the history contains words before the target token,\n $$ p(x_1, \\dots, x_n) = \\prod_{i=1}^n p(x_i \\mid x_1, \\dots, x_{i-1}) $$ In the backward pass, the history contains words after the target token,\n $$ p(x_1, \\dots, x_n) = \\prod_{i=1}^n p(x_i \\mid x_{i+1}, \\dots, x_n) $$ The predictions in both directions are modeled by multi-layer LSTMs with hidden states $\\overrightarrow{\\mathbf{h}}_{i,\\ell}$ and $\\overleftarrow{\\mathbf{h}}_{i,\\ell}$ for input token $x_i$ at the layer level $\\ell=1,\\dots,L$. The final layer’s hidden state $\\mathbf{h}_{i,L} = [\\overrightarrow{\\mathbf{h}}_{i,L}; \\overleftarrow{\\mathbf{h}}_{i,L}]$ is used to output the probabilities over tokens after softmax normalization. They share the embedding layer and the softmax layer, parameterized by $\\Theta_e$ and $\\Theta_s$ respectively.\nFig. 3. The biLSTM base model of ELMo. (Image source: recreated based on the figure in [\"Neural Networks, Types, and Functional Programming\"](http://colah.github.io/posts/2015-09-NN-Types-FP/) by Christopher Olah.) The model is trained to minimize the negative log likelihood (= maximize the log likelihood for true words) in both directions:\n $$ \\begin{aligned} \\mathcal{L} = - \\sum_{i=1}^n \\Big( \\log p(x_i \\mid x_1, \\dots, x_{i-1}; \\Theta_e, \\overrightarrow{\\Theta}_\\text{LSTM}, \\Theta_s) + \\\\ \\log p(x_i \\mid x_{i+1}, \\dots, x_n; \\Theta_e, \\overleftarrow{\\Theta}_\\text{LSTM}, \\Theta_s) \\Big) \\end{aligned} $$ ELMo Representations On top of a $L$-layer biLM, ELMo stacks all the hidden states across layers together by learning a task-specific linear combination. The hidden state representation for the token $x_i$ contains $2L+1$ vectors:\n $$ R_i = \\{ \\mathbf{h}_{i,\\ell} \\mid \\ell = 0, \\dots, L \\} $$ where $\\mathbf{h}_{0, \\ell}$ is the embedding layer output and $\\mathbf{h}_{i, \\ell} = [\\overrightarrow{\\mathbf{h}}_{i,\\ell}; \\overleftarrow{\\mathbf{h}}_{i,\\ell}]$.\nThe weights, $\\mathbf{s}^\\text{task}$, in the linear combination are learned for each end task and normalized by softmax. The scaling factor $\\gamma^\\text{task}$ is used to correct the misalignment between the distribution of biLM hidden states and the distribution of task specific representations.\n $$ v_i = f(R_i; \\Theta^\\text{task}) = \\gamma^\\text{task} \\sum_{\\ell=0}^L s^\\text{task}_i \\mathbf{h}_{i,\\ell} $$ To evaluate what kind of information is captured by hidden states across different layers, ELMo is applied on semantic-intensive and syntax-intensive tasks respectively using representations in different layers of biLM:\n Semantic task: The word sense disambiguation (WSD) task emphasizes the meaning of a word given a context. The biLM top layer is better at this task than the first layer. Syntax task: The part-of-speech (POS) tagging task aims to infer the grammatical role of a word in one sentence. A higher accuracy can be achieved by using the biLM first layer than the top layer. The comparison study indicates that syntactic information is better represented at lower layers while semantic information is captured by higher layers. Because different layers tend to carry different type of information, stacking them together helps.\nUse ELMo in Downstream Tasks Similar to how CoVe can help different downstream tasks, ELMo embedding vectors are included in the input or lower levels of task-specific models. Moreover, for some tasks (i.e., SNLI and SQuAD, but not SRL), adding them into the output level helps too.\nThe improvements brought up by ELMo are largest for tasks with a small supervised dataset. With ELMo, we can also achieve similar performance with much less labeled data.\nSummary: The language model pre-training is unsupervised and theoretically the pre-training can be scaled up as much as possible since the unlabeled text corpora are abundant. However, it still has the dependency on task-customized models and thus the improvement is only incremental, while searching for a good model architecture for every task remains non-trivial.\nCross-View Training In ELMo the unsupervised pre-training and task-specific learning happen for two independent models in two separate training stages. Cross-View Training (abbr. CVT; Clark et al., 2018) combines them into one unified semi-supervised learning procedure where the representation of a biLSTM encoder is improved by both supervised learning with labeled data and unsupervised learning with unlabeled data on auxiliary tasks.\nModel Architecture The model consists of a two-layer bidirectional LSTM encoder and a primary prediction module. During training, the model is fed with labeled and unlabeled data batches alternatively.\n On labeled examples, all the model parameters are updated by standard supervised learning. The loss is the standard cross entropy. On unlabeled examples, the primary prediction module still can produce a \u0026ldquo;soft\u0026rdquo; target, even though we cannot know exactly how accurate they are. In a couple of auxiliary tasks, the predictor only sees and processes a restricted view of the input, such as only using encoder hidden state representation in one direction. The auxiliary task outputs are expected to match the primary prediction target for a full view of input. In this way, the encoder is forced to distill the knowledge of the full context into partial representation. At this stage, the biLSTM encoder is backpropagated but the primary prediction module is fixed. The loss is to minimize the distance between auxiliary and primary predictions. Fig. 4. The overview of semi-supervised language model cross-view training. (Image source: original paper) Multi-Task Learning When training for multiple tasks simultaneously, CVT adds several extra primary prediction models for additional tasks. They all share the same sentence representation encoder. During supervised training, once one task is randomly selected, parameters in its corresponding predictor and the representation encoder are updated. With unlabeled data samples, the encoder is optimized jointly across all the tasks by minimizing the differences between auxiliary outputs and primary prediction for every task.\nThe multi-task learning encourages better generality of representation and in the meantime produces a nice side-product: all-tasks-labeled examples from unlabeled data. They are precious data labels considering that cross-task labels are useful but fairly rare.\nUse CVT in Downstream Tasks Theoretically the primary prediction module can take any form, generic or task-specific design. The examples presented in the CVT paper include both cases.\nIn sequential tagging tasks (classification for every token) like NER or POS tagging, the predictor module contains two fully connected layers and a softmax layer on the output to produce a probability distribution over class labels. For each token $\\mathbf{x}_i$, we take the corresponding hidden states in two layers, $\\mathbf{h}_1^{(i)}$ and $\\mathbf{h}_2^{(i)}$:\n $$ \\begin{aligned} p_\\theta(y_i \\mid \\mathbf{x}_i) \u0026= \\text{NN}(\\mathbf{h}^{(i)}) \\\\ \u0026= \\text{NN}([\\mathbf{h}_1^{(i)}; \\mathbf{h}_2^{(i)}]) \\\\ \u0026= \\text{softmax} \\big( \\mathbf{W}\\cdot\\text{ReLU}(\\mathbf{W'}\\cdot[\\mathbf{h}_1^{(i)}; \\mathbf{h}_2^{(i)}]) + \\mathbf{b} \\big) \\end{aligned} $$ The auxiliary tasks are only fed with forward or backward LSTM state in the first layer. Because they only observe partial context, either on the left or right, they have to learn like a language model, trying to predict the next token given the context. The fwd and bwd auxiliary tasks only take one direction. The future and past tasks take one step further in forward and backward direction, respectively.\n $$ \\begin{aligned} p_\\theta^\\text{fwd}(y_i \\mid \\mathbf{x}_i) \u0026= \\text{NN}^\\text{fwd}(\\overrightarrow{\\mathbf{h}}^{(i)}) \\\\ p_\\theta^\\text{bwd}(y_i \\mid \\mathbf{x}_i) \u0026= \\text{NN}^\\text{bwd}(\\overleftarrow{\\mathbf{h}}^{(i)}) \\\\ p_\\theta^\\text{future}(y_i \\mid \\mathbf{x}_i) \u0026= \\text{NN}^\\text{future}(\\overrightarrow{\\mathbf{h}}^{(i-1)}) \\\\ p_\\theta^\\text{past}(y_i \\mid \\mathbf{x}_i) \u0026= \\text{NN}^\\text{past}(\\overleftarrow{\\mathbf{h}}^{(i+1)}) \\end{aligned} $$ Fig. 5. The sequential tagging task depends on four auxiliary prediction models, their inputs only involving hidden states in one direction: forward, backward, future and past. (Image source: original paper) Note that if the primary prediction module has dropout, the dropout layer works as usual when training with labeled data, but it is not applied when generating \u0026ldquo;soft\u0026rdquo; target for auxiliary tasks during training with unlabeled data.\nIn the machine translation task, the primary prediction module is replaced with a standard unidirectional LSTM decoder with attention. There are two auxiliary tasks: (1) apply dropout on the attention weight vector by randomly zeroing out some values; (2) predict the future word in the target sequence. The primary prediction for auxiliary tasks to match is the best predicted target sequence produced by running the fixed primary decoder on the input sequence with beam search.\nULMFiT The idea of using generative pretrained LM + task-specific fine-tuning was first explored in ULMFiT (Howard \u0026amp; Ruder, 2018), directly motivated by the success of using ImageNet pre-training for computer vision tasks. The base model is AWD-LSTM.\nULMFiT follows three steps to achieve good transfer learning results on downstream language classification tasks:\n General LM pre-training: on Wikipedia text.\n Target task LM fine-tuning: ULMFiT proposed two training techniques for stabilizing the fine-tuning process. See below.\n Discriminative fine-tuning is motivated by the fact that different layers of LM capture different types of information (see discussion above). ULMFiT proposed to tune each layer with different learning rates, $\\{\\eta^1, \\dots, \\eta^\\ell, \\dots, \\eta^L\\}$, where $\\eta$ is the base learning rate for the first layer, $\\eta^\\ell$ is for the $\\ell$-th layer and there are $L$ layers in total.\n Slanted triangular learning rates (STLR) refer to a special learning rate scheduling that first linearly increases the learning rate and then linearly decays it. The increase stage is short so that the model can converge to a parameter space suitable for the task fast, while the decay period is long allowing for better fine-tuning.\n Target task classifier fine-tuning: The pretrained LM is augmented with two standard feed-forward layers and a softmax normalization at the end to predict a target label distribution. Concat pooling extracts max-polling and mean-pooling over the history of hidden states and concatenates them with the final hidden state.\n Gradual unfreezing helps to avoid catastrophic forgetting by gradually unfreezing the model layers starting from the last one. First the last layer is unfrozen and fine-tuned for one epoch. Then the next lower layer is unfrozen. This process is repeated until all the layers are tuned.\n Fig. 6. Three training stages of ULMFiT. (Image source: original paper) GPT Following the similar idea of ELMo, OpenAI GPT, short for Generative Pre-training Transformer (Radford et al., 2018), expands the unsupervised language model to a much larger scale by training on a giant collection of free text corpora. Despite of the similarity, GPT has two major differences from ELMo.\n The model architectures are different: ELMo uses a shallow concatenation of independently trained left-to-right and right-to-left multi-layer LSTMs, while GPT is a multi-layer transformer decoder. The use of contextualized embeddings in downstream tasks are different: ELMo feeds embeddings into models customized for specific tasks as additional features, while GPT fine-tunes the same base model for all end tasks. Transformer Decoder as Language Model Compared to the original transformer architecture, the transformer decoder model discards the encoder part, so there is only one single input sentence rather than two separate source and target sequences.\nThis model applies multiple transformer blocks over the embeddings of input sequences. Each block contains a masked multi-headed self-attention layer and a pointwise feed-forward layer. The final output produces a distribution over target tokens after softmax normalization.\nFig. 7. The transformer decoder model architecture in OpenAI GPT. The loss is the negative log-likelihood, same as ELMo, but without backward computation. Let’s say, the context window of the size $k$ is located before the target word and the loss would look like:\n $$ \\mathcal{L}_\\text{LM} = -\\sum_{i} \\log p(x_i\\mid x_{i-k}, \\dots, x_{i-1}) $$ Byte Pair Encoding Byte Pair Encoding (BPE) is used to encode the input sequences. BPE was originally proposed as a data compression algorithm in 1990s and then was adopted to solve the open-vocabulary issue in machine translation, as we can easily run into rare and unknown words when translating into a new language. Motivated by the intuition that rare and unknown words can often be decomposed into multiple subwords, BPE finds the best word segmentation by iteratively and greedily merging frequent pairs of characters.\nSupervised Fine-Tuning The most substantial upgrade that OpenAI GPT proposed is to get rid of the task-specific model and use the pre-trained language model directly!\nLet’s take classification as an example. Say, in the labeled dataset, each input has $n$ tokens, $\\mathbf{x} = (x_1, \\dots, x_n)$, and one label $y$. GPT first processes the input sequence $\\mathbf{x}$ through the pre-trained transformer decoder and the last layer output for the last token $x_n$ is $\\mathbf{h}_L^{(n)}$. Then with only one new trainable weight matrix $\\mathbf{W}_y$, it can predict a distribution over class labels.\n $$ P(y\\mid x_1, \\dots, x_n) = \\text{softmax}(\\mathbf{h}_L^{(n)}\\mathbf{W}_y) $$ The loss is to minimize the negative log-likelihood for true labels. In addition, adding the LM loss as an auxiliary loss is found to be beneficial, because:\n (1) it helps accelerate convergence during training and (2) it is expected to improve the generalization of the supervised model. $$ \\begin{aligned} \\mathcal{L}_\\text{cls} \u0026= \\sum_{(\\mathbf{x}, y) \\in \\mathcal{D}} \\log P(y\\mid x_1, \\dots, x_n) = \\sum_{(\\mathbf{x}, y) \\in \\mathcal{D}} \\log \\text{softmax}(\\mathbf{h}_L^{(n)}(\\mathbf{x})\\mathbf{W}_y) \\\\ \\mathcal{L}_\\text{LM} \u0026= -\\sum_{i} \\log p(x_i\\mid x_{i-k}, \\dots, x_{i-1}) \\\\ \\mathcal{L} \u0026= \\mathcal{L}_\\text{cls} + \\lambda \\mathcal{L}_\\text{LM} \\end{aligned} $$ With similar designs, no customized model structure is needed for other end tasks (see Fig. 7). If the task input contains multiple sentences, a special delimiter token ($) is added between each pair of sentences. The embedding for this delimiter token is a new parameter we need to learn, but it should be pretty minimal.\nFor the sentence similarity task, because the ordering does not matter, both orderings are included. For the multiple choice task, the context is paired with every answer candidate.\nFig. 8. Training objects in slightly modified GPT transformer models for downstream tasks. (Image source: original paper) Summary: It is super neat and encouraging to see that such a general framework is capable to beat SOTA on most language tasks at that time (June 2018). At the first stage, generative pre-training of a language model can absorb as much free text as possible. Then at the second stage, the model is fine-tuned on specific tasks with a small labeled dataset and a minimal set of new parameters to learn.\nOne limitation of GPT is its uni-directional nature \u0026mdash; the model is only trained to predict the future left-to-right context.\nBERT BERT, short for Bidirectional Encoder Representations from Transformers (Devlin, et al., 2019) is a direct descendant to GPT: train a large language model on free text and then fine-tune on specific tasks without customized network architectures.\nCompared to GPT, the largest difference and improvement of BERT is to make training bi-directional. The model learns to predict both context on the left and right. The paper according to the ablation study claimed that:\n \u0026ldquo;bidirectional nature of our model is the single most important new contribution\u0026rdquo;\n Pre-training Tasks The model architecture of BERT is a multi-layer bidirectional Transformer encoder.\nFig. 9. Recap of Transformer Encoder model architecture. (Image source: Transformer paper) To encourage the bi-directional prediction and sentence-level understanding, BERT is trained with two tasks instead of the basic language task (that is, to predict the next token given context).\n*Task 1: Mask language model (MLM)\n From Wikipedia: \u0026ldquo;A cloze test (also cloze deletion test) is an exercise, test, or assessment consisting of a portion of language with certain items, words, or signs removed (cloze text), where the participant is asked to replace the missing language item. … The exercise was first described by W.L. Taylor in 1953.\u0026rdquo;\n It is unsurprising to believe that a representation that learns the context around a word rather than just after the word is able to better capture its meaning, both syntactically and semantically. BERT encourages the model to do so by training on the \u0026ldquo;mask language model\u0026rdquo; task:\n Randomly mask 15% of tokens in each sequence. Because if we only replace masked tokens with a special placeholder [MASK], the special token would never be encountered during fine-tuning. Hence, BERT employed several heuristic tricks: (a) with 80% probability, replace the chosen words with [MASK]; (b) with 10% probability, replace with a random word; (c) with 10% probability, keep it the same. The model only predicts the missing words, but it has no information on which words have been replaced or which words should be predicted. The output size is only 15% of the input size. Task 2: Next sentence prediction\nMotivated by the fact that many downstream tasks involve the understanding of relationships between sentences (i.e., QA, NLI), BERT added another auxiliary task on training a binary classifier for telling whether one sentence is the next sentence of the other:\n Sample sentence pairs (A, B) so that: (a) 50% of the time, B follows A; (b) 50% of the time, B does not follow A. The model processes both sentences and output a binary label indicating whether B is the next sentence of A. The training data for both auxiliary tasks above can be trivially generated from any monolingual corpus. Hence the scale of training is unbounded. The training loss is the sum of the mean masked LM likelihood and mean next sentence prediction likelihood.\nFig. 10. Comparison of BERT, OpenAI GPT and ELMo model architectures. (Image source: original paper) Input Embedding The input embedding is the sum of three parts:\n WordPiece tokenization embeddings: The WordPiece model was originally proposed for Japanese or Korean segmentation problem. Instead of using naturally split English word, they can be further divided into smaller sub-word units so that it is more effective to handle rare or unknown words. Please read linked papers for the optimal way to split words if interested. Segment embeddings: If the input contains two sentences, they have sentence A embeddings and sentence B embeddings respectively and they are separated by a special character [SEP]; Only sentence A embeddings are used if the input only contains one sentence. Position embeddings: Positional embeddings are learned rather than hard-coded. Fig. 11. BERT input representation. (Image source: original paper) Note that the first token is always forced to be [CLS] \u0026mdash; a placeholder that will be used later for prediction in downstream tasks.\nUse BERT in Downstream Tasks BERT fine-tuning requires only a few new parameters added, just like OpenAI GPT.\nFor classification tasks, we get the prediction by taking the final hidden state of the special first token [CLS], $\\mathbf{h}^\\text{[CLS]}_L$, and multiplying it with a small weight matrix, $\\text{softmax}(\\mathbf{h}^\\text{[CLS]}_L \\mathbf{W}_\\text{cls})$.\nFor QA tasks like SQuAD, we need to predict the text span in the given paragraph for an given question. BERT predicts two probability distributions of every token, being the start and the end of the text span. Only two new small matrices, $\\mathbf{W}_\\text{s}$ and $\\mathbf{W}_\\text{e}$, are newly learned during fine-tuning and $\\text{softmax}(\\mathbf{h}^\\text{(i)}_L \\mathbf{W}_\\text{s})$ and $\\text{softmax}(\\mathbf{h}^\\text{(i)}_L \\mathbf{W}_\\text{e})$ define two probability distributions.\nOverall the add-on part for end task fine-tuning is very minimal \u0026mdash; one or two weight matrices to convert the Transform hidden states to an interpretable format. Check the paper for implementation details for other cases.\nFig. 12. Training objects in slightly modified BERT models for downstream tasks. (Image source: original paper) A summary table compares differences between fine-tuning of OpenAI GPT and BERT.\n| | OpenAI GPT | BERT | | Special char | [SEP] and [CLS] are only introduced at fine-tuning stage. | [SEP] and [CLS] and sentence A/B embeddings are learned at the pre-training stage. | | Training process | 1M steps, batch size 32k words. | 1M steps, batch size 128k words. | | Fine-tuning | lr = 5e-5 for all fine-tuning tasks. | Use task-specific lr for fine-tuning. |\nALBERT ALBERT (Lan, et al. 2019), short for A Lite BERT, is a light-weighted version of BERT model. An ALBERT model can be trained 1.7x faster with 18x fewer parameters, compared to a BERT model of similar configuration. ALBERT incorporates three changes as follows: the first two help reduce parameters and memory consumption and hence speed up the training speed, while the third one proposes a more chanllenging training task to replace the next sentence prediction (NSP) objective.\nFactorized Embedding Parameterization In BERT, the WordPiece tokenization embedding size $E$ is configured to be the same as the hidden state size $H$. That is saying, if we want to increase the model size (larger $H$), we need to learn a larger tokenization embedding too, which is expensive because it depends on the vocabulary size ($V$).\nConceptually, because the tokenization embedding is expected to learn context-independent representation and the hidden states are context-dependent, it makes sense to separate the size of the hidden layers from the size of vocabulary embedding. Using factorized embedding parameterization, the large vocabulary embedding matrix of size $V \\times H$ is decomposed into two small matrices of size $V \\times E$ and $E \\times H$. Given $H \\gt E$ or even $H \\gg E$, factorization can result in significant parameter reduction.\nCross-layer Parameter Sharing Parameter sharing across layers can happen in many ways: (a) only share feed-forward part; (b) only share attention parameters; or (c) share all the parameters. This technique reduces the number of parameters by a ton and does not damage the performance too much.\nSentence-Order Prediction (SOP) Interestingly, the next sentence prediction (NSP) task of BERT turned out to be too easy. ALBERT instead adopted a sentence-order prediction (SOP) self-supervised loss,\n Positive sample: two consecutive segments from the same document. Negative sample: same as above, but the segment order is switched. For the NSP task, the model can make reasonable predictions if it is able to detect topics when A and B are from different contexts. In comparison, SOP is harder as it requires the model to fully understand the coherence and ordering between segments.\nGPT-2 The OpenAI GPT-2 language model is a direct successor to GPT. GPT-2 has 1.5B parameters, 10x more than the original GPT, and it achieves SOTA results on 7 out of 8 tested language modeling datasets in a zero-shot transfer setting without any task-specific fine-tuning. The pre-training dataset contains 8 million Web pages collected by crawling qualified outbound links from Reddit. Large improvements by OpenAI GPT-2 are specially noticeable on small datasets and datasets used for measuring long-term dependency.\nZero-Shot Transfer The pre-training task for GPT-2 is solely language modeling. All the downstream language tasks are framed as predicting conditional probabilities and there is no task-specific fine-tuning.\n Text generation is straightforward using LM. Machine translation task, for example, English to Chinese, is induced by conditioning LM on pairs of \u0026ldquo;English sentence = Chinese sentence\u0026rdquo; and \u0026ldquo;the target English sentence =\u0026rdquo; at the end. For example, the conditional probability to predict might look like: P(? | I like green apples. = 我喜欢绿苹果。 A cat meows at him. = 一只猫对他喵。It is raining cats and dogs. =\u0026quot;) QA task is formatted similar to translation with pairs of questions and answers in the context. Summarization task is induced by adding TL;DR: after the articles in the context. BPE on Byte Sequences Same as the original GPT, GPT-2 uses BPE but on UTF-8 byte sequences. Each byte can represent 256 different values in 8 bits, while UTF-8 can use up to 4 bytes for one character, supporting up to $2^{31}$ characters in total. Therefore, with byte sequence representation we only need a vocabulary of size 256 and do not need to worry about pre-processing, tokenization, etc. Despite of the benefit, current byte-level LMs still have non-negligible performance gap with the SOTA word-level LMs.\nBPE merges frequently co-occurred byte pairs in a greedy manner. To prevent it from generating multiple versions of common words (i.e. dog., dog! and dog? for the word dog), GPT-2 prevents BPE from merging characters across categories (thus dog would not be merged with punctuations like ., ! and ?). This tricks help increase the quality of the final byte segmentation.\nUsing the byte sequence representation, GPT-2 is able to assign a probability to any Unicode string, regardless of any pre-processing steps.\nModel Modifications Compared to GPT, other than having many more transformer layers and parameters, GPT-2 incorporates only a few architecture modifications:\n Layer normalization was moved to the input of each sub-block, similar to a residual unit of type \u0026ldquo;building block\u0026rdquo; (differently from the original type \u0026ldquo;bottleneck\u0026rdquo;, it has batch normalization applied before weight layers). An additional layer normalization was added after the final self-attention block. A modified initialization was constructed as a function of the model depth. The weights of residual layers were initially scaled by a factor of $1/ \\sqrt{N}$ where N is the number of residual layers. Use larger vocabulary size and context size. RoBERTa RoBERTa (short for Robustly optimized BERT approach; Liu, et al. 2019) refers to a new receipt for training BERT to achieve better results, as they found that the original BERT model is significantly undertrained. The receipt contains the following learnings:\n Train for longer with bigger batch size. Remove the next sentence prediction (NSP) task. Use longer sequences in training data format. The paper found that using individual sentences as inputs hurts downstream performance. Instead we should use multiple sentences sampled contiguously to form longer segments. Change the masking pattern dynamically. The original BERT applies masking once during the data preprocessing stage, resulting in a static mask across training epochs. RoBERTa applies masks in 10 different ways across 40 epochs. RoBERTa also added a new dataset CommonCrawl News and further confirmed that pretraining with more data helps improve the performance on downstream tasks. It was trained with the BPE on byte sequences, same as in GPT-2. They also found that choices of hyperparameters have a big impact on the model performance.\nT5 The language model T5 is short for \u0026ldquo;Text-to-Text Transfer Transformer\u0026rdquo; (Raffel et al., 2020). The encoder-decoder implementation follows the original Transformer architecture: tokens → embedding → encoder → decoder → output. T5 adopts the framework “Natural Language Decathlon” (McCann et al., 2018), where many common NLP tasks are translated into question-answering over a context. Instead of an explicit QA format, T5 uses short task prefixes to distinguish task intentions and separately fine-tunes the model on every individual task. The text-to-text framework enables easier transfer learning evaluation with the same model on a diverse set of tasks.\nFig. 13. A diagram of T5 task evaluation. The text-to-text framework casts every task into a generic form: feeding input text to predict some target text. (Image source: Raffel et al., 2020) The model is trained on Web corpus extracted from Apr 2019 with various filters applied. The model is fine-tuned for each downstream task separately via \u0026ldquo;adapter layers\u0026rdquo; (add an extra layer for training) or \u0026ldquo;gradual unfreezing\u0026rdquo; (see ULMFiT). Both fine-tuning approaches only update partial parameters while keeping the majority of the model parameters unchanged. T5-11B achieved SOTA results on many NLP tasks.\nAs the authors mentioned in the paper \u0026ldquo;\u0026hellip;our goal is not to propose new methods but instead to provide a comprehensive perspective on where the field stands\u0026rdquo;, the T5 long paper described a lot of training setup and evaluation processes in detail, a good read for people who are interested in training a LM from scratch.\nGPT-3 GPT-3 (Brown et al., 2020) has the same architecture as GPT-2 but contains 175B parameters, 10x larger than GPT-2 (1.5B). In addition, GPT-3 uses alternating dense and locally banded sparse attention patterns, same as in sparse transformer. In order to fit such a huge model across multiple GPUs, GPT-3 is trained with partitions along both width and depth dimension. The training data is a filtered version of Common Crawl mixed with a few other high-quality curated datasets. To avoid the contamination that downstream tasks might appear in the training data, the authors attempted to remove all the overlaps with all the studied benchmark dataset from the training dataset. Unfortunately the filtering process is not perfect due to a bug.\nFig. 14. Training datasets for GPT-3. Note that the occurrence of each dataset during training is not proportional to the dataset size. (Table source: Brown et al., 2020) For all the downstream evaluation, GPT-3 is tested in the few-shot setting without any gradient-based fine-tuning. Here the few-shot examples are provided as part of the prompt. GPT-3 achieves strong performance on many NLP datasets, comparable with fine-tuned BERT models.\nFig. 15. The evaluation performance increases with the model size and the number of examples. (Image source: Brown et al., 2020) XLNet The Autoregressive (AR) model such as GPT and autoencoder (AE) model such as BERT are two most common ways for language modeling. However, each has their own disadvantages: AR does not learn the bidirectional context, which is needed by downstream tasks like reading comprehension and AE assumes masked positions are independent given all other unmasked tokens which oversimplifies the long context dependency.\nXLNet (Yang et al. 2019) generalizes the AE method to incorporate the benefits of AR. XLNet proposed the permutation language modeling objective. For a text sequence, it samples a factorization order $\\mathbf{z}$ and decomposes the likelihood $p_\\theta(\\mathbf{x})$ according to this factorization order,\n $$ \\begin{aligned} \\mathcal{L}_\\text{XLNet} \u0026= - \\mathbb{E}_{\\mathbf{z} \\sim \\mathcal{Z}_T} \\Big[ \\sum_{t=1}^T \\log p_\\theta (X_{z_t} = x \\mid \\mathbf{x}_{\\mathbf{z}_{where $\\mathcal{Z}_T$ is a set of all possible permutation of length $T$; $z_t$ and $\\mathbf{z}_{\u0026lt;t}$ denote the $t$-th element and the first $t-1$ elements of a permutation $\\mathbf{z} \\in \\mathcal{Z}_T$.\nNote that the naive representation of the hidden state of the context, $h_\\theta (\\mathbf{x}_{\\mathbf{z}_{\u0026lt;t}})$ in red, does not depend on which position the model tries to predict, as the permutation breaks the default ordering. Therefore, XLNet re-parameterized it to a function of the target position too, $g_\\theta (\\mathbf{x}_{\\mathbf{z}_{\u0026lt;t}}, z_t)$ in blue.\nHowever, two different requirements on $g_\\theta (\\mathbf{x}_{\\mathbf{z}_{\u0026lt;t}}, z_t)$ lead to a two-stream self-attention design to accommodate:\n When predicting $x_{z_t}$, it should only encode the position $z_t$ but not the content $x_{z_t}$; otherwise it is trivial. This is wrapped into the \u0026ldquo;query representation\u0026rdquo; $g_{z_t} = g_\\theta (\\mathbf{x}_{\\mathbf{z}_{\u0026lt;t}}, z_t)$ does not encode $x_{z_t}$. When predicting $x_j$ where $j \u0026gt; t$, it should encode the content $x_{z_t}$ as well to provide the full context. This is the \u0026ldquo;content representation\u0026rdquo; $h_{z_t} = h_\\theta(\\mathbf{x}_{\\leq t})$. Fig. 16. The illustration of two-stream self-attention mechanism in XLNet. (Image source: Yang et al. 2019) Conceptually, the two streams of representations are updated as follows,\n $$ \\begin{aligned} g_{z_t}^{(m)} \u0026\\gets \\text{Attention}(Q = g^{(m-1)}_{z_t}, KV=\\mathbf{h}^{(m-1)}_{\\color{red}{\\mathbf{z}_{Given the difficulty of optimization in permutation language modeling, XLNet is set to only predict the last chunk of tokens in a factorization order.\nThe name in XLNet actually comes from Transformer-XL. It incorporates the design of Transformer-XL to extend the attention span by reusing hidden states from previous segments.\nFig. 17. Comparison of model performance of XLNet with a couple other language models on GLUE, all single-task, no ensembles. (Image source: Yang et al. 2019) BART BART (Lewis et al., 2019) is a denoising autoencoder to recover the original text from a randomly corrupted version. It combines Bidirectional and AutoRegressive Transformer: precisely, jointly training BERT-like bidirectional encoder and GPT-like autoregressive decoder together. The loss is simply just to minimize the negative log-likelihood.\nFig. 18. A schematic comparison of BART with BERT and GPT. (Image source: Lewis et al., 2019) They experimented with a variety of noising transformations, including token masking, token deletion, text infilling (i.e. A randomly sampled text span, which may contain multiple tokens, is replaced with a [MASK] token), sentence permutation, documentation rotation (i.e. A document is rotated to begin with a random token.). The best noising approach they discovered is text infilling and sentence shuffling.\nFig. 19. Comparison of different language modeling pre-training objectives. (Image source: Lewis et al., 2019) Learnings from their experiments:\n The performance of pre-training methods varies significantly across downstream tasks. Token masking is crucial, as the performance is poor when only sentence permutation or documentation rotation is applied. Left-to-right pre-training improves generation. Bidirectional encoders are crucial for SQuAD. The pre-training objective is not the only important factor. Architectural improvements such as relative-position embeddings or segment-level recurrence matter too. Autoregressive language models perform best on ELI5. BART achieves the most consistently strong performance. ELECTRA Most current pre-training large language models demand a lot of computation resources, raising concerns about their cost and accessibility. ELECTRA (\u0026ldquo;Efficiently Learning an Encoder that Classifies Token Replacements Accurately\u0026rdquo;; Clark et al. 2020) aims to improve the pre-training efficiency, which frames the language modeling as a discrimination task instead of generation task.\nFig. 20. Illustration of ELECTRA model architecture. (Image source: Clark et al. 2020) ELECTRA proposes a new pretraining task, called \u0026ldquo;Replaced Token Detection\u0026rdquo; (RTD). Let\u0026rsquo;s randomly sample $k$ positions to be masked. Each selected token in the original text is replaced by a plausible alternative predicted by a small language model, known as the generator $G$. The discriminator $D$ predicts whether each token is original or replaced.\n $$ \\begin{aligned} \\boldsymbol{m} \u0026= [m_1, \\dots, m_k] \\text{ where } m_i \\sim \\text{unif}\\{1, n\\}\\text{ for } i=1, \\dots, k \\\\ \\boldsymbol{x}^\\text{masked} \u0026= \\text{REPLACE}(\\boldsymbol{x}, \\boldsymbol{m}, \\texttt{[MASK]}) \\\\ \\boldsymbol{x}^\\text{corrupt} \u0026= \\text{REPLACE}(\\boldsymbol{x}, \\boldsymbol{m}, \\tilde{\\boldsymbol{x}}) \\text{ where } \\tilde{x}_t \\sim p_G(x_i \\mid \\boldsymbol{x}^\\text{masked}) \\text{ for } i \\in \\boldsymbol{m} \\\\ \\end{aligned} $$ The loss for the generator is the negative log-likelihood just as in other language models. The loss for the discriminator is the cross-entropy. Note that the generator is not adversarially trained to fool the discriminator but simply to optimize the NLL, since their experiments show negative results.\n $$ \\begin{aligned} \\mathcal{L}_\\text{MLM}(\\mathbf{x}, \\theta_G) \u0026= \\mathbb{E}\\Big(\\sum_{i \\in \\boldsymbol{m}} -\\log p_G (x_i \\mid \\boldsymbol{x}^\\text{masked} )\\Big) \\\\ \\mathcal{L}_\\text{Disc}(\\mathbf{x}, \\theta_D) \u0026= \\mathbb{E}\\Big( - \\mathbb{1}[x^\\text{corrupt}_t = x_t] \\log D(\\boldsymbol{x}^\\text{corrupt}, t) - \\mathbb{1}[x^\\text{corrupt}_t \\neq x_t] \\log (1 - \\log D(\\boldsymbol{x}^\\text{corrupt}, t)) \\Big) \\end{aligned} $$ They found it more beneficial to only share the embeddings between generator \u0026amp; discriminator while using a small generator (1/4 to 1/2 the discriminator size), rather than sharing all the weights (i.e. two models have to be the same size then). In addition, joint training of the generator and discriminator works better than two-stage training of each alternatively.\nAfter pretraining the generator is discarded and only the ELECTRA discriminator is fine-tuned further for downstream tasks. The following table shows ELECTRA\u0026rsquo;s performance on the GLUE dev set.\nFig. 21. Comparison of ELECTRA with other language models on the GLUE dev set. (Image source: Clark et al. 2020) Summary Base model Pretraining Tasks CoVe seq2seq NMT model supervised learning using translation dataset. ELMo two-layer biLSTM next token prediction CVT two-layer biLSTM semi-supervised learning using both labeled and unlabeled datasets ULMFiT AWD-LSTM autoregressive pretraining on Wikitext-103 GPT Transformer decoder next token prediction BERT Transformer encoder mask language model + next sentence prediction ALBERT same as BERT but light-weighted mask language model + sentence order prediction GPT-2 Transformer decoder next token prediction RoBERTa same as BERT mask language model (dynamic masking) T5 Transformer encoder + decoder pre-trained on a multi-task mixture of unsupervised and supervised tasks and for which each task is converted into a text-to-text format. GPT-3 Transformer decoder next token prediction XLNet same as BERT permutation language modeling BART BERT encoder + GPT decoder reconstruct text from a noised version ELECTRA same as BERT replace token detection Metric: Perplexity Perplexity is often used as an intrinsic evaluation metric for gauging how well a language model can capture the real word distribution conditioned on the context.\nA perplexity of a discrete proability distribution $p$ is defined as the exponentiation of the entropy:\n $$ 2^{H(p)} = 2^{-\\sum_x p(x) \\log_2 p(x)} $$ Given a sentence with $N$ words, $s = (w_1, \\dots, w_N)$, the entropy looks as follows, simply assuming that each word has the same frequency, $\\frac{1}{N}$:\n $$ H(s) = -\\sum_{i=1}^N P(w_i) \\log_2 p(w_i) = -\\sum_{i=1}^N \\frac{1}{N} \\log_2 p(w_i) $$ The perplexity for the sentence becomes:\n $$ \\begin{aligned} 2^{H(s)} \u0026= 2^{-\\frac{1}{N} \\sum_{i=1}^N \\log_2 p(w_i)} = (2^{\\sum_{i=1}^N \\log_2 p(w_i)})^{-\\frac{1}{N}} = (p(w_1) \\dots p(w_N))^{-\\frac{1}{N}} \\end{aligned} $$ A good language model should predict high word probabilities. Therefore, the smaller perplexity the better.\nCommon Tasks and Datasets Question-Answering\n SQuAD (Stanford Question Answering Dataset): A reading comprehension dataset, consisting of questions posed on a set of Wikipedia articles, where the answer to every question is a span of text. RACE (ReAding Comprehension from Examinations): A large-scale reading comprehension dataset with more than 28,000 passages and nearly 100,000 questions. The dataset is collected from English examinations in China, which are designed for middle school and high school students. See more QA datasets in a later post. Commonsense Reasoning\n Story Cloze Test: A commonsense reasoning framework for evaluating story understanding and generation. The test requires a system to choose the correct ending to multi-sentence stories from two options. SWAG (Situations With Adversarial Generations): multiple choices; contains 113k sentence-pair completion examples that evaluate grounded common-sense inference Natural Language Inference (NLI): also known as Text Entailment, an exercise to discern in logic whether one sentence can be inferred from another.\n RTE (Recognizing Textual Entailment): A set of datasets initiated by text entailment challenges. SNLI (Stanford Natural Language Inference): A collection of 570k human-written English sentence pairs manually labeled for balanced classification with the labels entailment, contradiction, and neutral. MNLI (Multi-Genre NLI): Similar to SNLI, but with a more diverse variety of text styles and topics, collected from transcribed speech, popular fiction, and government reports. QNLI (Question NLI): Converted from SQuAD dataset to be a binary classification task over pairs of (question, sentence). SciTail: An entailment dataset created from multiple-choice science exams and web sentences. Named Entity Recognition (NER): labels sequences of words in a text which are the names of things, such as person and company names, or gene and protein names\n CoNLL 2003 NER task: consists of newswire from the Reuters, concentrating on four types of named entities: persons, locations, organizations and names of miscellaneous entities. OntoNotes 5.0: This corpus contains text in English, Arabic and Chinese, tagged with four different entity types (PER, LOC, ORG, MISC). Reuters Corpus: A large collection of Reuters News stories. Fine-Grained NER (FGN) Sentiment Analysis\n SST (Stanford Sentiment Treebank) IMDb: A large dataset of movie reviews with binary sentiment classification labels. Semantic Role Labeling (SRL): models the predicate-argument structure of a sentence, and is often described as answering \u0026ldquo;Who did what to whom\u0026rdquo;.\n CoNLL-2004 \u0026amp; CoNLL-2005 Sentence similarity: also known as paraphrase detection\n MRPC (MicRosoft Paraphrase Corpus): It contains pairs of sentences extracted from news sources on the web, with annotations indicating whether each pair is semantically equivalent. QQP (Quora Question Pairs) STS Benchmark: Semantic Textual Similarity Sentence Acceptability: a task to annotate sentences for grammatical acceptability.\n CoLA (Corpus of Linguistic Acceptability): a binary single-sentence classification task. Text Chunking: To divide a text in syntactically correlated parts of words.\n CoNLL-2000 Part-of-Speech (POS) Tagging: tag parts of speech to each token, such as noun, verb, adjective, etc. the Wall Street Journal portion of the Penn Treebank (Marcus et al., 1993).\nMachine Translation: See Standard NLP page.\n WMT 2015 English-Czech data (Large) WMT 2014 English-German data (Medium) IWSLT 2015 English-Vietnamese data (Small) Coreference Resolution: cluster mentions in text that refer to the same underlying real world entities.\n CoNLL-2012 Long-range Dependency\n LAMBADA (LAnguage Modeling Broadened to Account for Discourse Aspects): A collection of narrative passages extracted from the BookCorpus and the task is to predict the last word, which require at least 50 tokens of context for a human to successfully predict. Children’s Book Test: is built from books that are freely available in Project Gutenberg. The task is to predict the missing word among 10 candidates. Multi-task benchmark\n GLUE multi-task benchmark: https://gluebenchmark.com decaNLP benmark: https://decanlp.com Unsupervised pretraining dataset\n Books corpus: The corpus contains \u0026ldquo;over 7,000 unique unpublished books from a variety of genres including Adventure, Fantasy, and Romance.\u0026rdquo; 1B Word Language Model Benchmark English Wikipedia: ~2500M words Cited as:\n@article{weng2019LM, title = \u0026quot;Generalized Language Models\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2019\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2019-01-31-lm/\u0026quot; } Reference [1] Bryan McCann, et al. \u0026ldquo;Learned in translation: Contextualized word vectors.\u0026quot; NIPS. 2017.\n[2] Kevin Clark et al. \u0026ldquo;Semi-Supervised Sequence Modeling with Cross-View Training.\u0026quot; EMNLP 2018.\n[3] Matthew E. Peters, et al. \u0026ldquo;Deep contextualized word representations.\u0026quot; NAACL-HLT 2017.\n[4] OpenAI Blog \u0026ldquo;Improving Language Understanding with Unsupervised Learning\u0026rdquo;, June 11, 2018.\n[5] OpenAI Blog \u0026ldquo;Better Language Models and Their Implications.\u0026quot; Feb 14, 2019.\n[6] Jeremy Howard and Sebastian Ruder. \u0026ldquo;Universal language model fine-tuning for text classification.\u0026quot; ACL 2018.\n[7] Alec Radford et al. \u0026ldquo;Improving Language Understanding by Generative Pre-Training\u0026rdquo;. OpenAI Blog, June 11, 2018.\n[8] Jacob Devlin, et al. \u0026ldquo;BERT: Pre-training of deep bidirectional transformers for language understanding.\u0026quot; arXiv:1810.04805 (2018).\n[9] Mike Schuster, and Kaisuke Nakajima. \u0026ldquo;Japanese and Korean voice search.\u0026quot; ICASSP. 2012.\n[10] Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation\n[11] Ashish Vaswani, et al. \u0026ldquo;Attention is all you need.\u0026quot; NIPS 2017.\n[12] Peter J. Liu, et al. \u0026ldquo;Generating wikipedia by summarizing long sequences.\u0026quot; ICLR 2018.\n[13] Sebastian Ruder. \u0026ldquo;10 Exciting Ideas of 2018 in NLP\u0026rdquo; Dec 2018.\n[14] Alec Radford, et al. \u0026ldquo;Language Models are Unsupervised Multitask Learners.\u0026quot;. 2019.\n[15] Rico Sennrich, et al. \u0026ldquo;Neural machine translation of rare words with subword units.\u0026quot; arXiv preprint arXiv:1508.07909. 2015.\n[16] Zhenzhong Lan, et al. \u0026ldquo;ALBERT: A Lite BERT for Self-supervised Learning of Language Representations.\u0026quot; arXiv Preprint arXiv:1909.11942 (2019).\n[17] Yinhan Liu, et al. \u0026ldquo;RoBERTa: A Robustly Optimized BERT Pretraining Approach.\u0026quot; arXiv Preprint arXiv:1907.11692 (2019).\n[18] Tom B Brown, et al. \u0026ldquo;Language Models are Few-Shot Learners\u0026rdquo; NeuriPS 2020.\n[19] Zhilin Yang et al. “XLNet: Generalized Autoregressive Pretraining for Language Understanding.” NeuriPS 2019.\n[20] Mike Lewis et al. “BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension.” ACL 2020.\n[21] Kevin Clark et al. “ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators.” ICLR 2020.\n[22] Colin Raffel, et al. \u0026ldquo;Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer\u0026rdquo; JMLR 2020.\n","permalink":"https://lilianweng.github.io/posts/2019-01-31-lm/","summary":"[Updated on 2019-02-14: add ULMFiT and GPT-2.] [Updated on 2020-02-29: add ALBERT.] [Updated on 2020-10-25: add RoBERTa.] [Updated on 2020-12-13: add T5.] [Updated on 2020-12-30: add GPT-3.] [Updated on 2021-11-13: add XLNet, BART and ELECTRA; Also updated the Summary section.]\nFig. 0. I guess they are Elmo \u0026 Bert? (Image source: here) We have seen amazing progress in NLP in 2018. Large-scale pre-trained language modes like OpenAI GPT and BERT have achieved great performance on a variety of language tasks using generic model architectures.","title":"Generalized Language Models"},{"content":"In Part 3, we have reviewed models in the R-CNN family. All of them are region-based object detection algorithms. They can achieve high accuracy but could be too slow for certain applications such as autonomous driving. In Part 4, we only focus on fast object detection models, including SSD, RetinaNet, and models in the YOLO family.\nLinks to all the posts in the series: [Part 1] [Part 2] [Part 3] [Part 4].\nTwo-stage vs One-stage Detectors Models in the R-CNN family are all region-based. The detection happens in two stages: (1) First, the model proposes a set of regions of interests by select search or regional proposal network. The proposed regions are sparse as the potential bounding box candidates can be infinite. (2) Then a classifier only processes the region candidates.\nThe other different approach skips the region proposal stage and runs detection directly over a dense sampling of possible locations. This is how a one-stage object detection algorithm works. This is faster and simpler, but might potentially drag down the performance a bit.\nAll the models introduced in this post are one-stage detectors.\nYOLO: You Only Look Once The YOLO model (\u0026ldquo;You Only Look Once\u0026rdquo;; Redmon et al., 2016) is the very first attempt at building a fast real-time object detector. Because YOLO does not undergo the region proposal step and only predicts over a limited number of bounding boxes, it is able to do inference super fast.\nWorkflow Pre-train a CNN network on image classification task.\n Split an image into $S \\times S$ cells. If an object\u0026rsquo;s center falls into a cell, that cell is \u0026ldquo;responsible\u0026rdquo; for detecting the existence of that object. Each cell predicts (a) the location of $B$ bounding boxes, (b) a confidence score, and (c) a probability of object class conditioned on the existence of an object in the bounding box.\n The coordinates of bounding box are defined by a tuple of 4 values, (center x-coord, center y-coord, width, height) \u0026mdash; $(x, y, w, h)$, where $x$ and $y$ are set to be offset of a cell location. Moreover, $x$, $y$, $w$ and $h$ are normalized by the image width and height, and thus all between (0, 1]. A confidence score indicates the likelihood that the cell contains an object: Pr(containing an object) x IoU(pred, truth); where Pr = probability and IoU = interaction under union. If the cell contains an object, it predicts a probability of this object belonging to every class $C_i, i=1, \\dots, K$: Pr(the object belongs to the class C_i | containing an object). At this stage, the model only predicts one set of class probabilities per cell, regardless of the number of bounding boxes, $B$. In total, one image contains $S \\times S \\times B$ bounding boxes, each box corresponding to 4 location predictions, 1 confidence score, and K conditional probabilities for object classification. The total prediction values for one image is $S \\times S \\times (5B + K)$, which is the tensor shape of the final conv layer of the model. The final layer of the pre-trained CNN is modified to output a prediction tensor of size $S \\times S \\times (5B + K)$.\n Fig. 1. The workflow of YOLO model. (Image source: original paper) Network Architecture The base model is similar to GoogLeNet with inception module replaced by 1x1 and 3x3 conv layers. The final prediction of shape $S \\times S \\times (5B + K)$ is produced by two fully connected layers over the whole conv feature map.\nFig. 2. The network architecture of YOLO. Loss Function The loss consists of two parts, the localization loss for bounding box offset prediction and the classification loss for conditional class probabilities. Both parts are computed as the sum of squared errors. Two scale parameters are used to control how much we want to increase the loss from bounding box coordinate predictions ($\\lambda_\\text{coord}$) and how much we want to decrease the loss of confidence score predictions for boxes without objects ($\\lambda_\\text{noobj}$). Down-weighting the loss contributed by background boxes is important as most of the bounding boxes involve no instance. In the paper, the model sets $\\lambda_\\text{coord} = 5$ and $\\lambda_\\text{noobj} = 0.5$.\n $$ \\begin{aligned} \\mathcal{L}_\\text{loc} \u0026= \\lambda_\\text{coord} \\sum_{i=0}^{S^2} \\sum_{j=0}^B \\mathbb{1}_{ij}^\\text{obj} [(x_i - \\hat{x}_i)^2 + (y_i - \\hat{y}_i)^2 + (\\sqrt{w_i} - \\sqrt{\\hat{w}_i})^2 + (\\sqrt{h_i} - \\sqrt{\\hat{h}_i})^2 ] \\\\ \\mathcal{L}_\\text{cls} \u0026= \\sum_{i=0}^{S^2} \\sum_{j=0}^B \\big( \\mathbb{1}_{ij}^\\text{obj} + \\lambda_\\text{noobj} (1 - \\mathbb{1}_{ij}^\\text{obj})\\big) (C_{ij} - \\hat{C}_{ij})^2 + \\sum_{i=0}^{S^2} \\sum_{c \\in \\mathcal{C}} \\mathbb{1}_i^\\text{obj} (p_i(c) - \\hat{p}_i(c))^2\\\\ \\mathcal{L} \u0026= \\mathcal{L}_\\text{loc} + \\mathcal{L}_\\text{cls} \\end{aligned} $$ NOTE: In the original YOLO paper, the loss function uses $C_i$ instead of $C_{ij}$ as confidence score. I made the correction based on my own understanding, since every bounding box should have its own confidence score. Please kindly let me if you do not agree. Many thanks.\n where,\n $\\mathbb{1}_i^\\text{obj}$: An indicator function of whether the cell i contains an object. $\\mathbb{1}_{ij}^\\text{obj}$: It indicates whether the j-th bounding box of the cell i is \u0026ldquo;responsible\u0026rdquo; for the object prediction (see Fig. 3). $C_{ij}$: The confidence score of cell i, Pr(containing an object) * IoU(pred, truth). $\\hat{C}_{ij}$: The predicted confidence score. $\\mathcal{C}$: The set of all classes. $p_i(c)$: The conditional probability of whether cell i contains an object of class $c \\in \\mathcal{C}$. $\\hat{p}_i(c)$: The predicted conditional class probability. Fig. 3. At one location, in cell i, the model proposes B bounding box candidates and the one that has highest overlap with the ground truth is the \"responsible\" predictor. The loss function only penalizes classification error if an object is present in that grid cell, $\\mathbb{1}_i^\\text{obj} = 1$. It also only penalizes bounding box coordinate error if that predictor is \u0026ldquo;responsible\u0026rdquo; for the ground truth box, $\\mathbb{1}_{ij}^\\text{obj} = 1$.\nAs a one-stage object detector, YOLO is super fast, but it is not good at recognizing irregularly shaped objects or a group of small objects due to a limited number of bounding box candidates.\nSSD: Single Shot MultiBox Detector The Single Shot Detector (SSD; Liu et al, 2016) is one of the first attempts at using convolutional neural network\u0026rsquo;s pyramidal feature hierarchy for efficient detection of objects of various sizes.\nImage Pyramid SSD uses the VGG-16 model pre-trained on ImageNet as its base model for extracting useful image features. On top of VGG16, SSD adds several conv feature layers of decreasing sizes. They can be seen as a pyramid representation of images at different scales. Intuitively large fine-grained feature maps at earlier levels are good at capturing small objects and small coarse-grained feature maps can detect large objects well. In SSD, the detection happens in every pyramidal layer, targeting at objects of various sizes.\nFig. 4. The model architecture of SSD. Workflow Unlike YOLO, SSD does not split the image into grids of arbitrary size but predicts offset of predefined anchor boxes (this is called \u0026ldquo;default boxes\u0026rdquo; in the paper) for every location of the feature map. Each box has a fixed size and position relative to its corresponding cell. All the anchor boxes tile the whole feature map in a convolutional manner.\nFeature maps at different levels have different receptive field sizes. The anchor boxes on different levels are rescaled so that one feature map is only responsible for objects at one particular scale. For example, in Fig. 5 the dog can only be detected in the 4x4 feature map (higher level) while the cat is just captured by the 8x8 feature map (lower level).\nFig. 5. The SSD framework. (a) The training data contains images and ground truth boxes for every object. (b) In a fine-grained feature maps (8 x 8), the anchor boxes of different aspect ratios correspond to smaller area of the raw input. (c) In a coarse-grained feature map (4 x 4), the anchor boxes cover larger area of the raw input. (Image source: original paper) The width, height and the center location of an anchor box are all normalized to be (0, 1). At a location $(i, j)$ of the $\\ell$-th feature layer of size $m \\times n$, $i=1,\\dots,n, j=1,\\dots,m$, we have a unique linear scale proportional to the layer level and 5 different box aspect ratios (width-to-height ratios), in addition to a special scale (why we need this? the paper didn’t explain. maybe just a heuristic trick) when the aspect ratio is 1. This gives us 6 anchor boxes in total per feature cell.\n $$ \\begin{aligned} \\text{level index: } \u0026\\ell = 1, \\dots, L \\\\ \\text{scale of boxes: } \u0026s_\\ell = s_\\text{min} + \\frac{s_\\text{max} - s_\\text{min}}{L - 1} (\\ell - 1) \\\\ \\text{aspect ratio: } \u0026r \\in \\{1, 2, 3, 1/2, 1/3\\}\\\\ \\text{additional scale: } \u0026 s'_\\ell = \\sqrt{s_\\ell s_{\\ell + 1}} \\text{ when } r = 1 \\text{thus, 6 boxes in total.}\\\\ \\text{width: } \u0026w_\\ell^r = s_\\ell \\sqrt{r} \\\\ \\text{height: } \u0026h_\\ell^r = s_\\ell / \\sqrt{r} \\\\ \\text{center location: } \u0026 (x^i_\\ell, y^j_\\ell) = (\\frac{i+0.5}{m}, \\frac{j+0.5}{n}) \\end{aligned} $$ Fig. 6. An example of how the anchor box size is scaled up with the layer index $\\ell$ for $L=6, s\\_\\text{min} = 0.2, s\\_\\text{max} = 0.9$. Only the boxes of aspect ratio $r=1$ are illustrated. At every location, the model outputs 4 offsets and $c$ class probabilities by applying a $3 \\times 3 \\times p$ conv filter (where $p$ is the number of channels in the feature map) for every one of $k$ anchor boxes. Therefore, given a feature map of size $m \\times n$, we need $kmn(c+4)$ prediction filters.\nLoss Function Same as YOLO, the loss function is the sum of a localization loss and a classification loss.\n$\\mathcal{L} = \\frac{1}{N}(\\mathcal{L}_\\text{cls} + \\alpha \\mathcal{L}_\\text{loc})$\nwhere $N$ is the number of matched bounding boxes and $\\alpha$ balances the weights between two losses, picked by cross validation.\nThe localization loss is a smooth L1 loss between the predicted bounding box correction and the true values. The coordinate correction transformation is same as what R-CNN does in bounding box regression.\n $$ \\begin{aligned} \\mathcal{L}_\\text{loc} \u0026= \\sum_{i,j} \\sum_{m\\in\\{x, y, w, h\\}} \\mathbb{1}_{ij}^\\text{match} L_1^\\text{smooth}(d_m^i - t_m^j)^2\\\\ L_1^\\text{smooth}(x) \u0026= \\begin{cases} 0.5 x^2 \u0026 \\text{if } \\vert x \\vert where $\\mathbb{1}_{ij}^\\text{match}$ indicates whether the $i$-th bounding box with coordinates $(p^i_x, p^i_y, p^i_w, p^i_h)$ is matched to the $j$-th ground truth box with coordinates $(g^j_x, g^j_y, g^j_w, g^j_h)$ for any object. $d^i_m, m\\in\\{x, y, w, h\\}$ are the predicted correction terms. See this for how the transformation works.\nThe classification loss is a softmax loss over multiple classes (softmax_cross_entropy_with_logits in tensorflow):\n $$ \\mathcal{L}_\\text{cls} = -\\sum_{i \\in \\text{pos}} \\mathbb{1}_{ij}^k \\log(\\hat{c}_i^k) - \\sum_{i \\in \\text{neg}} \\log(\\hat{c}_i^0)\\text{, where }\\hat{c}_i^k = \\text{softmax}(c_i^k) $$ where $\\mathbb{1}_{ij}^k$ indicates whether the $i$-th bounding box and the $j$-th ground truth box are matched for an object in class $k$. $\\text{pos}$ is the set of matched bounding boxes ($N$ items in total) and $\\text{neg}$ is the set of negative examples. SSD uses hard negative mining to select easily misclassified negative examples to construct this $\\text{neg}$ set: Once all the anchor boxes are sorted by objectiveness confidence score, the model picks the top candidates for training so that neg:pos is at most 3:1.\nYOLOv2 / YOLO9000 YOLOv2 (Redmon \u0026amp; Farhadi, 2017) is an enhanced version of YOLO. YOLO9000 is built on top of YOLOv2 but trained with joint dataset combining the COCO detection dataset and the top 9000 classes from ImageNet.\nYOLOv2 Improvement A variety of modifications are applied to make YOLO prediction more accurate and faster, including:\n1. BatchNorm helps: Add batch norm on all the convolutional layers, leading to significant improvement over convergence.\n2. Image resolution matters: Fine-tuning the base model with high resolution images improves the detection performance.\n3. Convolutional anchor box detection: Rather than predicts the bounding box position with fully-connected layers over the whole feature map, YOLOv2 uses convolutional layers to predict locations of anchor boxes, like in faster R-CNN. The prediction of spatial locations and class probabilities are decoupled. Overall, the change leads to a slight decrease in mAP, but an increase in recall.\n4. K-mean clustering of box dimensions: Different from faster R-CNN that uses hand-picked sizes of anchor boxes, YOLOv2 runs k-mean clustering on the training data to find good priors on anchor box dimensions. The distance metric is designed to rely on IoU scores:\n $$ \\text{dist}(x, c_i) = 1 - \\text{IoU}(x, c_i), i=1,\\dots,k $$ where $x$ is a ground truth box candidate and $c_i$ is one of the centroids. The best number of centroids (anchor boxes) $k$ can be chosen by the elbow method.\nThe anchor boxes generated by clustering provide better average IoU conditioned on a fixed number of boxes.\n5. Direct location prediction: YOLOv2 formulates the bounding box prediction in a way that it would not diverge from the center location too much. If the box location prediction can place the box in any part of the image, like in regional proposal network, the model training could become unstable.\nGiven the anchor box of size $(p_w, p_h)$ at the grid cell with its top left corner at $(c_x, c_y)$, the model predicts the offset and the scale, $(t_x, t_y, t_w, t_h)$ and the corresponding predicted bounding box $b$ has center $(b_x, b_y)$ and size $(b_w, b_h)$. The confidence score is the sigmoid ($\\sigma$) of another output $t_o$.\n $$ \\begin{aligned} b_x \u0026= \\sigma(t_x) + c_x\\\\ b_y \u0026= \\sigma(t_y) + c_y\\\\ b_w \u0026= p_w e^{t_w}\\\\ b_h \u0026= p_h e^{t_h}\\\\ \\text{Pr}(\\text{object}) \u0026\\cdot \\text{IoU}(b, \\text{object}) = \\sigma(t_o) \\end{aligned} $$ Fig. 7. YOLOv2 bounding box location prediction. (Image source: original paper) 6. Add fine-grained features: YOLOv2 adds a passthrough layer to bring fine-grained features from an earlier layer to the last output layer. The mechanism of this passthrough layer is similar to identity mappings in ResNet to extract higher-dimensional features from previous layers. This leads to 1% performance increase.\n7. Multi-scale training: In order to train the model to be robust to input images of different sizes, a new size of input dimension is randomly sampled every 10 batches. Since conv layers of YOLOv2 downsample the input dimension by a factor of 32, the newly sampled size is a multiple of 32.\n8. Light-weighted base model: To make prediction even faster, YOLOv2 adopts a light-weighted base model, DarkNet-19, which has 19 conv layers and 5 max-pooling layers. The key point is to insert avg poolings and 1x1 conv filters between 3x3 conv layers.\nYOLO9000: Rich Dataset Training Because drawing bounding boxes on images for object detection is much more expensive than tagging images for classification, the paper proposed a way to combine small object detection dataset with large ImageNet so that the model can be exposed to a much larger number of object categories. The name of YOLO9000 comes from the top 9000 classes in ImageNet. During joint training, if an input image comes from the classification dataset, it only backpropagates the classification loss.\nThe detection dataset has much fewer and more general labels and, moreover, labels cross multiple datasets are often not mutually exclusive. For example, ImageNet has a label “Persian cat” while in COCO the same image would be labeled as “cat”. Without mutual exclusiveness, it does not make sense to apply softmax over all the classes.\nIn order to efficiently merge ImageNet labels (1000 classes, fine-grained) with COCO/PASCAL (\u0026lt; 100 classes, coarse-grained), YOLO9000 built a hierarchical tree structure with reference to WordNet so that general labels are closer to the root and the fine-grained class labels are leaves. In this way, \u0026ldquo;cat\u0026rdquo; is the parent node of \u0026ldquo;Persian cat\u0026rdquo;.\nFig. 8. The WordTree hierarchy merges labels from COCO and ImageNet. Blue nodes are COCO labels and red nodes are ImageNet labels. (Image source: original paper) To predict the probability of a class node, we can follow the path from the node to the root:\nPr(\u0026quot;persian cat\u0026quot; | contain a \u0026quot;physical object\u0026quot;) = Pr(\u0026quot;persian cat\u0026quot; | \u0026quot;cat\u0026quot;) Pr(\u0026quot;cat\u0026quot; | \u0026quot;animal\u0026quot;) Pr(\u0026quot;animal\u0026quot; | \u0026quot;physical object\u0026quot;) Pr(contain a \u0026quot;physical object\u0026quot;) # confidence score. Note that Pr(contain a \u0026quot;physical object\u0026quot;) is the confidence score, predicted separately in the bounding box detection pipeline. The path of conditional probability prediction can stop at any step, depending on which labels are available.\nRetinaNet The RetinaNet (Lin et al., 2018) is a one-stage dense object detector. Two crucial building blocks are featurized image pyramid and the use of focal loss.\nFocal Loss One issue for object detection model training is an extreme imbalance between background that contains no object and foreground that holds objects of interests. Focal loss is designed to assign more weights on hard, easily misclassified examples (i.e. background with noisy texture or partial object) and to down-weight easy examples (i.e. obviously empty background).\nStarting with a normal cross entropy loss for binary classification,\n $$ \\text{CE}(p, y) = -y\\log p - (1-y)\\log(1-p) $$ where $y \\in \\{0, 1\\}$ is a ground truth binary label, indicating whether a bounding box contains a object, and $p \\in [0, 1]$ is the predicted probability of objectiveness (aka confidence score).\nFor notational convenience,\n $$ \\text{let } p_t = \\begin{cases} p \u0026 \\text{if } y = 1\\\\ 1-p \u0026 \\text{otherwise} \\end{cases}, \\text{then } \\text{CE}(p, y)=\\text{CE}(p_t) = -\\log p_t $$ Easily classified examples with large $p_t \\gg 0.5$, that is, when $p$ is very close to 0 (when y=0) or 1 (when y=1), can incur a loss with non-trivial magnitude. Focal loss explicitly adds a weighting factor $(1-p_t)^\\gamma, \\gamma \\geq 0$ to each term in cross entropy so that the weight is small when $p_t$ is large and therefore easy examples are down-weighted.\n $$ \\text{FL}(p_t) = -(1-p_t)^\\gamma \\log p_t $$ Fig. 9. The focal loss focuses less on easy examples with a factor of $(1-p\\_t)^\\gamma$. (Image source: original paper) For a better control of the shape of the weighting function (see Fig. 10.), RetinaNet uses an $\\alpha$-balanced variant of the focal loss, where $\\alpha=0.25, \\gamma=2$ works the best.\n $$ \\text{FL}(p_t) = -\\alpha (1-p_t)^\\gamma \\log p_t $$ Fig. 10. The plot of focal loss weights $\\alpha (1-p\\_t)^\\gamma$ as a function of $p\\_t$, given different values of $\\alpha$ and $\\gamma$. Featurized Image Pyramid The featurized image pyramid (Lin et al., 2017) is the backbone network for RetinaNet. Following the same approach by image pyramid in SSD, featurized image pyramids provide a basic vision component for object detection at different scales.\nThe key idea of feature pyramid network is demonstrated in Fig. 11. The base structure contains a sequence of pyramid levels, each corresponding to one network stage. One stage contains multiple convolutional layers of the same size and the stage sizes are scaled down by a factor of 2. Let\u0026rsquo;s denote the last layer of the $i$-th stage as $C_i$.\nFig. 11. The illustration of the featurized image pyramid module. (Replot based on figure 3 in FPN paper) Two pathways connect conv layers:\n Bottom-up pathway is the normal feedforward computation. Top-down pathway goes in the inverse direction, adding coarse but semantically stronger feature maps back into the previous pyramid levels of a larger size via lateral connections. First, the higher-level features are upsampled spatially coarser to be 2x larger. For image upscaling, the paper used nearest neighbor upsampling. While there are many image upscaling algorithms such as using deconv, adopting another image scaling method might or might not improve the performance of RetinaNet. The larger feature map undergoes a 1x1 conv layer to reduce the channel dimension. Finally, these two feature maps are merged by element-wise addition. The lateral connections only happen at the last layer in stages, denoted as $\\{C_i\\}$, and the process continues until the finest (largest) merged feature map is generated. The prediction is made out of every merged map after a 3x3 conv layer, $\\{P_i\\}$. According to ablation studies, the importance rank of components of the featurized image pyramid design is as follows: 1x1 lateral connection \u0026gt; detect object across multiple layers \u0026gt; top-down enrichment \u0026gt; pyramid representation (compared to only check the finest layer).\nModel Architecture The featurized pyramid is constructed on top of the ResNet architecture. Recall that ResNet has 5 conv blocks (= network stages / pyramid levels). The last layer of the $i$-th pyramid level, $C_i$, has resolution $2^i$ lower than the raw input dimension.\nRetinaNet utilizes feature pyramid levels $P_3$ to $P_7$:\n $P_3$ to $P_5$ are computed from the corresponding ResNet residual stage from $C_3$ to $C_5$. They are connected by both top-down and bottom-up pathways. $P_6$ is obtained via a 3×3 stride-2 conv on top of $C_5$ $P_7$ applies ReLU and a 3×3 stride-2 conv on $P_6$. Adding higher pyramid levels on ResNet improves the performance for detecting large objects.\nSame as in SSD, detection happens in all pyramid levels by making a prediction out of every merged feature map. Because predictions share the same classifier and the box regressor, they are all formed to have the same channel dimension d=256.\nThere are A=9 anchor boxes per level:\n The base size corresponds to areas of $32^2$ to $512^2$ pixels on $P_3$ to $P_7$ respectively. There are three size ratios, $\\{2^0, 2^{1/3}, 2^{2/3}\\}$. For each size, there are three aspect ratios {1/2, 1, 2}. As usual, for each anchor box, the model outputs a class probability for each of $K$ classes in the classification subnet and regresses the offset from this anchor box to the nearest ground truth object in the box regression subnet. The classification subnet adopts the focal loss introduced above.\nFig. 12. The RetinaNet model architecture uses a FPN backbone on top of ResNet. (Image source: the FPN paper) YOLOv3 YOLOv3 is created by applying a bunch of design tricks on YOLOv2. The changes are inspired by recent advances in the object detection world.\nHere are a list of changes:\n1. Logistic regression for confidence scores: YOLOv3 predicts an confidence score for each bounding box using logistic regression, while YOLO and YOLOv2 uses sum of squared errors for classification terms (see the loss function above). Linear regression of offset prediction leads to a decrease in mAP.\n2. No more softmax for class prediction: When predicting class confidence, YOLOv3 uses multiple independent logistic classifier for each class rather than one softmax layer. This is very helpful especially considering that one image might have multiple labels and not all the labels are guaranteed to be mutually exclusive.\n3. Darknet + ResNet as the base model: The new Darknet-53 still relies on successive 3x3 and 1x1 conv layers, just like the original dark net architecture, but has residual blocks added.\n4. Multi-scale prediction: Inspired by image pyramid, YOLOv3 adds several conv layers after the base feature extractor model and makes prediction at three different scales among these conv layers. In this way, it has to deal with many more bounding box candidates of various sizes overall.\n5. Skip-layer concatenation: YOLOv3 also adds cross-layer connections between two prediction layers (except for the output layer) and earlier finer-grained feature maps. The model first up-samples the coarse feature maps and then merges it with the previous features by concatenation. The combination with finer-grained information makes it better at detecting small objects.\nInterestingly, focal loss does not help YOLOv3, potentially it might be due to the usage of $\\lambda_\\text{noobj}$ and $\\lambda_\\text{coord}$ \u0026mdash; they increase the loss from bounding box location predictions and decrease the loss from confidence predictions for background boxes.\nOverall YOLOv3 performs better and faster than SSD, and worse than RetinaNet but 3.8x faster.\nFig. 13. The comparison of various fast object detection models on speed and mAP performance. (Image source: focal loss paper with additional labels from the YOLOv3 paper.) Cited as:\n@article{weng2018detection4, title = \u0026quot;Object Detection Part 4: Fast Detection Models\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2018\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2018-12-27-object-recognition-part-4/\u0026quot; } Reference [1] Joseph Redmon, et al. \u0026ldquo;You only look once: Unified, real-time object detection.\u0026quot; CVPR 2016.\n[2] Joseph Redmon and Ali Farhadi. \u0026ldquo;YOLO9000: Better, Faster, Stronger.\u0026quot; CVPR 2017.\n[3] Joseph Redmon, Ali Farhadi. \u0026ldquo;YOLOv3: An incremental improvement.\u0026quot;.\n[4] Wei Liu et al. \u0026ldquo;SSD: Single Shot MultiBox Detector.\u0026quot; ECCV 2016.\n[5] Tsung-Yi Lin, et al. \u0026ldquo;Feature Pyramid Networks for Object Detection.\u0026quot; CVPR 2017.\n[6] Tsung-Yi Lin, et al. \u0026ldquo;Focal Loss for Dense Object Detection.\u0026quot; IEEE transactions on pattern analysis and machine intelligence, 2018.\n[7] \u0026ldquo;What\u0026rsquo;s new in YOLO v3?\u0026quot; by Ayoosh Kathuria on \u0026ldquo;Towards Data Science\u0026rdquo;, Apr 23, 2018.\n","permalink":"https://lilianweng.github.io/posts/2018-12-27-object-recognition-part-4/","summary":"In Part 3, we have reviewed models in the R-CNN family. All of them are region-based object detection algorithms. They can achieve high accuracy but could be too slow for certain applications such as autonomous driving. In Part 4, we only focus on fast object detection models, including SSD, RetinaNet, and models in the YOLO family.\nLinks to all the posts in the series: [Part 1] [Part 2] [Part 3] [Part 4].","title":"Object Detection Part 4: Fast Detection Models"},{"content":"[Updated on 2019-10-01: thanks to Tianhao, we have this post translated in Chinese!]\nA good machine learning model often requires training with a large number of samples. Humans, in contrast, learn new concepts and skills much faster and more efficiently. Kids who have seen cats and birds only a few times can quickly tell them apart. People who know how to ride a bike are likely to discover the way to ride a motorcycle fast with little or even no demonstration. Is it possible to design a machine learning model with similar properties \u0026mdash; learning new concepts and skills fast with a few training examples? That\u0026rsquo;s essentially what meta-learning aims to solve.\nWe expect a good meta-learning model capable of well adapting or generalizing to new tasks and new environments that have never been encountered during training time. The adaptation process, essentially a mini learning session, happens during test but with a limited exposure to the new task configurations. Eventually, the adapted model can complete new tasks. This is why meta-learning is also known as learning to learn.\nThe tasks can be any well-defined family of machine learning problems: supervised learning, reinforcement learning, etc. For example, here are a couple concrete meta-learning tasks:\n A classifier trained on non-cat images can tell whether a given image contains a cat after seeing a handful of cat pictures. A game bot is able to quickly master a new game. A mini robot completes the desired task on an uphill surface during test even through it was only trained in a flat surface environment. Define the Meta-Learning Problem In this post, we focus on the case when each desired task is a supervised learning problem like image classification. There is a lot of interesting literature on meta-learning with reinforcement learning problems (aka \u0026ldquo;Meta Reinforcement Learning\u0026rdquo;), but we would not cover them here.\nA Simple View A good meta-learning model should be trained over a variety of learning tasks and optimized for the best performance on a distribution of tasks, including potentially unseen tasks. Each task is associated with a dataset $\\mathcal{D}$, containing both feature vectors and true labels. The optimal model parameters are:\n $$ \\theta^* = \\arg\\min_\\theta \\mathbb{E}_{\\mathcal{D}\\sim p(\\mathcal{D})} [\\mathcal{L}_\\theta(\\mathcal{D})] $$ It looks very similar to a normal learning task, but one dataset is considered as one data sample.\nFew-shot classification is an instantiation of meta-learning in the field of supervised learning. The dataset $\\mathcal{D}$ is often split into two parts, a support set $S$ for learning and a prediction set $B$ for training or testing, $\\mathcal{D}=\\langle S, B\\rangle$. Often we consider a K-shot N-class classification task: the support set contains K labelled examples for each of N classes.\nFig. 1. An example of 4-shot 2-class image classification. (Image thumbnails are from Pinterest) Training in the Same Way as Testing A dataset $\\mathcal{D}$ contains pairs of feature vectors and labels, $\\mathcal{D} = \\{(\\mathbf{x}_i, y_i)\\}$ and each label belongs to a known label set $\\mathcal{L}^\\text{label}$. Let\u0026rsquo;s say, our classifier $f_\\theta$ with parameter $\\theta$ outputs a probability of a data point belonging to the class $y$ given the feature vector $\\mathbf{x}$, $P_\\theta(y\\vert\\mathbf{x})$.\nThe optimal parameters should maximize the probability of true labels across multiple training batches $B \\subset \\mathcal{D}$:\n $$ \\begin{aligned} \\theta^* \u0026= {\\arg\\max}_{\\theta} \\mathbb{E}_{(\\mathbf{x}, y)\\in \\mathcal{D}}[P_\\theta(y \\vert \\mathbf{x})] \u0026\\\\ \\theta^* \u0026= {\\arg\\max}_{\\theta} \\mathbb{E}_{B\\subset \\mathcal{D}}[\\sum_{(\\mathbf{x}, y)\\in B}P_\\theta(y \\vert \\mathbf{x})] \u0026 \\scriptstyle{\\text{; trained with mini-batches.}} \\end{aligned} $$ In few-shot classification, the goal is to reduce the prediction error on data samples with unknown labels given a small support set for \u0026ldquo;fast learning\u0026rdquo; (think of how \u0026ldquo;fine-tuning\u0026rdquo; works). To make the training process mimics what happens during inference, we would like to \u0026ldquo;fake\u0026rdquo; datasets with a subset of labels to avoid exposing all the labels to the model and modify the optimization procedure accordingly to encourage fast learning:\n Sample a subset of labels, $L\\subset\\mathcal{L}^\\text{label}$. Sample a support set $S^L \\subset \\mathcal{D}$ and a training batch $B^L \\subset \\mathcal{D}$. Both of them only contain data points with labels belonging to the sampled label set $L$, $y \\in L, \\forall (x, y) \\in S^L, B^L$. The support set is part of the model input. The final optimization uses the mini-batch $B^L$ to compute the loss and update the model parameters through backpropagation, in the same way as how we use it in the supervised learning. You may consider each pair of sampled dataset $(S^L, B^L)$ as one data point. The model is trained such that it can generalize to other datasets. Symbols in red are added for meta-learning in addition to the supervised learning objective.\n $$ \\theta = \\arg\\max_\\theta \\color{red}{E_{L\\subset\\mathcal{L}}[} E_{\\color{red}{S^L \\subset\\mathcal{D}, }B^L \\subset\\mathcal{D}} [\\sum_{(x, y)\\in B^L} P_\\theta(x, y\\color{red}{, S^L})] \\color{red}{]} $$ The idea is to some extent similar to using a pre-trained model in image classification (ImageNet) or language modeling (big text corpora) when only a limited set of task-specific data samples are available. Meta-learning takes this idea one step further, rather than fine-tuning according to one down-steam task, it optimizes the model to be good at many, if not all.\nLearner and Meta-Learner Another popular view of meta-learning decomposes the model update into two stages:\n A classifier $f_\\theta$ is the \u0026ldquo;learner\u0026rdquo; model, trained for operating a given task; In the meantime, a optimizer $g_\\phi$ learns how to update the learner model\u0026rsquo;s parameters via the support set $S$, $\\theta' = g_\\phi(\\theta, S)$. Then in final optimization step, we need to update both $\\theta$ and $\\phi$ to maximize:\n $$ \\mathbb{E}_{L\\subset\\mathcal{L}}[ \\mathbb{E}_{S^L \\subset\\mathcal{D}, B^L \\subset\\mathcal{D}} [\\sum_{(\\mathbf{x}, y)\\in B^L} P_{g_\\phi(\\theta, S^L)}(y \\vert \\mathbf{x})]] $$ Common Approaches There are three common approaches to meta-learning: metric-based, model-based, and optimization-based. Oriol Vinyals has a nice summary in his talk at meta-learning symposium @ NIPS 2018:\n| \u0026mdash;\u0026mdash;\u0026mdash;\u0026mdash;- | \u0026mdash;\u0026mdash;\u0026mdash;\u0026mdash;- | \u0026mdash;\u0026mdash;\u0026mdash;\u0026mdash;- | \u0026mdash;\u0026mdash;\u0026mdash;\u0026mdash;- |\n Model-based Metric-based Optimization-based Key idea RNN; memory Metric learning Gradient descent How $P_\\theta(y \\vert \\mathbf{x})$ is modeled? $f_\\theta(\\mathbf{x}, S)$ $\\sum_{(\\mathbf{x}_i, y_i) \\in S} k_\\theta(\\mathbf{x}, \\mathbf{x}_i)y_i$ (*) $P_{g_\\phi(\\theta, S^L)}(y \\vert \\mathbf{x})$ (*) $k_\\theta$ is a kernel function measuring the similarity between $\\mathbf{x}_i$ and $\\mathbf{x}$.\nNext we are gonna review classic models in each approach.\nMetric-Based The core idea in metric-based meta-learning is similar to nearest neighbors algorithms (i.e., k-NN classificer and k-means clustering) and kernel density estimation. The predicted probability over a set of known labels $y$ is a weighted sum of labels of support set samples. The weight is generated by a kernel function $k_\\theta$, measuring the similarity between two data samples.\n $$ P_\\theta(y \\vert \\mathbf{x}, S) = \\sum_{(\\mathbf{x}_i, y_i) \\in S} k_\\theta(\\mathbf{x}, \\mathbf{x}_i)y_i $$ To learn a good kernel is crucial to the success of a metric-based meta-learning model. Metric learning is well aligned with this intention, as it aims to learn a metric or distance function over objects. The notion of a good metric is problem-dependent. It should represent the relationship between inputs in the task space and facilitate problem solving.\nAll the models introduced below learn embedding vectors of input data explicitly and use them to design proper kernel functions.\nConvolutional Siamese Neural Network The Siamese Neural Network is composed of two twin networks and their outputs are jointly trained on top with a function to learn the relationship between pairs of input data samples. The twin networks are identical, sharing the same weights and network parameters. In other words, both refer to the same embedding network that learns an efficient embedding to reveal relationship between pairs of data points.\nKoch, Zemel \u0026amp; Salakhutdinov (2015) proposed a method to use the siamese neural network to do one-shot image classification. First, the siamese network is trained for a verification task for telling whether two input images are in the same class. It outputs the probability of two images belonging to the same class. Then, during test time, the siamese network processes all the image pairs between a test image and every image in the support set. The final prediction is the class of the support image with the highest probability.\nFig. 2. The architecture of convolutional siamese neural network for few-show image classification. First, convolutional siamese network learns to encode two images into feature vectors via a embedding function $f_\\theta$ which contains a couple of convolutional layers. The L1-distance between two embeddings is $\\vert f_\\theta(\\mathbf{x}_i) - f_\\theta(\\mathbf{x}_j) \\vert$. The distance is converted to a probability $p$ by a linear feedforward layer and sigmoid. It is the probability of whether two images are drawn from the same class. Intuitively the loss is cross entropy because the label is binary. $$ \\begin{aligned} p(\\mathbf{x}_i, \\mathbf{x}_j) \u0026= \\sigma(\\mathbf{W}\\vert f_\\theta(\\mathbf{x}_i) - f_\\theta(\\mathbf{x}_j) \\vert) \\\\ \\mathcal{L}(B) \u0026= \\sum_{(\\mathbf{x}_i, \\mathbf{x}_j, y_i, y_j)\\in B} \\mathbf{1}_{y_i=y_j}\\log p(\\mathbf{x}_i, \\mathbf{x}_j) + (1-\\mathbf{1}_{y_i=y_j})\\log (1-p(\\mathbf{x}_i, \\mathbf{x}_j)) \\end{aligned} $$ Images in the training batch $B$ can be augmented with distortion. Of course, you can replace the L1 distance with other distance metric, L2, cosine, etc. Just make sure they are differential and then everything else works the same.\nGiven a support set $S$ and a test image $\\mathbf{x}$, the final predicted class is:\n $$ \\hat{c}_S(\\mathbf{x}) = c(\\arg\\max_{\\mathbf{x}_i \\in S} P(\\mathbf{x}, \\mathbf{x}_i)) $$ where $c(\\mathbf{x})$ is the class label of an image $\\mathbf{x}$ and $\\hat{c}(.)$ is the predicted label.\nThe assumption is that the learned embedding can be generalized to be useful for measuring the distance between images of unknown categories. This is the same assumption behind transfer learning via the adoption of a pre-trained model; for example, the convolutional features learned in the model pre-trained with ImageNet are expected to help other image tasks. However, the benefit of a pre-trained model decreases when the new task diverges from the original task that the model was trained on.\nMatching Networks The task of Matching Networks (Vinyals et al., 2016) is to learn a classifier $c_S$ for any given (small) support set $S=\\{x_i, y_i\\}_{i=1}^k$ (k-shot classification). This classifier defines a probability distribution over output labels $y$ given a test example $\\mathbf{x}$. Similar to other metric-based models, the classifier output is defined as a sum of labels of support samples weighted by attention kernel $a(\\mathbf{x}, \\mathbf{x}_i)$ - which should be proportional to the similarity between $\\mathbf{x}$ and $\\mathbf{x}_i$.\nFig. 3. The architecture of Matching Networks. (Image source: original paper) $$ c_S(\\mathbf{x}) = P(y \\vert \\mathbf{x}, S) = \\sum_{i=1}^k a(\\mathbf{x}, \\mathbf{x}_i) y_i \\text{, where }S=\\{(\\mathbf{x}_i, y_i)\\}_{i=1}^k $$ The attention kernel depends on two embedding functions, $f$ and $g$, for encoding the test sample and the support set samples respectively. The attention weight between two data points is the cosine similarity, $\\text{cosine}(.)$, between their embedding vectors, normalized by softmax:\n $$ a(\\mathbf{x}, \\mathbf{x}_i) = \\frac{\\exp(\\text{cosine}(f(\\mathbf{x}), g(\\mathbf{x}_i))}{\\sum_{j=1}^k\\exp(\\text{cosine}(f(\\mathbf{x}), g(\\mathbf{x}_j))} $$ Simple Embedding In the simple version, an embedding function is a neural network with a single data sample as input. Potentially we can set $f=g$.\nFull Context Embeddings The embedding vectors are critical inputs for building a good classifier. Taking a single data point as input might not be enough to efficiently gauge the entire feature space. Therefore, the Matching Network model further proposed to enhance the embedding functions by taking as input the whole support set $S$ in addition to the original input, so that the learned embedding can be adjusted based on the relationship with other support samples.\n $g_\\theta(\\mathbf{x}_i, S)$ uses a bidirectional LSTM to encode $\\mathbf{x}_i$ in the context of the entire support set $S$.\n $f_\\theta(\\mathbf{x}, S)$ encodes the test sample $\\mathbf{x}$ visa an LSTM with read attention over the support set $S$.\n First the test sample goes through a simple neural network, such as a CNN, to extract basic features, $f'(\\mathbf{x})$. Then an LSTM is trained with a read attention vector over the support set as part of the hidden state: $$ \\begin{aligned} \\hat{\\mathbf{h}}_t, \\mathbf{c}_t \u0026= \\text{LSTM}(f'(\\mathbf{x}), [\\mathbf{h}_{t-1}, \\mathbf{r}_{t-1}], \\mathbf{c}_{t-1}) \\\\ \\mathbf{h}_t \u0026= \\hat{\\mathbf{h}}_t + f'(\\mathbf{x}) \\\\ \\mathbf{r}_{t-1} \u0026= \\sum_{i=1}^k a(\\mathbf{h}_{t-1}, g(\\mathbf{x}_i)) g(\\mathbf{x}_i) \\\\ a(\\mathbf{h}_{t-1}, g(\\mathbf{x}_i)) \u0026= \\text{softmax}(\\mathbf{h}_{t-1}^\\top g(\\mathbf{x}_i)) = \\frac{\\exp(\\mathbf{h}_{t-1}^\\top g(\\mathbf{x}_i))}{\\sum_{j=1}^k \\exp(\\mathbf{h}_{t-1}^\\top g(\\mathbf{x}_j))} \\end{aligned} $$ Eventually $f(\\mathbf{x}, S)=\\mathbf{h}_K$ if we do K steps of \u0026ldquo;read\u0026rdquo;. This embedding method is called \u0026ldquo;Full Contextual Embeddings (FCE)\u0026rdquo;. Interestingly it does help improve the performance on a hard task (few-shot classification on mini ImageNet), but makes no difference on a simple task (Omniglot).\nThe training process in Matching Networks is designed to match inference at test time, see the details in the earlier section. It is worthy of mentioning that the Matching Networks paper refined the idea that training and testing conditions should match.\n $$ \\theta^* = \\arg\\max_\\theta \\mathbb{E}_{L\\subset\\mathcal{L}}[ \\mathbb{E}_{S^L \\subset\\mathcal{D}, B^L \\subset\\mathcal{D}} [\\sum_{(\\mathbf{x}, y)\\in B^L} P_\\theta(y\\vert\\mathbf{x}, S^L)]] $$ Relation Network Relation Network (RN) (Sung et al., 2018) is similar to siamese network but with a few differences:\n The relationship is not captured by a simple L1 distance in the feature space, but predicted by a CNN classifier $g_\\phi$. The relation score between a pair of inputs, $\\mathbf{x}_i$ and $\\mathbf{x}_j$, is $r_{ij} = g_\\phi([\\mathbf{x}_i, \\mathbf{x}_j])$ where $[.,.]$ is concatenation. The objective function is MSE loss instead of cross-entropy, because conceptually RN focuses more on predicting relation scores which is more like regression, rather than binary classification, $\\mathcal{L}(B) = \\sum_{(\\mathbf{x}_i, \\mathbf{x}_j, y_i, y_j)\\in B} (r_{ij} - \\mathbf{1}_{y_i=y_j})^2$. Fig. 4. Relation Network architecture for a 5-way 1-shot problem with one query example. (Image source: original paper) (Note: There is another Relation Network for relational reasoning, proposed by DeepMind. Don\u0026rsquo;t get confused.)\nPrototypical Networks Prototypical Networks (Snell, Swersky \u0026amp; Zemel, 2017) use an embedding function $f_\\theta$ to encode each input into a $M$-dimensional feature vector. A prototype feature vector is defined for every class $c \\in \\mathcal{C}$, as the mean vector of the embedded support data samples in this class.\n $$ \\mathbf{v}_c = \\frac{1}{|S_c|} \\sum_{(\\mathbf{x}_i, y_i) \\in S_c} f_\\theta(\\mathbf{x}_i) $$ Fig. 5. Prototypical networks in the few-shot and zero-shot scenarios. (Image source: original paper) The distribution over classes for a given test input $\\mathbf{x}$ is a softmax over the inverse of distances between the test data embedding and prototype vectors.\n $$ P(y=c\\vert\\mathbf{x})=\\text{softmax}(-d_\\varphi(f_\\theta(\\mathbf{x}), \\mathbf{v}_c)) = \\frac{\\exp(-d_\\varphi(f_\\theta(\\mathbf{x}), \\mathbf{v}_c))}{\\sum_{c' \\in \\mathcal{C}}\\exp(-d_\\varphi(f_\\theta(\\mathbf{x}), \\mathbf{v}_{c'}))} $$ where $d_\\varphi$ can be any distance function as long as $\\varphi$ is differentiable. In the paper, they used the squared euclidean distance.\nThe loss function is the negative log-likelihood: $\\mathcal{L}(\\theta) = -\\log P_\\theta(y=c\\vert\\mathbf{x})$.\nModel-Based Model-based meta-learning models make no assumption on the form of $P_\\theta(y\\vert\\mathbf{x})$. Rather it depends on a model designed specifically for fast learning \u0026mdash; a model that updates its parameters rapidly with a few training steps. This rapid parameter update can be achieved by its internal architecture or controlled by another meta-learner model.\nMemory-Augmented Neural Networks A family of model architectures use external memory storage to facilitate the learning process of neural networks, including Neural Turing Machines and Memory Networks. With an explicit storage buffer, it is easier for the network to rapidly incorporate new information and not to forget in the future. Such a model is known as MANN, short for \u0026ldquo;Memory-Augmented Neural Network\u0026rdquo;. Note that recurrent neural networks with only internal memory such as vanilla RNN or LSTM are not MANNs.\nBecause MANN is expected to encode new information fast and thus to adapt to new tasks after only a few samples, it fits well for meta-learning. Taking the Neural Turing Machine (NTM) as the base model, Santoro et al. (2016) proposed a set of modifications on the training setup and the memory retrieval mechanisms (or \u0026ldquo;addressing mechanisms\u0026rdquo;, deciding how to assign attention weights to memory vectors). Please go through the NTM section in my other post first if you are not familiar with this matter before reading forward.\nAs a quick recap, NTM couples a controller neural network with external memory storage. The controller learns to read and write memory rows by soft attention, while the memory serves as a knowledge repository. The attention weights are generated by its addressing mechanism: content-based + location based.\nFig. 6. The architecture of Neural Turing Machine (NTM). The memory at time t, $\\mathbf{M}\\_t$ is a matrix of size $N \\times M$, containing N vector rows and each has M dimensions. MANN for Meta-Learning To use MANN for meta-learning tasks, we need to train it in a way that the memory can encode and capture information of new tasks fast and, in the meantime, any stored representation is easily and stably accessible.\nThe training described in Santoro et al., 2016 happens in an interesting way so that the memory is forced to hold information for longer until the appropriate labels are presented later. In each training episode, the truth label $y_t$ is presented with one step offset, $(\\mathbf{x}_{t+1}, y_t)$: it is the true label for the input at the previous time step t, but presented as part of the input at time step t+1.\nFig. 7. Task setup in MANN for meta-learning (Image source: original paper). In this way, MANN is motivated to memorize the information of a new dataset, because the memory has to hold the current input until the label is present later and then retrieve the old information to make a prediction accordingly.\nNext let us see how the memory is updated for efficient information retrieval and storage.\nAddressing Mechanism for Meta-Learning Aside from the training process, a new pure content-based addressing mechanism is utilized to make the model better suitable for meta-learning.\n\u0026raquo; How to read from memory? The read attention is constructed purely based on the content similarity.\nFirst, a key feature vector $\\mathbf{k}_t$ is produced at the time step t by the controller as a function of the input $\\mathbf{x}$. Similar to NTM, a read weighting vector $\\mathbf{w}_t^r$ of N elements is computed as the cosine similarity between the key vector and every memory vector row, normalized by softmax. The read vector $\\mathbf{r}_t$ is a sum of memory records weighted by such weightings:\n $$ \\mathbf{r}_i = \\sum_{i=1}^N w_t^r(i)\\mathbf{M}_t(i) \\text{, where } w_t^r(i) = \\text{softmax}(\\frac{\\mathbf{k}_t \\cdot \\mathbf{M}_t(i)}{\\|\\mathbf{k}_t\\| \\cdot \\|\\mathbf{M}_t(i)\\|}) $$ where $M_t$ is the memory matrix at time t and $M_t(i)$ is the i-th row in this matrix.\n\u0026raquo; How to write into memory? The addressing mechanism for writing newly received information into memory operates a lot like the cache replacement policy. The Least Recently Used Access (LRUA) writer is designed for MANN to better work in the scenario of meta-learning. A LRUA write head prefers to write new content to either the least used memory location or the most recently used memory location.\n Rarely used locations: so that we can preserve frequently used information (see LFU); The last used location: the motivation is that once a piece of information is retrieved once, it probably won\u0026rsquo;t be called again for a while (see MRU). There are many cache replacement algorithms and each of them could potentially replace the design here with better performance in different use cases. Furthermore, it would be a good idea to learn the memory usage pattern and addressing strategies rather than arbitrarily set it.\nThe preference of LRUA is carried out in a way that everything is differentiable:\n The usage weight $\\mathbf{w}^u_t$ at time t is a sum of current read and write vectors, in addition to the decayed last usage weight, $\\gamma \\mathbf{w}^u_{t-1}$, where $\\gamma$ is a decay factor. The write vector is an interpolation between the previous read weight (prefer \u0026ldquo;the last used location\u0026rdquo;) and the previous least-used weight (prefer \u0026ldquo;rarely used location\u0026rdquo;). The interpolation parameter is the sigmoid of a hyperparameter $\\alpha$. The least-used weight $\\mathbf{w}^{lu}$ is scaled according to usage weights $\\mathbf{w}_t^u$, in which any dimension remains at 1 if smaller than the n-th smallest element in the vector and 0 otherwise. $$ \\begin{aligned} \\mathbf{w}_t^u \u0026= \\gamma \\mathbf{w}_{t-1}^u + \\mathbf{w}_t^r + \\mathbf{w}_t^w \\\\ \\mathbf{w}_t^r \u0026= \\text{softmax}(\\text{cosine}(\\mathbf{k}_t, \\mathbf{M}_t(i))) \\\\ \\mathbf{w}_t^w \u0026= \\sigma(\\alpha)\\mathbf{w}_{t-1}^r + (1-\\sigma(\\alpha))\\mathbf{w}^{lu}_{t-1}\\\\ \\mathbf{w}_t^{lu} \u0026= \\mathbf{1}_{w_t^u(i) \\leq m(\\mathbf{w}_t^u, n)} \\text{, where }m(\\mathbf{w}_t^u, n)\\text{ is the }n\\text{-th smallest element in vector }\\mathbf{w}_t^u\\text{.} \\end{aligned} $$ Finally, after the least used memory location, indicated by $\\mathbf{w}_t^{lu}$, is set to zero, every memory row is updated:\n $$ \\mathbf{M}_t(i) = \\mathbf{M}_{t-1}(i) + w_t^w(i)\\mathbf{k}_t, \\forall i $$ Meta Networks Meta Networks (Munkhdalai \u0026amp; Yu, 2017), short for MetaNet, is a meta-learning model with architecture and training process designed for rapid generalization across tasks.\nFast Weights The rapid generalization of MetaNet relies on \u0026ldquo;fast weights\u0026rdquo;. There are a handful of papers on this topic, but I haven\u0026rsquo;t read all of them in detail and I failed to find a very concrete definition, only a vague agreement on the concept. Normally weights in the neural networks are updated by stochastic gradient descent in an objective function and this process is known to be slow. One faster way to learn is to utilize one neural network to predict the parameters of another neural network and the generated weights are called fast weights. In comparison, the ordinary SGD-based weights are named slow weights.\nIn MetaNet, loss gradients are used as meta information to populate models that learn fast weights. Slow and fast weights are combined to make predictions in neural networks.\nFig. 8. Combining slow and fast weights in a MLP. $\\bigoplus$ is element-wise sum. (Image source: original paper). Model Components Disclaimer: Below you will find my annotations are different from those in the paper. imo, the paper is poorly written, but the idea is still interesting. So I\u0026rsquo;m presenting the idea in my own language.\n Key components of MetaNet are:\n An embedding function $f_\\theta$, parameterized by $\\theta$, encodes raw inputs into feature vectors. Similar to Siamese Neural Network, these embeddings are trained to be useful for telling whether two inputs are of the same class (verification task). A base learner model $g_\\phi$, parameterized by weights $\\phi$, completes the actual learning task. If we stop here, it looks just like Relation Network. MetaNet, in addition, explicitly models the fast weights of both functions and then aggregates them back into the model (See Fig. 8).\nTherefore we need additional two functions to output fast weights for $f$ and $g$ respectively.\n $F_w$: a LSTM parameterized by $w$ for learning fast weights $\\theta^+$ of the embedding function $f$. It takes as input gradients of $f$\u0026rsquo;s embedding loss for verification task. $G_v$: a neural network parameterized by $v$ learning fast weights $\\phi^+$ for the base learner $g$ from its loss gradients. In MetaNet, the learner\u0026rsquo;s loss gradients are viewed as the meta information of the task. Ok, now let\u0026rsquo;s see how meta networks are trained. The training data contains multiple pairs of datasets: a support set $S=\\{\\mathbf{x}'_i, y'_i\\}_{i=1}^K$ and a test set $U=\\{\\mathbf{x}_i, y_i\\}_{i=1}^L$. Recall that we have four networks and four sets of model parameters to learn, $(\\theta, \\phi, w, v)$.\nFig.9. The MetaNet architecture. Training Process Sample a random pair of inputs at each time step t from the support set $S$, $(\\mathbf{x}'_i, y'_i)$ and $(\\mathbf{x}'_j, y_j)$. Let $\\mathbf{x}_{(t,1)}=\\mathbf{x}'_i$ and $\\mathbf{x}_{(t,2)}=\\mathbf{x}'_j$. for $t = 1, \\dots, K$:\n a. Compute a loss for representation learning; i.e., cross entropy for the verification task: $\\mathcal{L}^\\text{emb}_t = \\mathbf{1}_{y'_i=y'_j} \\log P_t + (1 - \\mathbf{1}_{y'_i=y'_j})\\log(1 - P_t)\\text{, where }P_t = \\sigma(\\mathbf{W}\\vert f_\\theta(\\mathbf{x}_{(t,1)}) - f_\\theta(\\mathbf{x}_{(t,2)})\\vert)$ Compute the task-level fast weights: $\\theta^+ = F_w(\\nabla_\\theta \\mathcal{L}^\\text{emb}_1, \\dots, \\mathcal{L}^\\text{emb}_T)$\n Next go through examples in the support set $S$ and compute the example-level fast weights. Meanwhile, update the memory with learned representations. for $i=1, \\dots, K$:\n a. The base learner outputs a probability distribution: $P(\\hat{y}_i \\vert \\mathbf{x}_i) = g_\\phi(\\mathbf{x}_i)$ and the loss can be cross-entropy or MSE: $\\mathcal{L}^\\text{task}_i = y'_i \\log g_\\phi(\\mathbf{x}'_i) + (1- y'_i) \\log (1 - g_\\phi(\\mathbf{x}'_i))$ b. Extract meta information (loss gradients) of the task and compute the example-level fast weights: $\\phi_i^+ = G_v(\\nabla_\\phi\\mathcal{L}^\\text{task}_i)$ Then store $\\phi^+_i$ into $i$-th location of the \u0026ldquo;value\u0026rdquo; memory $\\mathbf{M}$. d. Encode the support sample into a task-specific input representation using both slow and fast weights: $r'_i = f_{\\theta, \\theta^+}(\\mathbf{x}'_i)$ Then store $r'_i$ into $i$-th location of the \u0026ldquo;key\u0026rdquo; memory $\\mathbf{R}$. Finally it is the time to construct the training loss using the test set $U=\\{\\mathbf{x}_i, y_i\\}_{i=1}^L$. Starts with $\\mathcal{L}_\\text{train}=0$: for $j=1, \\dots, L$:\n a. Encode the test sample into a task-specific input representation: $r_j = f_{\\theta, \\theta^+}(\\mathbf{x}_j)$ b. The fast weights are computed by attending to representations of support set samples in memory $\\mathbf{R}$. The attention function is of your choice. Here MetaNet uses cosine similarity: $$ \\begin{aligned} a_j \u0026= \\text{cosine}(\\mathbf{R}, r_j) = [\\frac{r'_1\\cdot r_j}{\\|r'_1\\|\\cdot\\|r_j\\|}, \\dots, \\frac{r'_N\\cdot r_j}{\\|r'_N\\|\\cdot\\|r_j\\|}]\\\\ \\phi^+_j \u0026= \\text{softmax}(a_j)^\\top \\mathbf{M} \\end{aligned} $$ c. Update the training loss: $\\mathcal{L}_\\text{train} \\leftarrow \\mathcal{L}_\\text{train} + \\mathcal{L}^\\text{task}(g_{\\phi, \\phi^+}(\\mathbf{x}_i), y_i) $ Update all the parameters $(\\theta, \\phi, w, v)$ using $\\mathcal{L}_\\text{train}$.\n Optimization-Based Deep learning models learn through backpropagation of gradients. However, the gradient-based optimization is neither designed to cope with a small number of training samples, nor to converge within a small number of optimization steps. Is there a way to adjust the optimization algorithm so that the model can be good at learning with a few examples? This is what optimization-based approach meta-learning algorithms intend for.\nLSTM Meta-Learner The optimization algorithm can be explicitly modeled. Ravi \u0026amp; Larochelle (2017) did so and named it \u0026ldquo;meta-learner\u0026rdquo;, while the original model for handling the task is called \u0026ldquo;learner\u0026rdquo;. The goal of the meta-learner is to efficiently update the learner\u0026rsquo;s parameters using a small support set so that the learner can adapt to the new task quickly.\nLet\u0026rsquo;s denote the learner model as $M_\\theta$ parameterized by $\\theta$, the meta-learner as $R_\\Theta$ with parameters $\\Theta$, and the loss function $\\mathcal{L}$.\nWhy LSTM? The meta-learner is modeled as a LSTM, because:\n There is similarity between the gradient-based update in backpropagation and the cell-state update in LSTM. Knowing a history of gradients benefits the gradient update; think about how momentum works. The update for the learner\u0026rsquo;s parameters at time step t with a learning rate $\\alpha_t$ is:\n $$ \\theta_t = \\theta_{t-1} - \\alpha_t \\nabla_{\\theta_{t-1}}\\mathcal{L}_t $$ It has the same form as the cell state update in LSTM, if we set forget gate $f_t=1$, input gate $i_t = \\alpha_t$, cell state $c_t = \\theta_t$, and new cell state $\\tilde{c}_t = -\\nabla_{\\theta_{t-1}}\\mathcal{L}_t$:\n $$ \\begin{aligned} c_t \u0026= f_t \\odot c_{t-1} + i_t \\odot \\tilde{c}_t\\\\ \u0026= \\theta_{t-1} - \\alpha_t\\nabla_{\\theta_{t-1}}\\mathcal{L}_t \\end{aligned} $$ While fixing $f_t=1$ and $i_t=\\alpha_t$ might not be the optimal, both of them can be learnable and adaptable to different datasets.\n $$ \\begin{aligned} f_t \u0026= \\sigma(\\mathbf{W}_f \\cdot [\\nabla_{\\theta_{t-1}}\\mathcal{L}_t, \\mathcal{L}_t, \\theta_{t-1}, f_{t-1}] + \\mathbf{b}_f) \u0026 \\scriptstyle{\\text{; how much to forget the old value of parameters.}}\\\\ i_t \u0026= \\sigma(\\mathbf{W}_i \\cdot [\\nabla_{\\theta_{t-1}}\\mathcal{L}_t, \\mathcal{L}_t, \\theta_{t-1}, i_{t-1}] + \\mathbf{b}_i) \u0026 \\scriptstyle{\\text{; corresponding to the learning rate at time step t.}}\\\\ \\tilde{\\theta}_t \u0026= -\\nabla_{\\theta_{t-1}}\\mathcal{L}_t \u0026\\\\ \\theta_t \u0026= f_t \\odot \\theta_{t-1} + i_t \\odot \\tilde{\\theta}_t \u0026\\\\ \\end{aligned} $$ Model Setup Fig. 10. How the learner $M\\_\\theta$ and the meta-learner $R\\_\\Theta$ are trained. (Image source: original paper with more annotations) The training process mimics what happens during test, since it has been proved to be beneficial in Matching Networks. During each training epoch, we first sample a dataset $\\mathcal{D} = (\\mathcal{D}_\\text{train}, \\mathcal{D}_\\text{test}) \\in \\hat{\\mathcal{D}}_\\text{meta-train}$ and then sample mini-batches out of $\\mathcal{D}_\\text{train}$ to update $\\theta$ for $T$ rounds. The final state of the learner parameter $\\theta_T$ is used to train the meta-learner on the test data $\\mathcal{D}_\\text{test}$.\nTwo implementation details to pay extra attention to:\n How to compress the parameter space in LSTM meta-learner? As the meta-learner is modeling parameters of another neural network, it would have hundreds of thousands of variables to learn. Following the idea of sharing parameters across coordinates, To simplify the training process, the meta-learner assumes that the loss $\\mathcal{L}_t$ and the gradient $\\nabla_{\\theta_{t-1}} \\mathcal{L}_t$ are independent. MAML MAML, short for Model-Agnostic Meta-Learning (Finn, et al. 2017) is a fairly general optimization algorithm, compatible with any model that learns through gradient descent.\nLet\u0026rsquo;s say our model is $f_\\theta$ with parameters $\\theta$. Given a task $\\tau_i$ and its associated dataset $(\\mathcal{D}^{(i)}_\\text{train}, \\mathcal{D}^{(i)}_\\text{test})$, we can update the model parameters by one or more gradient descent steps (the following example only contains one step):\n $$ \\theta'_i = \\theta - \\alpha \\nabla_\\theta\\mathcal{L}^{(0)}_{\\tau_i}(f_\\theta) $$ where $\\mathcal{L}^{(0)}$ is the loss computed using the mini data batch with id (0).\nFig. 11. Diagram of MAML. (Image source: original paper) Well, the above formula only optimizes for one task. To achieve a good generalization across a variety of tasks, we would like to find the optimal $\\theta^*$ so that the task-specific fine-tuning is more efficient. Now, we sample a new data batch with id (1) for updating the meta-objective. The loss, denoted as $\\mathcal{L}^{(1)}$, depends on the mini batch (1). The superscripts in $\\mathcal{L}^{(0)}$ and $\\mathcal{L}^{(1)}$ only indicate different data batches, and they refer to the same loss objective for the same task.\n $$ \\begin{aligned} \\theta^* \u0026= \\arg\\min_\\theta \\sum_{\\tau_i \\sim p(\\tau)} \\mathcal{L}_{\\tau_i}^{(1)} (f_{\\theta'_i}) = \\arg\\min_\\theta \\sum_{\\tau_i \\sim p(\\tau)} \\mathcal{L}_{\\tau_i}^{(1)} (f_{\\theta - \\alpha\\nabla_\\theta \\mathcal{L}_{\\tau_i}^{(0)}(f_\\theta)}) \u0026 \\\\ \\theta \u0026\\leftarrow \\theta - \\beta \\nabla_{\\theta} \\sum_{\\tau_i \\sim p(\\tau)} \\mathcal{L}_{\\tau_i}^{(1)} (f_{\\theta - \\alpha\\nabla_\\theta \\mathcal{L}_{\\tau_i}^{(0)}(f_\\theta)}) \u0026 \\scriptstyle{\\text{; updating rule}} \\end{aligned} $$ Fig. 12. The general form of MAML algorithm. (Image source: original paper) First-Order MAML The meta-optimization step above relies on second derivatives. To make the computation less expensive, a modified version of MAML omits second derivatives, resulting in a simplified and cheaper implementation, known as First-Order MAML (FOMAML).\nLet\u0026rsquo;s consider the case of performing $k$ inner gradient steps, $k\\geq1$. Starting with the initial model parameter $\\theta_\\text{meta}$:\n $$ \\begin{aligned} \\theta_0 \u0026= \\theta_\\text{meta}\\\\ \\theta_1 \u0026= \\theta_0 - \\alpha\\nabla_\\theta\\mathcal{L}^{(0)}(\\theta_0)\\\\ \\theta_2 \u0026= \\theta_1 - \\alpha\\nabla_\\theta\\mathcal{L}^{(0)}(\\theta_1)\\\\ \u0026\\dots\\\\ \\theta_k \u0026= \\theta_{k-1} - \\alpha\\nabla_\\theta\\mathcal{L}^{(0)}(\\theta_{k-1}) \\end{aligned} $$ Then in the outer loop, we sample a new data batch for updating the meta-objective.\n $$ \\begin{aligned} \\theta_\\text{meta} \u0026\\leftarrow \\theta_\\text{meta} - \\beta g_\\text{MAML} \u0026 \\scriptstyle{\\text{; update for meta-objective}} \\\\[2mm] \\text{where } g_\\text{MAML} \u0026= \\nabla_{\\theta} \\mathcal{L}^{(1)}(\\theta_k) \u0026\\\\[2mm] \u0026= \\nabla_{\\theta_k} \\mathcal{L}^{(1)}(\\theta_k) \\cdot (\\nabla_{\\theta_{k-1}} \\theta_k) \\dots (\\nabla_{\\theta_0} \\theta_1) \\cdot (\\nabla_{\\theta} \\theta_0) \u0026 \\scriptstyle{\\text{; following the chain rule}} \\\\ \u0026= \\nabla_{\\theta_k} \\mathcal{L}^{(1)}(\\theta_k) \\cdot \\Big( \\prod_{i=1}^k \\nabla_{\\theta_{i-1}} \\theta_i \\Big) \\cdot I \u0026 \\\\ \u0026= \\nabla_{\\theta_k} \\mathcal{L}^{(1)}(\\theta_k) \\cdot \\prod_{i=1}^k \\nabla_{\\theta_{i-1}} (\\theta_{i-1} - \\alpha\\nabla_\\theta\\mathcal{L}^{(0)}(\\theta_{i-1})) \u0026 \\\\ \u0026= \\nabla_{\\theta_k} \\mathcal{L}^{(1)}(\\theta_k) \\cdot \\prod_{i=1}^k (I - \\alpha\\nabla_{\\theta_{i-1}}(\\nabla_\\theta\\mathcal{L}^{(0)}(\\theta_{i-1}))) \u0026 \\end{aligned} $$ The MAML gradient is:\n $$ g_\\text{MAML} = \\nabla_{\\theta_k} \\mathcal{L}^{(1)}(\\theta_k) \\cdot \\prod_{i=1}^k (I - \\alpha \\color{red}{\\nabla_{\\theta_{i-1}}(\\nabla_\\theta\\mathcal{L}^{(0)}(\\theta_{i-1}))}) $$ The First-Order MAML ignores the second derivative part in red. It is simplified as follows, equivalent to the derivative of the last inner gradient update result.\n $$ g_\\text{FOMAML} = \\nabla_{\\theta_k} \\mathcal{L}^{(1)}(\\theta_k) $$ Reptile Reptile (Nichol, Achiam \u0026amp; Schulman, 2018) is a remarkably simple meta-learning optimization algorithm. It is similar to MAML in many ways, given that both rely on meta-optimization through gradient descent and both are model-agnostic.\nThe Reptile works by repeatedly:\n sampling a task, training on it by multiple gradient descent steps, and then moving the model weights towards the new parameters. See the algorithm below: $\\text{SGD}(\\mathcal{L}_{\\tau_i}, \\theta, k)$ performs stochastic gradient update for k steps on the loss $\\mathcal{L}_{\\tau_i}$ starting with initial parameter $\\theta$ and returns the final parameter vector. The batch version samples multiple tasks instead of one within each iteration. The reptile gradient is defined as $(\\theta - W)/\\alpha$, where $\\alpha$ is the stepsize used by the SGD operation.\nFig. 13. The batched version of Reptile algorithm. (Image source: original paper) At a glance, the algorithm looks a lot like an ordinary SGD. However, because the task-specific optimization can take more than one step. it eventually makes $$\\text{SGD}(\\mathbb{E} \\tau[\\mathcal{L}{\\tau}], \\theta, k)$ diverge from $\\mathbb{E}\\tau [\\text{SGD}(\\mathcal{L}{\\tau}, \\theta, k)]$$ when k \u0026gt; 1.\nThe Optimization Assumption Assuming that a task $\\tau \\sim p(\\tau)$ has a manifold of optimal network configuration, $\\mathcal{W}_{\\tau}^*$. The model $f_\\theta$ achieves the best performance for task $\\tau$ when $\\theta$ lays on the surface of $\\mathcal{W}_{\\tau}^*$. To find a solution that is good across tasks, we would like to find a parameter close to all the optimal manifolds of all tasks:\n $$ \\theta^* = \\arg\\min_\\theta \\mathbb{E}_{\\tau \\sim p(\\tau)} [\\frac{1}{2} \\text{dist}(\\theta, \\mathcal{W}_\\tau^*)^2] $$ Fig. 14. The Reptile algorithm updates the parameter alternatively to be closer to the optimal manifolds of different tasks. (Image source: original paper) Let\u0026rsquo;s use the L2 distance as $\\text{dist}(.)$ and the distance between a point $\\theta$ and a set $\\mathcal{W}_\\tau^*$ equals to the distance between $\\theta$ and a point $W_{\\tau}^*(\\theta)$ on the manifold that is closest to $\\theta$:\n $$ \\text{dist}(\\theta, \\mathcal{W}_{\\tau}^*) = \\text{dist}(\\theta, W_{\\tau}^*(\\theta)) \\text{, where }W_{\\tau}^*(\\theta) = \\arg\\min_{W\\in\\mathcal{W}_{\\tau}^*} \\text{dist}(\\theta, W) $$ The gradient of the squared euclidean distance is:\n $$ \\begin{aligned} \\nabla_\\theta[\\frac{1}{2}\\text{dist}(\\theta, \\mathcal{W}_{\\tau_i}^*)^2] \u0026= \\nabla_\\theta[\\frac{1}{2}\\text{dist}(\\theta, W_{\\tau_i}^*(\\theta))^2] \u0026 \\\\ \u0026= \\nabla_\\theta[\\frac{1}{2}(\\theta - W_{\\tau_i}^*(\\theta))^2] \u0026 \\\\ \u0026= \\theta - W_{\\tau_i}^*(\\theta) \u0026 \\scriptstyle{\\text{; See notes.}} \\end{aligned} $$ Notes: According to the Reptile paper, \u0026ldquo;the gradient of the squared euclidean distance between a point Θ and a set S is the vector 2(Θ − p), where p is the closest point in S to Θ\u0026rdquo;. Technically the closest point in S is also a function of Θ, but I\u0026rsquo;m not sure why the gradient does not need to worry about the derivative of p. (Please feel free to leave me a comment or send me an email about this if you have ideas.)\nThus the update rule for one stochastic gradient step is:\n $$ \\theta = \\theta - \\alpha \\nabla_\\theta[\\frac{1}{2} \\text{dist}(\\theta, \\mathcal{W}_{\\tau_i}^*)^2] = \\theta - \\alpha(\\theta - W_{\\tau_i}^*(\\theta)) = (1-\\alpha)\\theta + \\alpha W_{\\tau_i}^*(\\theta) $$ The closest point on the optimal task manifold $W_{\\tau_i}^*(\\theta)$ cannot be computed exactly, but Reptile approximates it using $\\text{SGD}(\\mathcal{L}_\\tau, \\theta, k)$.\nReptile vs FOMAML To demonstrate the deeper connection between Reptile and MAML, let\u0026rsquo;s expand the update formula with an example performing two gradient steps, k=2 in $\\text{SGD}(.)$. Same as defined above, $\\mathcal{L}^{(0)}$ and $\\mathcal{L}^{(1)}$ are losses using different mini-batches of data. For ease of reading, we adopt two simplified annotations: $g^{(i)}_j = \\nabla_{\\theta} \\mathcal{L}^{(i)}(\\theta_j)$ and $H^{(i)}_j = \\nabla^2_{\\theta} \\mathcal{L}^{(i)}(\\theta_j)$.\n $$ \\begin{aligned} \\theta_0 \u0026= \\theta_\\text{meta}\\\\ \\theta_1 \u0026= \\theta_0 - \\alpha\\nabla_\\theta\\mathcal{L}^{(0)}(\\theta_0)= \\theta_0 - \\alpha g^{(0)}_0 \\\\ \\theta_2 \u0026= \\theta_1 - \\alpha\\nabla_\\theta\\mathcal{L}^{(1)}(\\theta_1) = \\theta_0 - \\alpha g^{(0)}_0 - \\alpha g^{(1)}_1 \\end{aligned} $$ According to the early section, the gradient of FOMAML is the last inner gradient update result. Therefore, when k=1:\n $$ \\begin{aligned} g_\\text{FOMAML} \u0026= \\nabla_{\\theta_1} \\mathcal{L}^{(1)}(\\theta_1) = g^{(1)}_1 \\\\ g_\\text{MAML} \u0026= \\nabla_{\\theta_1} \\mathcal{L}^{(1)}(\\theta_1) \\cdot (I - \\alpha\\nabla^2_{\\theta} \\mathcal{L}^{(0)}(\\theta_0)) = g^{(1)}_1 - \\alpha H^{(0)}_0 g^{(1)}_1 \\end{aligned} $$ The Reptile gradient is defined as:\n $$ g_\\text{Reptile} = (\\theta_0 - \\theta_2) / \\alpha = g^{(0)}_0 + g^{(1)}_1 $$ Up to now we have:\nFig. 15. Reptile versus FOMAML in one loop of meta-optimization. (Image source: slides on Reptile by Yoonho Lee.) $$ \\begin{aligned} g_\\text{FOMAML} \u0026= g^{(1)}_1 \\\\ g_\\text{MAML} \u0026= g^{(1)}_1 - \\alpha H^{(0)}_0 g^{(1)}_1 \\\\ g_\\text{Reptile} \u0026= g^{(0)}_0 + g^{(1)}_1 \\end{aligned} $$ Next let\u0026rsquo;s try further expand $g^{(1)}_1$ using Taylor expansion. Recall that Taylor expansion of a function $f(x)$ that is differentiable at a number $a$ is:\n $$ f(x) = f(a) + \\frac{f'(a)}{1!}(x-a) + \\frac{f''(a)}{2!}(x-a)^2 + \\dots = \\sum_{i=0}^\\infty \\frac{f^{(i)}(a)}{i!}(x-a)^i $$ We can consider $\\nabla_{\\theta}\\mathcal{L}^{(1)}(.)$ as a function and $\\theta_0$ as a value point. The Taylor expansion of $g_1^{(1)}$ at the value point $\\theta_0$ is:\n $$ \\begin{aligned} g_1^{(1)} \u0026= \\nabla_{\\theta}\\mathcal{L}^{(1)}(\\theta_1) \\\\ \u0026= \\nabla_{\\theta}\\mathcal{L}^{(1)}(\\theta_0) + \\nabla^2_\\theta\\mathcal{L}^{(1)}(\\theta_0)(\\theta_1 - \\theta_0) + \\frac{1}{2}\\nabla^3_\\theta\\mathcal{L}^{(1)}(\\theta_0)(\\theta_1 - \\theta_0)^2 + \\dots \u0026 \\\\ \u0026= g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} + \\frac{\\alpha^2}{2}\\nabla^3_\\theta\\mathcal{L}^{(1)}(\\theta_0) (g_0^{(0)})^2 + \\dots \u0026 \\scriptstyle{\\text{; because }\\theta_1-\\theta_0=-\\alpha g_0^{(0)}} \\\\ \u0026= g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} + O(\\alpha^2) \\end{aligned} $$ Plug in the expanded form of $g_1^{(1)}$ into the MAML gradients with one step inner gradient update:\n $$ \\begin{aligned} g_\\text{FOMAML} \u0026= g^{(1)}_1 = g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} + O(\\alpha^2)\\\\ g_\\text{MAML} \u0026= g^{(1)}_1 - \\alpha H^{(0)}_0 g^{(1)}_1 \\\\ \u0026= g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} + O(\\alpha^2) - \\alpha H^{(0)}_0 (g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} + O(\\alpha^2))\\\\ \u0026= g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} - \\alpha H^{(0)}_0 g_0^{(1)} + \\alpha^2 \\alpha H^{(0)}_0 H^{(1)}_0 g_0^{(0)} + O(\\alpha^2)\\\\ \u0026= g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} - \\alpha H^{(0)}_0 g_0^{(1)} + O(\\alpha^2) \\end{aligned} $$ The Reptile gradient becomes:\n $$ \\begin{aligned} g_\\text{Reptile} \u0026= g^{(0)}_0 + g^{(1)}_1 \\\\ \u0026= g^{(0)}_0 + g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} + O(\\alpha^2) \\end{aligned} $$ So far we have the formula of three types of gradients:\n $$ \\begin{aligned} g_\\text{FOMAML} \u0026= g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} + O(\\alpha^2)\\\\ g_\\text{MAML} \u0026= g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} - \\alpha H^{(0)}_0 g_0^{(1)} + O(\\alpha^2)\\\\ g_\\text{Reptile} \u0026= g^{(0)}_0 + g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} + O(\\alpha^2) \\end{aligned} $$ During training, we often average over multiple data batches. In our example, the mini batches (0) and (1) are interchangeable since both are drawn at random. The expectation $\\mathbb{E}_{\\tau,0,1}$ is averaged over two data batches, ids (0) and (1), for task $\\tau$.\nLet,\n $A = \\mathbb{E}_{\\tau,0,1} [g_0^{(0)}] = \\mathbb{E}_{\\tau,0,1} [g_0^{(1)}]$; it is the average gradient of task loss. We expect to improve the model parameter to achieve better task performance by following this direction pointed by $A$. $B = \\mathbb{E}_{\\tau,0,1} [H^{(1)}_0 g_0^{(0)}] = \\frac{1}{2}\\mathbb{E}_{\\tau,0,1} [H^{(1)}_0 g_0^{(0)} + H^{(0)}_0 g_0^{(1)}] = \\frac{1}{2}\\mathbb{E}_{\\tau,0,1} [\\nabla_\\theta(g^{(0)}_0 g_0^{(1)})]$; it is the direction (gradient) that increases the inner product of gradients of two different mini batches for the same task. We expect to improve the model parameter to achieve better generalization over different data by following this direction pointed by $B$. To conclude, both MAML and Reptile aim to optimize for the same goal, better task performance (guided by A) and better generalization (guided by B), when the gradient update is approximated by first three leading terms.\n $$ \\begin{aligned} \\mathbb{E}_{\\tau,1,2}[g_\\text{FOMAML}] \u0026= A - \\alpha B + O(\\alpha^2)\\\\ \\mathbb{E}_{\\tau,1,2}[g_\\text{MAML}] \u0026= A - 2\\alpha B + O(\\alpha^2)\\\\ \\mathbb{E}_{\\tau,1,2}[g_\\text{Reptile}] \u0026= 2A - \\alpha B + O(\\alpha^2) \\end{aligned} $$ It is not clear to me whether the ignored term $O(\\alpha^2)$ might play a big impact on the parameter learning. But given that FOMAML is able to obtain a similar performance as the full version of MAML, it might be safe to say higher-level derivatives would not be critical during gradient descent update.\n Cited as:\n@article{weng2018metalearning, title = \u0026quot;Meta-Learning: Learning to Learn Fast\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2018\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2018-11-30-meta-learning/\u0026quot; } Reference [1] Brenden M. Lake, Ruslan Salakhutdinov, and Joshua B. Tenenbaum. \u0026ldquo;Human-level concept learning through probabilistic program induction.\u0026quot; Science 350.6266 (2015): 1332-1338.\n[2] Oriol Vinyals' talk on \u0026ldquo;Model vs Optimization Meta Learning\u0026rdquo;\n[3] Gregory Koch, Richard Zemel, and Ruslan Salakhutdinov. \u0026ldquo;Siamese neural networks for one-shot image recognition.\u0026quot; ICML Deep Learning Workshop. 2015.\n[4] Oriol Vinyals, et al. \u0026ldquo;Matching networks for one shot learning.\u0026quot; NIPS. 2016.\n[5] Flood Sung, et al. \u0026ldquo;Learning to compare: Relation network for few-shot learning.\u0026quot; CVPR. 2018.\n[6] Jake Snell, Kevin Swersky, and Richard Zemel. \u0026ldquo;Prototypical Networks for Few-shot Learning.\u0026quot; CVPR. 2018.\n[7] Adam Santoro, et al. \u0026ldquo;Meta-learning with memory-augmented neural networks.\u0026quot; ICML. 2016.\n[8] Alex Graves, Greg Wayne, and Ivo Danihelka. \u0026ldquo;Neural turing machines.\u0026quot; arXiv preprint arXiv:1410.5401 (2014).\n[9] Tsendsuren Munkhdalai and Hong Yu. \u0026ldquo;Meta Networks.\u0026quot; ICML. 2017.\n[10] Sachin Ravi and Hugo Larochelle. \u0026ldquo;Optimization as a Model for Few-Shot Learning.\u0026quot; ICLR. 2017.\n[11] Chelsea Finn\u0026rsquo;s BAIR blog on \u0026ldquo;Learning to Learn\u0026rdquo;.\n[12] Chelsea Finn, Pieter Abbeel, and Sergey Levine. \u0026ldquo;Model-agnostic meta-learning for fast adaptation of deep networks.\u0026quot; ICML 2017.\n[13] Alex Nichol, Joshua Achiam, John Schulman. \u0026ldquo;On First-Order Meta-Learning Algorithms.\u0026quot; arXiv preprint arXiv:1803.02999 (2018).\n[14] Slides on Reptile by Yoonho Lee.\n","permalink":"https://lilianweng.github.io/posts/2018-11-30-meta-learning/","summary":"[Updated on 2019-10-01: thanks to Tianhao, we have this post translated in Chinese!]\nA good machine learning model often requires training with a large number of samples. Humans, in contrast, learn new concepts and skills much faster and more efficiently. Kids who have seen cats and birds only a few times can quickly tell them apart. People who know how to ride a bike are likely to discover the way to ride a motorcycle fast with little or even no demonstration.","title":"Meta-Learning: Learning to Learn Fast"},{"content":"So far, I\u0026rsquo;ve written about two types of generative models, GAN and VAE. Neither of them explicitly learns the probability density function of real data, $p(\\mathbf{x})$ (where $\\mathbf{x} \\in \\mathcal{D}$) \u0026mdash; because it is really hard! Taking the generative model with latent variables as an example, $p(\\mathbf{x}) = \\int p(\\mathbf{x}\\vert\\mathbf{z})p(\\mathbf{z})d\\mathbf{z}$ can hardly be calculated as it is intractable to go through all possible values of the latent code $\\mathbf{z}$.\nFlow-based deep generative models conquer this hard problem with the help of normalizing flows, a powerful statistics tool for density estimation. A good estimation of $p(\\mathbf{x})$ makes it possible to efficiently complete many downstream tasks: sample unobserved but realistic new data points (data generation), predict the rareness of future events (density estimation), infer latent variables, fill in incomplete data samples, etc.\nTypes of Generative Models Here is a quick summary of the difference between GAN, VAE, and flow-based generative models:\n Generative adversarial networks: GAN provides a smart solution to model the data generation, an unsupervised learning problem, as a supervised one. The discriminator model learns to distinguish the real data from the fake samples that are produced by the generator model. Two models are trained as they are playing a minimax game. Variational autoencoders: VAE inexplicitly optimizes the log-likelihood of the data by maximizing the evidence lower bound (ELBO). Flow-based generative models: A flow-based generative model is constructed by a sequence of invertible transformations. Unlike other two, the model explicitly learns the data distribution $p(\\mathbf{x})$ and therefore the loss function is simply the negative log-likelihood. Fig. 1. Comparison of three categories of generative models. Linear Algebra Basics Recap We should understand two key concepts before getting into the flow-based generative model: the Jacobian determinant and the change of variable rule. Pretty basic, so feel free to skip.\nJacobian Matrix and Determinant Given a function of mapping a $n$-dimensional input vector $\\mathbf{x}$ to a $m$-dimensional output vector, $\\mathbf{f}: \\mathbb{R}^n \\mapsto \\mathbb{R}^m$, the matrix of all first-order partial derivatives of this function is called the Jacobian matrix, $\\mathbf{J}$ where one entry on the i-th row and j-th column is $\\mathbf{J}_{ij} = \\frac{\\partial f_i}{\\partial x_j}$.\n $$ \\mathbf{J} = \\begin{bmatrix} \\frac{\\partial f_1}{\\partial x_1} \u0026 \\dots \u0026 \\frac{\\partial f_1}{\\partial x_n} \\\\[6pt] \\vdots \u0026 \\ddots \u0026 \\vdots \\\\[6pt] \\frac{\\partial f_m}{\\partial x_1} \u0026 \\dots \u0026 \\frac{\\partial f_m}{\\partial x_n} \\\\[6pt] \\end{bmatrix} $$ The determinant is one real number computed as a function of all the elements in a squared matrix. Note that the determinant only exists for square matrices. The absolute value of the determinant can be thought of as a measure of \u0026ldquo;how much multiplication by the matrix expands or contracts space\u0026rdquo;.\nThe determinant of a nxn matrix $M$ is:\n $$ \\det M = \\det \\begin{bmatrix} a_{11} \u0026 a_{12} \u0026 \\dots \u0026 a_{1n} \\\\ a_{21} \u0026 a_{22} \u0026 \\dots \u0026 a_{2n} \\\\ \\vdots \u0026 \\vdots \u0026 \u0026 \\vdots \\\\ a_{n1} \u0026 a_{n2} \u0026 \\dots \u0026 a_{nn} \\\\ \\end{bmatrix} = \\sum_{j_1 j_2 \\dots j_n} (-1)^{\\tau(j_1 j_2 \\dots j_n)} a_{1j_1} a_{2j_2} \\dots a_{nj_n} $$ where the subscript under the summation $j_1 j_2 \\dots j_n$ are all permutations of the set {1, 2, \u0026hellip;, n}, so there are $n!$ items in total; $\\tau(.)$ indicates the signature of a permutation.\nThe determinant of a square matrix $M$ detects whether it is invertible: If $\\det(M)=0$ then $M$ is not invertible (a singular matrix with linearly dependent rows or columns; or any row or column is all 0); otherwise, if $\\det(M)\\neq 0$, $M$ is invertible.\nThe determinant of the product is equivalent to the product of the determinants: $\\det(AB) = \\det(A)\\det(B)$. (proof)\nChange of Variable Theorem Let\u0026rsquo;s review the change of variable theorem specifically in the context of probability density estimation, starting with a single variable case.\nGiven a random variable $z$ and its known probability density function $z \\sim \\pi(z)$, we would like to construct a new random variable using a 1-1 mapping function $x = f(z)$. The function $f$ is invertible, so $z=f^{-1}(x)$. Now the question is how to infer the unknown probability density function of the new variable, $p(x)$?\n $$ \\begin{aligned} \u0026 \\int p(x)dx = \\int \\pi(z)dz = 1 \\scriptstyle{\\text{ ; Definition of probability distribution.}}\\\\ \u0026 p(x) = \\pi(z) \\left\\vert\\frac{dz}{dx}\\right\\vert = \\pi(f^{-1}(x)) \\left\\vert\\frac{d f^{-1}}{dx}\\right\\vert = \\pi(f^{-1}(x)) \\vert (f^{-1})'(x) \\vert \\end{aligned} $$ By definition, the integral $\\int \\pi(z)dz$ is the sum of an infinite number of rectangles of infinitesimal width $\\Delta z$. The height of such a rectangle at position $z$ is the value of the density function $\\pi(z)$. When we substitute the variable, $z = f^{-1}(x)$ yields $\\frac{\\Delta z}{\\Delta x} = (f^{-1}(x))'$ and $\\Delta z = (f^{-1}(x))' \\Delta x$. Here $\\vert(f^{-1}(x))'\\vert$ indicates the ratio between the area of rectangles defined in two different coordinate of variables $z$ and $x$ respectively.\nThe multivariable version has a similar format:\n $$ \\begin{aligned} \\mathbf{z} \u0026\\sim \\pi(\\mathbf{z}), \\mathbf{x} = f(\\mathbf{z}), \\mathbf{z} = f^{-1}(\\mathbf{x}) \\\\ p(\\mathbf{x}) \u0026= \\pi(\\mathbf{z}) \\left\\vert \\det \\dfrac{d \\mathbf{z}}{d \\mathbf{x}} \\right\\vert = \\pi(f^{-1}(\\mathbf{x})) \\left\\vert \\det \\dfrac{d f^{-1}}{d \\mathbf{x}} \\right\\vert \\end{aligned} $$ where $\\det \\frac{\\partial f}{\\partial\\mathbf{z}}$ is the Jacobian determinant of the function $f$. The full proof of the multivariate version is out of the scope of this post; ask Google if interested ;)\nWhat is Normalizing Flows? Being able to do good density estimation has direct applications in many machine learning problems, but it is very hard. For example, since we need to run backward propagation in deep learning models, the embedded probability distribution (i.e. posterior $p(\\mathbf{z}\\vert\\mathbf{x})$) is expected to be simple enough to calculate the derivative easily and efficiently. That is why Gaussian distribution is often used in latent variable generative models, even though most of real world distributions are much more complicated than Gaussian.\nHere comes a Normalizing Flow (NF) model for better and more powerful distribution approximation. A normalizing flow transforms a simple distribution into a complex one by applying a sequence of invertible transformation functions. Flowing through a chain of transformations, we repeatedly substitute the variable for the new one according to the change of variables theorem and eventually obtain a probability distribution of the final target variable.\nFig. 2. Illustration of a normalizing flow model, transforming a simple distribution $p\\_0(\\mathbf{z}\\_0)$ to a complex one $p\\_K(\\mathbf{z}\\_K)$ step by step. As defined in Fig. 2,\n $$ \\begin{aligned} \\mathbf{z}_{i-1} \u0026\\sim p_{i-1}(\\mathbf{z}_{i-1}) \\\\ \\mathbf{z}_i \u0026= f_i(\\mathbf{z}_{i-1})\\text{, thus }\\mathbf{z}_{i-1} = f_i^{-1}(\\mathbf{z}_i) \\\\ p_i(\\mathbf{z}_i) \u0026= p_{i-1}(f_i^{-1}(\\mathbf{z}_i)) \\left\\vert \\det\\dfrac{d f_i^{-1}}{d \\mathbf{z}_i} \\right\\vert \\end{aligned} $$ Then let\u0026rsquo;s convert the equation to be a function of $\\mathbf{z}_i$ so that we can do inference with the base distribution.\n $$ \\begin{aligned} p_i(\\mathbf{z}_i) \u0026= p_{i-1}(f_i^{-1}(\\mathbf{z}_i)) \\left\\vert \\det\\dfrac{d f_i^{-1}}{d \\mathbf{z}_i} \\right\\vert \\\\ \u0026= p_{i-1}(\\mathbf{z}_{i-1}) \\left\\vert \\det \\color{red}{\\Big(\\dfrac{d f_i}{d\\mathbf{z}_{i-1}}\\Big)^{-1}} \\right\\vert \u0026 \\scriptstyle{\\text{; According to the inverse func theorem.}} \\\\ \u0026= p_{i-1}(\\mathbf{z}_{i-1}) \\color{red}{\\left\\vert \\det \\dfrac{d f_i}{d\\mathbf{z}_{i-1}} \\right\\vert^{-1}} \u0026 \\scriptstyle{\\text{; According to a property of Jacobians of invertible func.}} \\\\ \\log p_i(\\mathbf{z}_i) \u0026= \\log p_{i-1}(\\mathbf{z}_{i-1}) - \\log \\left\\vert \\det \\dfrac{d f_i}{d\\mathbf{z}_{i-1}} \\right\\vert \\end{aligned} $$ (*) A note on the \u0026ldquo;inverse function theorem\u0026rdquo;: If $y=f(x)$ and $x=f^{-1}(y)$, we have:\n $$ \\dfrac{df^{-1}(y)}{dy} = \\dfrac{dx}{dy} = (\\dfrac{dy}{dx})^{-1} = (\\dfrac{df(x)}{dx})^{-1} $$ (*) A note on \u0026ldquo;Jacobians of invertible function\u0026rdquo;: The determinant of the inverse of an invertible matrix is the inverse of the determinant: $\\det(M^{-1}) = (\\det(M))^{-1}$, because $\\det(M)\\det(M^{-1}) = \\det(M \\cdot M^{-1}) = \\det(I) = 1$.\nGiven such a chain of probability density functions, we know the relationship between each pair of consecutive variables. We can expand the equation of the output $\\mathbf{x}$ step by step until tracing back to the initial distribution $\\mathbf{z}_0$.\n $$ \\begin{aligned} \\mathbf{x} = \\mathbf{z}_K \u0026= f_K \\circ f_{K-1} \\circ \\dots \\circ f_1 (\\mathbf{z}_0) \\\\ \\log p(\\mathbf{x}) = \\log \\pi_K(\\mathbf{z}_K) \u0026= \\log \\pi_{K-1}(\\mathbf{z}_{K-1}) - \\log\\left\\vert\\det\\dfrac{d f_K}{d \\mathbf{z}_{K-1}}\\right\\vert \\\\ \u0026= \\log \\pi_{K-2}(\\mathbf{z}_{K-2}) - \\log\\left\\vert\\det\\dfrac{d f_{K-1}}{d\\mathbf{z}_{K-2}}\\right\\vert - \\log\\left\\vert\\det\\dfrac{d f_K}{d\\mathbf{z}_{K-1}}\\right\\vert \\\\ \u0026= \\dots \\\\ \u0026= \\log \\pi_0(\\mathbf{z}_0) - \\sum_{i=1}^K \\log\\left\\vert\\det\\dfrac{d f_i}{d\\mathbf{z}_{i-1}}\\right\\vert \\end{aligned} $$ The path traversed by the random variables $\\mathbf{z}_i = f_i(\\mathbf{z}_{i-1})$ is the flow and the full chain formed by the successive distributions $\\pi_i$ is called a normalizing flow. Required by the computation in the equation, a transformation function $f_i$ should satisfy two properties:\n It is easily invertible. Its Jacobian determinant is easy to compute. Models with Normalizing Flows With normalizing flows in our toolbox, the exact log-likelihood of input data $\\log p(\\mathbf{x})$ becomes tractable. As a result, the training criterion of flow-based generative model is simply the negative log-likelihood (NLL) over the training dataset $\\mathcal{D}$:\n $$ \\mathcal{L}(\\mathcal{D}) = - \\frac{1}{\\vert\\mathcal{D}\\vert}\\sum_{\\mathbf{x} \\in \\mathcal{D}} \\log p(\\mathbf{x}) $$ RealNVP The RealNVP (Real-valued Non-Volume Preserving; Dinh et al., 2017) model implements a normalizing flow by stacking a sequence of invertible bijective transformation functions. In each bijection $f: \\mathbf{x} \\mapsto \\mathbf{y}$, known as affine coupling layer, the input dimensions are split into two parts:\n The first $d$ dimensions stay same; The second part, $d+1$ to $D$ dimensions, undergo an affine transformation (\u0026ldquo;scale-and-shift\u0026rdquo;) and both the scale and shift parameters are functions of the first $d$ dimensions. $$ \\begin{aligned} \\mathbf{y}_{1:d} \u0026= \\mathbf{x}_{1:d} \\\\ \\mathbf{y}_{d+1:D} \u0026= \\mathbf{x}_{d+1:D} \\odot \\exp({s(\\mathbf{x}_{1:d})}) + t(\\mathbf{x}_{1:d}) \\end{aligned} $$ where $s(.)$ and $t(.)$ are scale and translation functions and both map $\\mathbb{R}^d \\mapsto \\mathbb{R}^{D-d}$. The $\\odot$ operation is the element-wise product.\nNow let\u0026rsquo;s check whether this transformation satisfy two basic properties for a flow transformation.\nCondition 1: \u0026ldquo;It is easily invertible.\u0026rdquo;\nYes and it is fairly straightforward.\n $$ \\begin{cases} \\mathbf{y}_{1:d} \u0026= \\mathbf{x}_{1:d} \\\\ \\mathbf{y}_{d+1:D} \u0026= \\mathbf{x}_{d+1:D} \\odot \\exp({s(\\mathbf{x}_{1:d})}) + t(\\mathbf{x}_{1:d}) \\end{cases} \\Leftrightarrow \\begin{cases} \\mathbf{x}_{1:d} \u0026= \\mathbf{y}_{1:d} \\\\ \\mathbf{x}_{d+1:D} \u0026= (\\mathbf{y}_{d+1:D} - t(\\mathbf{y}_{1:d})) \\odot \\exp(-s(\\mathbf{y}_{1:d})) \\end{cases} $$ Condition 2: \u0026ldquo;Its Jacobian determinant is easy to compute.\u0026rdquo;\nYes. It is not hard to get the Jacobian matrix and determinant of this transformation. The Jacobian is a lower triangular matrix.\n $$ \\mathbf{J} = \\begin{bmatrix} \\mathbb{I}_d \u0026 \\mathbf{0}_{d\\times(D-d)} \\\\[5pt] \\frac{\\partial \\mathbf{y}_{d+1:D}}{\\partial \\mathbf{x}_{1:d}} \u0026 \\text{diag}(\\exp(s(\\mathbf{x}_{1:d}))) \\end{bmatrix} $$ Hence the determinant is simply the product of terms on the diagonal.\n $$ \\det(\\mathbf{J}) = \\prod_{j=1}^{D-d}\\exp(s(\\mathbf{x}_{1:d}))_j = \\exp(\\sum_{j=1}^{D-d} s(\\mathbf{x}_{1:d})_j) $$ So far, the affine coupling layer looks perfect for constructing a normalizing flow :)\nEven better, since (i) computing $f^-1$ does not require computing the inverse of $s$ or $t$ and (ii) computing the Jacobian determinant does not involve computing the Jacobian of $s$ or $t$, those functions can be arbitrarily complex; i.e. both $s$ and $t$ can be modeled by deep neural networks.\nIn one affine coupling layer, some dimensions (channels) remain unchanged. To make sure all the inputs have a chance to be altered, the model reverses the ordering in each layer so that different components are left unchanged. Following such an alternating pattern, the set of units which remain identical in one transformation layer are always modified in the next. Batch normalization is found to help training models with a very deep stack of coupling layers.\nFurthermore, RealNVP can work in a multi-scale architecture to build a more efficient model for large inputs. The multi-scale architecture applies several \u0026ldquo;sampling\u0026rdquo; operations to normal affine layers, including spatial checkerboard pattern masking, squeezing operation, and channel-wise masking. Read the paper for more details on the multi-scale architecture.\nNICE The NICE (Non-linear Independent Component Estimation; Dinh, et al. 2015) model is a predecessor of RealNVP. The transformation in NICE is the affine coupling layer without the scale term, known as additive coupling layer.\n $$ \\begin{cases} \\mathbf{y}_{1:d} \u0026= \\mathbf{x}_{1:d} \\\\ \\mathbf{y}_{d+1:D} \u0026= \\mathbf{x}_{d+1:D} + m(\\mathbf{x}_{1:d}) \\end{cases} \\Leftrightarrow \\begin{cases} \\mathbf{x}_{1:d} \u0026= \\mathbf{y}_{1:d} \\\\ \\mathbf{x}_{d+1:D} \u0026= \\mathbf{y}_{d+1:D} - m(\\mathbf{y}_{1:d}) \\end{cases} $$ Glow The Glow (Kingma and Dhariwal, 2018) model extends the previous reversible generative models, NICE and RealNVP, and simplifies the architecture by replacing the reverse permutation operation on the channel ordering with invertible 1x1 convolutions.\nFig. 3. One step of flow in the Glow model. (Image source: Kingma and Dhariwal, 2018) There are three substeps in one step of flow in Glow.\nSubstep 1: Activation normalization (short for \u0026ldquo;actnorm\u0026rdquo;)\nIt performs an affine transformation using a scale and bias parameter per channel, similar to batch normalization, but works for mini-batch size 1. The parameters are trainable but initialized so that the first minibatch of data have mean 0 and standard deviation 1 after actnorm.\nSubstep 2: Invertible 1x1 conv\nBetween layers of the RealNVP flow, the ordering of channels is reversed so that all the data dimensions have a chance to be altered. A 1×1 convolution with equal number of input and output channels is a generalization of any permutation of the channel ordering.\nSay, we have an invertible 1x1 convolution of an input $h \\times w \\times c$ tensor $\\mathbf{h}$ with a weight matrix $\\mathbf{W}$ of size $c \\times c$. The output is a $h \\times w \\times c$ tensor, labeled as $f = \\texttt{conv2d}(\\mathbf{h}; \\mathbf{W})$. In order to apply the change of variable rule, we need to compute the Jacobian determinant $\\vert \\det\\partial f / \\partial\\mathbf{h}\\vert$.\nBoth the input and output of 1x1 convolution here can be viewed as a matrix of size $h \\times w$. Each entry $\\mathbf{x}_{ij}$ ($i=1,\\dots,h, j=1,\\dots,w$) in $\\mathbf{h}$ is a vector of $c$ channels and each entry is multiplied by the weight matrix $\\mathbf{W}$ to obtain the corresponding entry $\\mathbf{y}_{ij}$ in the output matrix respectively. The derivative of each entry is $\\partial \\mathbf{x}_{ij} \\mathbf{W} / \\partial\\mathbf{x}_{ij} = \\mathbf{W}$ and there are $h \\times w$ such entries in total:\n $$ \\log \\left\\vert\\det \\frac{\\partial\\texttt{conv2d}(\\mathbf{h}; \\mathbf{W})}{\\partial\\mathbf{h}}\\right\\vert = \\log (\\vert\\det\\mathbf{W}\\vert^{h \\cdot w}\\vert) = h \\cdot w \\cdot \\log \\vert\\det\\mathbf{W}\\vert $$ The inverse 1x1 convolution depends on the inverse matrix $\\mathbf{W}^{-1}$. Since the weight matrix is relatively small, the amount of computation for the matrix determinant (tf.linalg.det) and inversion (tf.linalg.inv) is still under control.\nSubstep 3: Affine coupling layer\nThe design is same as in RealNVP.\nFig. 4. Three substeps in one step of flow in Glow. (Image source: Kingma and Dhariwal, 2018) Models with Autoregressive Flows The autoregressive constraint is a way to model sequential data, $\\mathbf{x} = [x_1, \\dots, x_D]$: each output only depends on the data observed in the past, but not on the future ones. In other words, the probability of observing $x_i$ is conditioned on $x_1, \\dots, x_{i-1}$ and the product of these conditional probabilities gives us the probability of observing the full sequence:\n $$ p(\\mathbf{x}) = \\prod_{i=1}^{D} p(x_i\\vert x_1, \\dots, x_{i-1}) = \\prod_{i=1}^{D} p(x_i\\vert x_{1:i-1}) $$ How to model the conditional density is of your choice. It can be a univariate Gaussian with mean and standard deviation computed as a function of $x_{1:i-1}$, or a multilayer neural network with $x_{1:i-1}$ as the input.\nIf a flow transformation in a normalizing flow is framed as an autoregressive model \u0026mdash; each dimension in a vector variable is conditioned on the previous dimensions \u0026mdash; this is an autoregressive flow.\nThis section starts with several classic autoregressive models (MADE, PixelRNN, WaveNet) and then we dive into autoregressive flow models (MAF and IAF).\nMADE MADE (Masked Autoencoder for Distribution Estimation; Germain et al., 2015) is a specially designed architecture to enforce the autoregressive property in the autoencoder efficiently. When using an autoencoder to predict the conditional probabilities, rather than feeding the autoencoder with input of different observation windows $D$ times, MADE removes the contribution from certain hidden units by multiplying binary mask matrices so that each input dimension is reconstructed only from previous dimensions in a given ordering in a single pass.\nIn a multilayer fully-connected neural network, say, we have $L$ hidden layers with weight matrices $\\mathbf{W}^1, \\dots, \\mathbf{W}^L$ and an output layer with weight matrix $\\mathbf{V}$. The output $\\hat{\\mathbf{x}}$ has each dimension $\\hat{x}_i = p(x_i\\vert x_{1:i-1})$.\nWithout any mask, the computation through layers looks like the following:\n $$ \\begin{aligned} \\mathbf{h}^0 \u0026= \\mathbf{x} \\\\ \\mathbf{h}^l \u0026= \\text{activation}^l(\\mathbf{W}^l\\mathbf{h}^{l-1} + \\mathbf{b}^l) \\\\ \\hat{\\mathbf{x}} \u0026= \\sigma(\\mathbf{V}\\mathbf{h}^L + \\mathbf{c}) \\end{aligned} $$ Fig. 5. Demonstration of how MADE works in a three-layer feed-forward neural network. (Image source: Germain et al., 2015) To zero out some connections between layers, we can simply element-wise multiply every weight matrix by a binary mask matrix. Each hidden node is assigned with a random \u0026ldquo;connectivity integer\u0026rdquo; between $1$ and $D-1$; the assigned value for the $k$-th unit in the $l$-th layer is denoted by $m^l_k$. The binary mask matrix is determined by element-wise comparing values of two nodes in two layers.\n $$ \\begin{aligned} \\mathbf{h}^l \u0026= \\text{activation}^l((\\mathbf{W}^l \\color{red}{\\odot \\mathbf{M}^{\\mathbf{W}^l}}) \\mathbf{h}^{l-1} + \\mathbf{b}^l) \\\\ \\hat{\\mathbf{x}} \u0026= \\sigma((\\mathbf{V} \\color{red}{\\odot \\mathbf{M}^{\\mathbf{V}}}) \\mathbf{h}^L + \\mathbf{c}) \\\\ M^{\\mathbf{W}^l}_{k', k} \u0026= \\mathbf{1}_{m^l_{k'} \\geq m^{l-1}_k} = \\begin{cases} 1, \u0026 \\text{if } m^l_{k'} \\geq m^{l-1}_k\\\\ 0, \u0026 \\text{otherwise} \\end{cases} \\\\ M^{\\mathbf{V}}_{d, k} \u0026= \\mathbf{1}_{d \\geq m^L_k} = \\begin{cases} 1, \u0026 \\text{if } d m^L_k\\\\ 0, \u0026 \\text{otherwise} \\end{cases} \\end{aligned} $$ A unit in the current layer can only be connected to other units with equal or smaller numbers in the previous layer and this type of dependency easily propagates through the network up to the output layer. Once the numbers are assigned to all the units and layers, the ordering of input dimensions is fixed and the conditional probability is produced with respect to it. See a great illustration in Fig. 5. To make sure all the hidden units are connected to the input and output layers through some paths, the $m^l_k$ is sampled to be equal or greater than the minimal connectivity integer in the previous layer, $\\min_{k'} m_{k'}^{l-1}$.\nMADE training can be further facilitated by:\n Order-agnostic training: shuffle the input dimensions, so that MADE is able to model any arbitrary ordering; can create an ensemble of autoregressive models at the runtime. Connectivity-agnostic training: to avoid a model being tied up to a specific connectivity pattern constraints, resample $m^l_k$ for each training minibatch. PixelRNN PixelRNN (Oord et al, 2016) is a deep generative model for images. The image is generated one pixel at a time and each new pixel is sampled conditional on the pixels that have been seen before.\nLet\u0026rsquo;s consider an image of size $n \\times n$, $\\mathbf{x} = \\{x_1, \\dots, x_{n^2}\\}$, the model starts generating pixels from the top left corner, from left to right and top to bottom (See Fig. 6).\nFig. 6. The context for generating one pixel in PixelRNN. (Image source: Oord et al, 2016) Every pixel $x_i$ is sampled from a probability distribution conditional over the the past context: pixels above it or on the left of it when in the same row. The definition of such context looks pretty arbitrary, because how visual attention is attended to an image is more flexible. Somehow magically a generative model with such a strong assumption works.\nOne implementation that could capture the entire context is the Diagonal BiLSTM. First, apply the skewing operation by offsetting each row of the input feature map by one position with respect to the previous row, so that computation for each row can be parallelized. Then the LSTM states are computed with respect to the current pixel and the pixels on the left.\nFig. 7. (a) PixelRNN with diagonal BiLSTM. (b) Skewing operation that offsets each row in the feature map by one with regards to the row above. (Image source: Oord et al, 2016) $$ \\begin{aligned} \\lbrack \\mathbf{o}_i, \\mathbf{f}_i, \\mathbf{i}_i, \\mathbf{g}_i \\rbrack \u0026= \\sigma(\\mathbf{K}^{ss} \\circledast \\mathbf{h}_{i-1} + \\mathbf{K}^{is} \\circledast \\mathbf{x}_i) \u0026 \\scriptstyle{\\text{; }\\sigma\\scriptstyle{\\text{ is tanh for g, but otherwise sigmoid; }}\\circledast\\scriptstyle{\\text{ is convolution operation.}}} \\\\ \\mathbf{c}_i \u0026= \\mathbf{f}_i \\odot \\mathbf{c}_{i-1} + \\mathbf{i}_i \\odot \\mathbf{g}_i \u0026 \\scriptstyle{\\text{; }}\\odot\\scriptstyle{\\text{ is elementwise product.}}\\\\ \\mathbf{h}_i \u0026= \\mathbf{o}_i \\odot \\tanh(\\mathbf{c}_i) \\end{aligned} $$ where $\\circledast$ denotes the convolution operation and $\\odot$ is the element-wise multiplication. The input-to-state component $\\mathbf{K}^{is}$ is a 1x1 convolution, while the state-to-state recurrent component is computed with a column-wise convolution $\\mathbf{K}^{ss}$ with a kernel of size 2x1.\nThe diagonal BiLSTM layers are capable of processing an unbounded context field, but expensive to compute due to the sequential dependency between states. A faster implementation uses multiple convolutional layers without pooling to define a bounded context box. The convolution kernel is masked so that the future context is not seen, similar to MADE. This convolution version is called PixelCNN.\nFig. 8. PixelCNN with masked convolution constructed by an elementwise product of a mask tensor and the convolution kernel before applying it. (Image source: http://slazebni.cs.illinois.edu/spring17/lec13_advanced.pdf) WaveNet WaveNet (Van Den Oord, et al. 2016) is very similar to PixelCNN but applied to 1-D audio signals. WaveNet consists of a stack of causal convolution which is a convolution operation designed to respect the ordering: the prediction at a certain timestamp can only consume the data observed in the past, no dependency on the future. In PixelCNN, the causal convolution is implemented by masked convolution kernel. The causal convolution in WaveNet is simply to shift the output by a number of timestamps to the future so that the output is aligned with the last input element.\nOne big drawback of convolution layer is a very limited size of receptive field. The output can hardly depend on the input hundreds or thousands of timesteps ago, which can be a crucial requirement for modeling long sequences. WaveNet therefore adopts dilated convolution (animation), where the kernel is applied to an evenly-distributed subset of samples in a much larger receptive field of the input.\nFig. 9. Visualization of WaveNet models with a stack of (top) causal convolution layers and (bottom) dilated convolution layers. (Image source: Van Den Oord, et al. 2016) WaveNet uses the gated activation unit as the non-linear layer, as it is found to work significantly better than ReLU for modeling 1-D audio data. The residual connection is applied after the gated activation.\n $$ \\mathbf{z} = \\tanh(\\mathbf{W}_{f,k}\\circledast\\mathbf{x})\\odot\\sigma(\\mathbf{W}_{g,k}\\circledast\\mathbf{x}) $$ where $\\mathbf{W}_{f,k}$ and $\\mathbf{W}_{g,k}$ are convolution filter and gate weight matrix of the $k$-th layer, respectively; both are learnable.\nMasked Autoregressive Flow Masked Autoregressive Flow (MAF; Papamakarios et al., 2017) is a type of normalizing flows, where the transformation layer is built as an autoregressive neural network. MAF is very similar to Inverse Autoregressive Flow (IAF) introduced later. See more discussion on the relationship between MAF and IAF in the next section.\nGiven two random variables, $\\mathbf{z} \\sim \\pi(\\mathbf{z})$ and $\\mathbf{x} \\sim p(\\mathbf{x})$ and the probability density function $\\pi(\\mathbf{z})$ is known, MAF aims to learn $p(\\mathbf{x})$. MAF generates each $x_i$ conditioned on the past dimensions $\\mathbf{x}_{1:i-1}$.\nPrecisely the conditional probability is an affine transformation of $\\mathbf{z}$, where the scale and shift terms are functions of the observed part of $\\mathbf{x}$.\n Data generation, producing a new $\\mathbf{x}$: $x_i \\sim p(x_i\\vert\\mathbf{x}_{1:i-1}) = z_i \\odot \\sigma_i(\\mathbf{x}_{1:i-1}) + \\mu_i(\\mathbf{x}_{1:i-1})\\text{, where }\\mathbf{z} \\sim \\pi(\\mathbf{z})$\n Density estimation, given a known $\\mathbf{x}$: $p(\\mathbf{x}) = \\prod_{i=1}^D p(x_i\\vert\\mathbf{x}_{1:i-1})$\nThe generation procedure is sequential, so it is slow by design. While density estimation only needs one pass the network using architecture like MADE. The transformation function is trivial to inverse and the Jacobian determinant is easy to compute too.\nInverse Autoregressive Flow Similar to MAF, Inverse autoregressive flow (IAF; Kingma et al., 2016) models the conditional probability of the target variable as an autoregressive model too, but with a reversed flow, thus achieving a much efficient sampling process.\nFirst, let\u0026rsquo;s reverse the affine transformation in MAF:\n $$ z_i = \\frac{x_i - \\mu_i(\\mathbf{x}_{1:i-1})}{\\sigma_i(\\mathbf{x}_{1:i-1})} = -\\frac{\\mu_i(\\mathbf{x}_{1:i-1})}{\\sigma_i(\\mathbf{x}_{1:i-1})} + x_i \\odot \\frac{1}{\\sigma_i(\\mathbf{x}_{1:i-1})} $$ If let:\n $$ \\begin{aligned} \u0026 \\tilde{\\mathbf{x}} = \\mathbf{z}\\text{, }\\tilde{p}(.) = \\pi(.)\\text{, }\\tilde{\\mathbf{x}} \\sim \\tilde{p}(\\tilde{\\mathbf{x}}) \\\\ \u0026 \\tilde{\\mathbf{z}} = \\mathbf{x} \\text{, }\\tilde{\\pi}(.) = p(.)\\text{, }\\tilde{\\mathbf{z}} \\sim \\tilde{\\pi}(\\tilde{\\mathbf{z}})\\\\ \u0026 \\tilde{\\mu}_i(\\tilde{\\mathbf{z}}_{1:i-1}) = \\tilde{\\mu}_i(\\mathbf{x}_{1:i-1}) = -\\frac{\\mu_i(\\mathbf{x}_{1:i-1})}{\\sigma_i(\\mathbf{x}_{1:i-1})} \\\\ \u0026 \\tilde{\\sigma}(\\tilde{\\mathbf{z}}_{1:i-1}) = \\tilde{\\sigma}(\\mathbf{x}_{1:i-1}) = \\frac{1}{\\sigma_i(\\mathbf{x}_{1:i-1})} \\end{aligned} $$ Then we would have,\n $$ \\tilde{x}_i \\sim p(\\tilde{x}_i\\vert\\tilde{\\mathbf{z}}_{1:i}) = \\tilde{z}_i \\odot \\tilde{\\sigma}_i(\\tilde{\\mathbf{z}}_{1:i-1}) + \\tilde{\\mu}_i(\\tilde{\\mathbf{z}}_{1:i-1}) \\text{, where }\\tilde{\\mathbf{z}} \\sim \\tilde{\\pi}(\\tilde{\\mathbf{z}}) $$ IAF intends to estimate the probability density function of $\\tilde{\\mathbf{x}}$ given that $\\tilde{\\pi}(\\tilde{\\mathbf{z}})$ is already known. The inverse flow is an autoregressive affine transformation too, same as in MAF, but the scale and shift terms are autoregressive functions of observed variables from the known distribution $\\tilde{\\pi}(\\tilde{\\mathbf{z}})$. See the comparison between MAF and IAF in Fig. 10.\nFig. 10. Comparison of MAF and IAF. The variable with known density is in green while the unknown one is in red. Computations of the individual elements $\\tilde{x}_i$ do not depend on each other, so they are easily parallelizable (only one pass using MADE). The density estimation for a known $\\tilde{\\mathbf{x}}$ is not efficient, because we have to recover the value of $\\tilde{z}_i$ in a sequential order, $\\tilde{z}_i = (\\tilde{x}_i - \\tilde{\\mu}_i(\\tilde{\\mathbf{z}}_{1:i-1})) / \\tilde{\\sigma}_i(\\tilde{\\mathbf{z}}_{1:i-1})$, thus D times in total.\n Base distribution Target distribution Model Data generation Density estimation MAF $\\mathbf{z}\\sim\\pi(\\mathbf{z})$ $\\mathbf{x}\\sim p(\\mathbf{x})$ $x_i = z_i \\odot \\sigma_i(\\mathbf{x}_{1:i-1}) + \\mu_i(\\mathbf{x}_{1:i-1})$ Sequential; slow One pass; fast IAF $\\tilde{\\mathbf{z}}\\sim\\tilde{\\pi}(\\tilde{\\mathbf{z}})$ $\\tilde{\\mathbf{x}}\\sim\\tilde{p}(\\tilde{\\mathbf{x}})$ $\\tilde{x}_i = \\tilde{z}_i \\odot \\tilde{\\sigma}_i(\\tilde{\\mathbf{z}}_{1:i-1}) + \\tilde{\\mu}_i(\\tilde{\\mathbf{z}}_{1:i-1})$ One pass; fast Sequential; slow \u0026mdash;\u0026mdash;\u0026mdash;- \u0026mdash;\u0026mdash;\u0026mdash;- \u0026mdash;\u0026mdash;\u0026mdash;- \u0026mdash;\u0026mdash;\u0026mdash;- \u0026mdash;\u0026mdash;\u0026mdash;- \u0026mdash;\u0026mdash;\u0026mdash;- VAE + Flows In Variational Autoencoder, if we want to model the posterior $p(\\mathbf{z}\\vert\\mathbf{x})$ as a more complicated distribution rather than simple Gaussian. Intuitively we can use normalizing flow to transform the base Gaussian for better density approximation. The encoder then would predict a set of scale and shift terms $(\\mu_i, \\sigma_i)$ which are all functions of input $\\mathbf{x}$. Read the paper for more details if interested.\n If you notice mistakes and errors in this post, don\u0026rsquo;t hesitate to contact me at [lilian dot wengweng at gmail dot com] and I would be very happy to correct them right away!\nSee you in the next post :D\n Cited as:\n@article{weng2018flow, title = \u0026quot;Flow-based Deep Generative Models\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2018\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2018-10-13-flow-models/\u0026quot; } Reference [1] Danilo Jimenez Rezende, and Shakir Mohamed. \u0026ldquo;Variational inference with normalizing flows.\u0026quot; ICML 2015.\n[2] Normalizing Flows Tutorial, Part 1: Distributions and Determinants by Eric Jang.\n[3] Normalizing Flows Tutorial, Part 2: Modern Normalizing Flows by Eric Jang.\n[4] Normalizing Flows by Adam Kosiorek.\n[5] Laurent Dinh, Jascha Sohl-Dickstein, and Samy Bengio. \u0026ldquo;Density estimation using Real NVP.\u0026quot; ICLR 2017.\n[6] Laurent Dinh, David Krueger, and Yoshua Bengio. \u0026ldquo;NICE: Non-linear independent components estimation.\u0026quot; ICLR 2015 Workshop track.\n[7] Diederik P. Kingma, and Prafulla Dhariwal. \u0026ldquo;Glow: Generative flow with invertible 1x1 convolutions.\u0026quot; arXiv:1807.03039 (2018).\n[8] Germain, Mathieu, Karol Gregor, Iain Murray, and Hugo Larochelle. \u0026ldquo;Made: Masked autoencoder for distribution estimation.\u0026quot; ICML 2015.\n[9] Aaron van den Oord, Nal Kalchbrenner, and Koray Kavukcuoglu. \u0026ldquo;Pixel recurrent neural networks.\u0026quot; ICML 2016.\n[10] Diederik P. Kingma, et al. \u0026ldquo;Improved variational inference with inverse autoregressive flow.\u0026quot; NIPS. 2016.\n[11] George Papamakarios, Iain Murray, and Theo Pavlakou. \u0026ldquo;Masked autoregressive flow for density estimation.\u0026quot; NIPS 2017.\n[12] Jianlin Su, and Guang Wu. \u0026ldquo;f-VAEs: Improve VAEs with Conditional Flows.\u0026quot; arXiv:1809.05861 (2018).\n[13] Van Den Oord, Aaron, et al. \u0026ldquo;WaveNet: A generative model for raw audio.\u0026quot; SSW. 2016.\n","permalink":"https://lilianweng.github.io/posts/2018-10-13-flow-models/","summary":"So far, I\u0026rsquo;ve written about two types of generative models, GAN and VAE. Neither of them explicitly learns the probability density function of real data, $p(\\mathbf{x})$ (where $\\mathbf{x} \\in \\mathcal{D}$) \u0026mdash; because it is really hard! Taking the generative model with latent variables as an example, $p(\\mathbf{x}) = \\int p(\\mathbf{x}\\vert\\mathbf{z})p(\\mathbf{z})d\\mathbf{z}$ can hardly be calculated as it is intractable to go through all possible values of the latent code $\\mathbf{z}$.\nFlow-based deep generative models conquer this hard problem with the help of normalizing flows, a powerful statistics tool for density estimation.","title":"Flow-based Deep Generative Models"},{"content":"[Updated on 2019-07-18: add a section on VQ-VAE \u0026amp; VQ-VAE-2.] [Updated on 2019-07-26: add a section on TD-VAE.] \nAutocoder is invented to reconstruct high-dimensional data using a neural network model with a narrow bottleneck layer in the middle (oops, this is probably not true for Variational Autoencoder, and we will investigate it in details in later sections). A nice byproduct is dimension reduction: the bottleneck layer captures a compressed latent encoding. Such a low-dimensional representation can be used as en embedding vector in various applications (i.e. search), help data compression, or reveal the underlying data generative factors.\nNotation Symbol Mean $\\mathcal{D}$ The dataset, $\\mathcal{D} = \\{ \\mathbf{x}^{(1)}, \\mathbf{x}^{(2)}, \\dots, \\mathbf{x}^{(n)} \\}$, contains $n$ data samples; $\\vert\\mathcal{D}\\vert =n $. $\\mathbf{x}^{(i)}$ Each data point is a vector of $d$ dimensions, $\\mathbf{x}^{(i)} = [x^{(i)}_1, x^{(i)}_2, \\dots, x^{(i)}_d]$. $\\mathbf{x}$ One data sample from the dataset, $\\mathbf{x} \\in \\mathcal{D}$. $\\mathbf{x}’$ The reconstructed version of $\\mathbf{x}$. $\\tilde{\\mathbf{x}}$ The corrupted version of $\\mathbf{x}$. $\\mathbf{z}$ The compressed code learned in the bottleneck layer. $a_j^{(l)}$ The activation function for the $j$-th neuron in the $l$-th hidden layer. $g_{\\phi}(.)$ The encoding function parameterized by $\\phi$. $f_{\\theta}(.)$ The decoding function parameterized by $\\theta$. $q_{\\phi}(\\mathbf{z}\\vert\\mathbf{x})$ Estimated posterior probability function, also known as probabilistic encoder. $p_{\\theta}(\\mathbf{x}\\vert\\mathbf{z})$ Likelihood of generating true data sample given the latent code, also known as probabilistic decoder. Autoencoder Autoencoder is a neural network designed to learn an identity function in an unsupervised way to reconstruct the original input while compressing the data in the process so as to discover a more efficient and compressed representation. The idea was originated in the 1980s, and later promoted by the seminal paper by Hinton \u0026amp; Salakhutdinov, 2006.\nIt consists of two networks:\n Encoder network: It translates the original high-dimension input into the latent low-dimensional code. The input size is larger than the output size. Decoder network: The decoder network recovers the data from the code, likely with larger and larger output layers. Fig. 1. Illustration of autoencoder model architecture. The encoder network essentially accomplishes the dimensionality reduction, just like how we would use Principal Component Analysis (PCA) or Matrix Factorization (MF) for. In addition, the autoencoder is explicitly optimized for the data reconstruction from the code. A good intermediate representation not only can capture latent variables, but also benefits a full decompression process.\nThe model contains an encoder function $g(.)$ parameterized by $\\phi$ and a decoder function $f(.)$ parameterized by $\\theta$. The low-dimensional code learned for input $\\mathbf{x}$ in the bottleneck layer is $\\mathbf{z} = g_\\phi(\\mathbf{x})$ and the reconstructed input is $\\mathbf{x}' = f_\\theta(g_\\phi(\\mathbf{x}))$.\nThe parameters $(\\theta, \\phi)$ are learned together to output a reconstructed data sample same as the original input, $\\mathbf{x} \\approx f_\\theta(g_\\phi(\\mathbf{x}))$, or in other words, to learn an identity function. There are various metrics to quantify the difference between two vectors, such as cross entropy when the activation function is sigmoid, or as simple as MSE loss:\n $$ L_\\text{AE}(\\theta, \\phi) = \\frac{1}{n}\\sum_{i=1}^n (\\mathbf{x}^{(i)} - f_\\theta(g_\\phi(\\mathbf{x}^{(i)})))^2 $$ Denoising Autoencoder Since the autoencoder learns the identity function, we are facing the risk of \u0026ldquo;overfitting\u0026rdquo; when there are more network parameters than the number of data points.\nTo avoid overfitting and improve the robustness, Denoising Autoencoder (Vincent et al. 2008) proposed a modification to the basic autoencoder. The input is partially corrupted by adding noises to or masking some values of the input vector in a stochastic manner, $\\tilde{\\mathbf{x}} \\sim \\mathcal{M}_\\mathcal{D}(\\tilde{\\mathbf{x}} \\vert \\mathbf{x})$. Then the model is trained to recover the original input (note: not the corrupt one).\n $$ \\begin{aligned} \\tilde{\\mathbf{x}}^{(i)} \u0026\\sim \\mathcal{M}_\\mathcal{D}(\\tilde{\\mathbf{x}}^{(i)} \\vert \\mathbf{x}^{(i)})\\\\ L_\\text{DAE}(\\theta, \\phi) \u0026= \\frac{1}{n} \\sum_{i=1}^n (\\mathbf{x}^{(i)} - f_\\theta(g_\\phi(\\tilde{\\mathbf{x}}^{(i)})))^2 \\end{aligned} $$ where $\\mathcal{M}_\\mathcal{D}$ defines the mapping from the true data samples to the noisy or corrupted ones.\nFig. 2. Illustration of denoising autoencoder model architecture. This design is motivated by the fact that humans can easily recognize an object or a scene even the view is partially occluded or corrupted. To \u0026ldquo;repair\u0026rdquo; the partially destroyed input, the denoising autoencoder has to discover and capture relationship between dimensions of input in order to infer missing pieces.\nFor high dimensional input with high redundancy, like images, the model is likely to depend on evidence gathered from a combination of many input dimensions to recover the denoised version rather than to overfit one dimension. This builds up a good foundation for learning robust latent representation.\nThe noise is controlled by a stochastic mapping $\\mathcal{M}_\\mathcal{D}(\\tilde{\\mathbf{x}} \\vert \\mathbf{x})$, and it is not specific to a particular type of corruption process (i.e. masking noise, Gaussian noise, salt-and-pepper noise, etc.). Naturally the corruption process can be equipped with prior knowledge\nIn the experiment of the original DAE paper, the noise is applied in this way: a fixed proportion of input dimensions are selected at random and their values are forced to 0. Sounds a lot like dropout, right? Well, the denoising autoencoder was proposed in 2008, 4 years before the dropout paper (Hinton, et al. 2012) ;)\nFig. 3. Stacking denoising autoencoders. (Image source: Vincent et al., 2010) -- Sparse Autoencoder Sparse Autoencoder applies a \u0026ldquo;sparse\u0026rdquo; constraint on the hidden unit activation to avoid overfitting and improve robustness. It forces the model to only have a small number of hidden units being activated at the same time, or in other words, one hidden neuron should be inactivate most of time.\nRecall that common activation functions include sigmoid, tanh, relu, leaky relu, etc. A neuron is activated when the value is close to 1 and inactivate with a value close to 0.\nLet’s say there are $s_l$ neurons in the $l$-th hidden layer and the activation function for the $j$-th neuron in this layer is labelled as $a^{(l)}_j(.)$, $j=1, \\dots, s_l$. The fraction of activation of this neuron $\\hat{\\rho}_j$ is expected to be a small number $\\rho$, known as sparsity parameter; a common config is $\\rho = 0.05$.\n $$ \\hat{\\rho}_j^{(l)} = \\frac{1}{n} \\sum_{i=1}^n [a_j^{(l)}(\\mathbf{x}^{(i)})] \\approx \\rho $$ This constraint is achieved by adding a penalty term into the loss function. The KL-divergence $D_\\text{KL}$ measures the difference between two Bernoulli distributions, one with mean $\\rho$ and the other with mean $\\hat{\\rho}_j^{(l)}$. The hyperparameter $\\beta$ controls how strong the penalty we want to apply on the sparsity loss.\n $$ \\begin{aligned} L_\\text{SAE}(\\theta) \u0026= L(\\theta) + \\beta \\sum_{l=1}^L \\sum_{j=1}^{s_l} D_\\text{KL}(\\rho \\| \\hat{\\rho}_j^{(l)}) \\\\ \u0026= L(\\theta) + \\beta \\sum_{l=1}^L \\sum_{j=1}^{s_l} \\rho\\log\\frac{\\rho}{\\hat{\\rho}_j^{(l)}} + (1-\\rho)\\log\\frac{1-\\rho}{1-\\hat{\\rho}_j^{(l)}} \\end{aligned} $$ Fig. 4. The KL divergence between a Bernoulli distribution with mean $\\rho=0.25$ and a Bernoulli distribution with mean $0 \\leq \\hat{\\rho} \\leq 1$. $k$-Sparse Autoencoder\nIn $k$-Sparse Autoencoder (Makhzani and Frey, 2013), the sparsity is enforced by only keeping the top k highest activations in the bottleneck layer with linear activation function. First we run feedforward through the encoder network to get the compressed code: $\\mathbf{z} = g(\\mathbf{x})$. Sort the values in the code vector $\\mathbf{z}$. Only the k largest values are kept while other neurons are set to 0. This can be done in a ReLU layer with an adjustable threshold too. Now we have a sparsified code: $\\mathbf{z}’ = \\text{Sparsify}(\\mathbf{z})$. Compute the output and the loss from the sparsified code, $L = |\\mathbf{x} - f(\\mathbf{z}') |_2^2$. And, the back-propagation only goes through the top k activated hidden units!\nFig. 5. Filters of the k-sparse autoencoder for different sparsity levels k, learnt from MNIST with 1000 hidden units.. (Image source: Makhzani and Frey, 2013) Contractive Autoencoder Similar to sparse autoencoder, Contractive Autoencoder (Rifai, et al, 2011) encourages the learned representation to stay in a contractive space for better robustness.\nIt adds a term in the loss function to penalize the representation being too sensitive to the input, and thus improve the robustness to small perturbations around the training data points. The sensitivity is measured by the Frobenius norm of the Jacobian matrix of the encoder activations with respect to the input:\n $$ \\|J_f(\\mathbf{x})\\|_F^2 = \\sum_{ij} \\Big( \\frac{\\partial h_j(\\mathbf{x})}{\\partial x_i} \\Big)^2 $$ where $h_j$ is one unit output in the compressed code $\\mathbf{z} = f(x)$.\nThis penalty term is the sum of squares of all partial derivatives of the learned encoding with respect to input dimensions. The authors claimed that empirically this penalty was found to carve a representation that corresponds to a lower-dimensional non-linear manifold, while staying more invariant to majority directions orthogonal to the manifold.\nVAE: Variational Autoencoder The idea of Variational Autoencoder (Kingma \u0026amp; Welling, 2014), short for VAE, is actually less similar to all the autoencoder models above, but deeply rooted in the methods of variational bayesian and graphical model.\nInstead of mapping the input into a fixed vector, we want to map it into a distribution. Let’s label this distribution as $p_\\theta$, parameterized by $\\theta$. The relationship between the data input $\\mathbf{x}$ and the latent encoding vector $\\mathbf{z}$ can be fully defined by:\n Prior $p_\\theta(\\mathbf{z})$ Likelihood $p_\\theta(\\mathbf{x}\\vert\\mathbf{z})$ Posterior $p_\\theta(\\mathbf{z}\\vert\\mathbf{x})$ Assuming that we know the real parameter $\\theta^{*}$ for this distribution. In order to generate a sample that looks like a real data point $\\mathbf{x}^{(i)}$, we follow these steps:\n First, sample a $\\mathbf{z}^{(i)}$ from a prior distribution $p_{\\theta^*}(\\mathbf{z})$. Then a value $\\mathbf{x}^{(i)}$ is generated from a conditional distribution $p_{\\theta^*}(\\mathbf{x} \\vert \\mathbf{z} = \\mathbf{z}^{(i)})$. The optimal parameter $\\theta^{*}$ is the one that maximizes the probability of generating real data samples:\n $$ \\theta^{*} = \\arg\\max_\\theta \\prod_{i=1}^n p_\\theta(\\mathbf{x}^{(i)}) $$ Commonly we use the log probabilities to convert the product on RHS to a sum:\n $$ \\theta^{*} = \\arg\\max_\\theta \\sum_{i=1}^n \\log p_\\theta(\\mathbf{x}^{(i)}) $$ Now let’s update the equation to better demonstrate the data generation process so as to involve the encoding vector:\n $$ p_\\theta(\\mathbf{x}^{(i)}) = \\int p_\\theta(\\mathbf{x}^{(i)}\\vert\\mathbf{z}) p_\\theta(\\mathbf{z}) d\\mathbf{z} $$ Unfortunately it is not easy to compute $p_\\theta(\\mathbf{x}^{(i)})$ in this way, as it is very expensive to check all the possible values of $\\mathbf{z}$ and sum them up. To narrow down the value space to facilitate faster search, we would like to introduce a new approximation function to output what is a likely code given an input $\\mathbf{x}$, $q_\\phi(\\mathbf{z}\\vert\\mathbf{x})$, parameterized by $\\phi$.\nFig. 6. The graphical model involved in Variational Autoencoder. Solid lines denote the generative distribution $p\\_\\theta(.)$ and dashed lines denote the distribution $q\\_\\phi (\\mathbf{z}\\vert\\mathbf{x})$ to approximate the intractable posterior $p\\_\\theta (\\mathbf{z}\\vert\\mathbf{x})$. Now the structure looks a lot like an autoencoder:\n The conditional probability $p_\\theta(\\mathbf{x} \\vert \\mathbf{z})$ defines a generative model, similar to the decoder $f_\\theta(\\mathbf{x} \\vert \\mathbf{z})$ introduced above. $p_\\theta(\\mathbf{x} \\vert \\mathbf{z})$ is also known as probabilistic decoder. The approximation function $q_\\phi(\\mathbf{z} \\vert \\mathbf{x})$ is the probabilistic encoder, playing a similar role as $g_\\phi(\\mathbf{z} \\vert \\mathbf{x})$ above. Loss Function: ELBO The estimated posterior $q_\\phi(\\mathbf{z}\\vert\\mathbf{x})$ should be very close to the real one $p_\\theta(\\mathbf{z}\\vert\\mathbf{x})$. We can use Kullback-Leibler divergence to quantify the distance between these two distributions. KL divergence $D_\\text{KL}(X|Y)$ measures how much information is lost if the distribution Y is used to represent X.\nIn our case we want to minimize $D_\\text{KL}( q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) | p_\\theta(\\mathbf{z}\\vert\\mathbf{x}) )$ with respect to $\\phi$.\nBut why use $D_\\text{KL}(q_\\phi | p_\\theta)$ (reversed KL) instead of $D_\\text{KL}(p_\\theta | q_\\phi)$ (forward KL)? Eric Jang has a great explanation in his post on Bayesian Variational methods. As a quick recap:\nFig. 7. Forward and reversed KL divergence have different demands on how to match two distributions. (Image source: blog.evjang.com/2016/08/variational-bayes.html) Forward KL divergence: $D_\\text{KL}(P|Q) = \\mathbb{E}_{z\\sim P(z)} \\log\\frac{P(z)}{Q(z)}$; we have to ensure that Q(z)\u0026gt;0 wherever P(z)\u0026gt;0. The optimized variational distribution $q(z)$ has to cover over the entire $p(z)$. Reversed KL divergence: $D_\\text{KL}(Q|P) = \\mathbb{E}_{z\\sim Q(z)} \\log\\frac{Q(z)}{P(z)}$; minimizing the reversed KL divergence squeezes the $Q(z)$ under $P(z)$. Let\u0026rsquo;s now expand the equation:\n $$ \\begin{aligned} \u0026 D_\\text{KL}( q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z}\\vert\\mathbf{x}) ) \u0026 \\\\ \u0026=\\int q_\\phi(\\mathbf{z} \\vert \\mathbf{x}) \\log\\frac{q_\\phi(\\mathbf{z} \\vert \\mathbf{x})}{p_\\theta(\\mathbf{z} \\vert \\mathbf{x})} d\\mathbf{z} \u0026 \\\\ \u0026=\\int q_\\phi(\\mathbf{z} \\vert \\mathbf{x}) \\log\\frac{q_\\phi(\\mathbf{z} \\vert \\mathbf{x})p_\\theta(\\mathbf{x})}{p_\\theta(\\mathbf{z}, \\mathbf{x})} d\\mathbf{z} \u0026 \\scriptstyle{\\text{; Because }p(z \\vert x) = p(z, x) / p(x)} \\\\ \u0026=\\int q_\\phi(\\mathbf{z} \\vert \\mathbf{x}) \\big( \\log p_\\theta(\\mathbf{x}) + \\log\\frac{q_\\phi(\\mathbf{z} \\vert \\mathbf{x})}{p_\\theta(\\mathbf{z}, \\mathbf{x})} \\big) d\\mathbf{z} \u0026 \\\\ \u0026=\\log p_\\theta(\\mathbf{x}) + \\int q_\\phi(\\mathbf{z} \\vert \\mathbf{x})\\log\\frac{q_\\phi(\\mathbf{z} \\vert \\mathbf{x})}{p_\\theta(\\mathbf{z}, \\mathbf{x})} d\\mathbf{z} \u0026 \\scriptstyle{\\text{; Because }\\int q(z \\vert x) dz = 1}\\\\ \u0026=\\log p_\\theta(\\mathbf{x}) + \\int q_\\phi(\\mathbf{z} \\vert \\mathbf{x})\\log\\frac{q_\\phi(\\mathbf{z} \\vert \\mathbf{x})}{p_\\theta(\\mathbf{x}\\vert\\mathbf{z})p_\\theta(\\mathbf{z})} d\\mathbf{z} \u0026 \\scriptstyle{\\text{; Because }p(z, x) = p(x \\vert z) p(z)} \\\\ \u0026=\\log p_\\theta(\\mathbf{x}) + \\mathbb{E}_{\\mathbf{z}\\sim q_\\phi(\\mathbf{z} \\vert \\mathbf{x})}[\\log \\frac{q_\\phi(\\mathbf{z} \\vert \\mathbf{x})}{p_\\theta(\\mathbf{z})} - \\log p_\\theta(\\mathbf{x} \\vert \\mathbf{z})] \u0026\\\\ \u0026=\\log p_\\theta(\\mathbf{x}) + D_\\text{KL}(q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z})) - \\mathbb{E}_{\\mathbf{z}\\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})}\\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z}) \u0026 \\end{aligned} $$ So we have:\n $$ D_\\text{KL}( q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z}\\vert\\mathbf{x}) ) =\\log p_\\theta(\\mathbf{x}) + D_\\text{KL}(q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z})) - \\mathbb{E}_{\\mathbf{z}\\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})}\\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z}) $$ Once rearrange the left and right hand side of the equation,\n $$ \\log p_\\theta(\\mathbf{x}) - D_\\text{KL}( q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z}\\vert\\mathbf{x}) ) = \\mathbb{E}_{\\mathbf{z}\\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})}\\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z}) - D_\\text{KL}(q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z})) $$ The LHS of the equation is exactly what we want to maximize when learning the true distributions: we want to maximize the (log-)likelihood of generating real data (that is $\\log p_\\theta(\\mathbf{x})$) and also minimize the difference between the real and estimated posterior distributions (the term $D_\\text{KL}$ works like a regularizer). Note that $p_\\theta(\\mathbf{x})$ is fixed with respect to $q_\\phi$.\nThe negation of the above defines our loss function:\n $$ \\begin{aligned} L_\\text{VAE}(\\theta, \\phi) \u0026= -\\log p_\\theta(\\mathbf{x}) + D_\\text{KL}( q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z}\\vert\\mathbf{x}) )\\\\ \u0026= - \\mathbb{E}_{\\mathbf{z} \\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})} \\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z}) + D_\\text{KL}( q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z}) ) \\\\ \\theta^{*}, \\phi^{*} \u0026= \\arg\\min_{\\theta, \\phi} L_\\text{VAE} \\end{aligned} $$ In Variational Bayesian methods, this loss function is known as the variational lower bound, or evidence lower bound. The \u0026ldquo;lower bound\u0026rdquo; part in the name comes from the fact that KL divergence is always non-negative and thus $-L_\\text{VAE}$ is the lower bound of $\\log p_\\theta (\\mathbf{x})$.\n $$ -L_\\text{VAE} = \\log p_\\theta(\\mathbf{x}) - D_\\text{KL}( q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z}\\vert\\mathbf{x}) ) \\leq \\log p_\\theta(\\mathbf{x}) $$ Therefore by minimizing the loss, we are maximizing the lower bound of the probability of generating real data samples.\nReparameterization Trick The expectation term in the loss function invokes generating samples from $\\mathbf{z} \\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})$. Sampling is a stochastic process and therefore we cannot backpropagate the gradient. To make it trainable, the reparameterization trick is introduced: It is often possible to express the random variable $\\mathbf{z}$ as a deterministic variable $\\mathbf{z} = \\mathcal{T}_\\phi(\\mathbf{x}, \\boldsymbol{\\epsilon})$, where $\\boldsymbol{\\epsilon}$ is an auxiliary independent random variable, and the transformation function $\\mathcal{T}_\\phi$ parameterized by $\\phi$ converts $\\boldsymbol{\\epsilon}$ to $\\mathbf{z}$.\nFor example, a common choice of the form of $q_\\phi(\\mathbf{z}\\vert\\mathbf{x})$ is a multivariate Gaussian with a diagonal covariance structure:\n $$ \\begin{aligned} \\mathbf{z} \u0026\\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x}^{(i)}) = \\mathcal{N}(\\mathbf{z}; \\boldsymbol{\\mu}^{(i)}, \\boldsymbol{\\sigma}^{2(i)}\\boldsymbol{I}) \u0026 \\\\ \\mathbf{z} \u0026= \\boldsymbol{\\mu} + \\boldsymbol{\\sigma} \\odot \\boldsymbol{\\epsilon} \\text{, where } \\boldsymbol{\\epsilon} \\sim \\mathcal{N}(0, \\boldsymbol{I}) \u0026 \\scriptstyle{\\text{; Reparameterization trick.}} \\end{aligned} $$ where $\\odot$ refers to element-wise product.\nFig. 8. Illustration of how the reparameterization trick makes the $\\mathbf{z}$ sampling process trainable.(Image source: Slide 12 in Kingma’s NIPS 2015 workshop talk) The reparameterization trick works for other types of distributions too, not only Gaussian. In the multivariate Gaussian case, we make the model trainable by learning the mean and variance of the distribution, $\\mu$ and $\\sigma$, explicitly using the reparameterization trick, while the stochasticity remains in the random variable $\\boldsymbol{\\epsilon} \\sim \\mathcal{N}(0, \\boldsymbol{I})$.\nFig. 9. Illustration of variational autoencoder model with the multivariate Gaussian assumption. Beta-VAE If each variable in the inferred latent representation $\\mathbf{z}$ is only sensitive to one single generative factor and relatively invariant to other factors, we will say this representation is disentangled or factorized. One benefit that often comes with disentangled representation is good interpretability and easy generalization to a variety of tasks.\nFor example, a model trained on photos of human faces might capture the gentle, skin color, hair color, hair length, emotion, whether wearing a pair of glasses and many other relatively independent factors in separate dimensions. Such a disentangled representation is very beneficial to facial image generation.\nβ-VAE (Higgins et al., 2017) is a modification of Variational Autoencoder with a special emphasis to discover disentangled latent factors. Following the same incentive in VAE, we want to maximize the probability of generating real data, while keeping the distance between the real and estimated posterior distributions small (say, under a small constant $\\delta$):\n $$ \\begin{aligned} \u0026\\max_{\\phi, \\theta} \\mathbb{E}_{\\mathbf{x}\\sim\\mathcal{D}}[\\mathbb{E}_{\\mathbf{z} \\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})} \\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z})]\\\\ \u0026\\text{subject to } D_\\text{KL}(q_\\phi(\\mathbf{z}\\vert\\mathbf{x})\\|p_\\theta(\\mathbf{z})) We can rewrite it as a Lagrangian with a Lagrangian multiplier $\\beta$ under the KKT condition. The above optimization problem with only one inequality constraint is equivalent to maximizing the following equation $\\mathcal{F}(\\theta, \\phi, \\beta)$:\n $$ \\begin{aligned} \\mathcal{F}(\\theta, \\phi, \\beta) \u0026= \\mathbb{E}_{\\mathbf{z} \\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})} \\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z}) - \\beta(D_\\text{KL}(q_\\phi(\\mathbf{z}\\vert\\mathbf{x})\\|p_\\theta(\\mathbf{z})) - \\delta) \u0026 \\\\ \u0026 = \\mathbb{E}_{\\mathbf{z} \\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})} \\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z}) - \\beta D_\\text{KL}(q_\\phi(\\mathbf{z}\\vert\\mathbf{x})\\|p_\\theta(\\mathbf{z})) + \\beta \\delta \u0026 \\\\ \u0026 \\geq \\mathbb{E}_{\\mathbf{z} \\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})} \\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z}) - \\beta D_\\text{KL}(q_\\phi(\\mathbf{z}\\vert\\mathbf{x})\\|p_\\theta(\\mathbf{z})) \u0026 \\scriptstyle{\\text{; Because }\\beta,\\delta\\geq 0} \\end{aligned} $$ The loss function of $\\beta$-VAE is defined as:\n $$ L_\\text{BETA}(\\phi, \\beta) = - \\mathbb{E}_{\\mathbf{z} \\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})} \\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z}) + \\beta D_\\text{KL}(q_\\phi(\\mathbf{z}\\vert\\mathbf{x})\\|p_\\theta(\\mathbf{z})) $$ where the Lagrangian multiplier $\\beta$ is considered as a hyperparameter.\nSince the negation of $L_\\text{BETA}(\\phi, \\beta)$ is the lower bound of the Lagrangian $\\mathcal{F}(\\theta, \\phi, \\beta)$. Minimizing the loss is equivalent to maximizing the Lagrangian and thus works for our initial optimization problem.\nWhen $\\beta=1$, it is same as VAE. When $\\beta \u0026gt; 1$, it applies a stronger constraint on the latent bottleneck and limits the representation capacity of $\\mathbf{z}$. For some conditionally independent generative factors, keeping them disentangled is the most efficient representation. Therefore a higher $\\beta$ encourages more efficient latent encoding and further encourages the disentanglement. Meanwhile, a higher $\\beta$ may create a trade-off between reconstruction quality and the extent of disentanglement.\nBurgess, et al. (2017) discussed the distentangling in $\\beta$-VAE in depth with an inspiration by the information bottleneck theory and further proposed a modification to $\\beta$-VAE to better control the encoding representation capacity.\nVQ-VAE and VQ-VAE-2 The VQ-VAE (“Vector Quantised-Variational AutoEncoder”; van den Oord, et al. 2017) model learns a discrete latent variable by the encoder, since discrete representations may be a more natural fit for problems like language, speech, reasoning, etc.\nVector quantisation (VQ) is a method to map $K$-dimensional vectors into a finite set of “code” vectors. The process is very much similar to KNN algorithm. The optimal centroid code vector that a sample should be mapped to is the one with minimum euclidean distance.\nLet $\\mathbf{e} \\in \\mathbb{R}^{K \\times D}, i=1, \\dots, K$ be the latent embedding space (also known as \u0026ldquo;codebook\u0026rdquo;) in VQ-VAE, where $K$ is the number of latent variable categories and $D$ is the embedding size. An individual embedding vector is $\\mathbf{e}_i \\in \\mathbb{R}^{D}, i=1, \\dots, K$.\nThe encoder output $E(\\mathbf{x}) = \\mathbf{z}_e$ goes through a nearest-neighbor lookup to match to one of $K$ embedding vectors and then this matched code vector becomes the input for the decoder $D(.)$:\n $$ \\mathbf{z}_q(\\mathbf{x}) = \\text{Quantize}(E(\\mathbf{x})) = \\mathbf{e}_k \\text{ where } k = \\arg\\min_i \\|E(\\mathbf{x}) - \\mathbf{e}_i \\|_2 $$ Note that the discrete latent variables can have different shapes in differnet applications; for example, 1D for speech, 2D for image and 3D for video.\nFig. 10. The architecture of VQ-VAE (Image source: van den Oord, et al. 2017) Because argmin() is non-differentiable on a discrete space, the gradients $\\nabla_z L$ from decoder input $\\mathbf{z}_q$ is copied to the encoder output $\\mathbf{z}_e$. Other than reconstruction loss, VQ-VAE also optimizes:\n VQ loss: The L2 error between the embedding space and the encoder outputs. Commitment loss: A measure to encourage the encoder output to stay close to the embedding space and to prevent it from fluctuating too frequently from one code vector to another. $$ L = \\underbrace{\\|\\mathbf{x} - D(\\mathbf{e}_k)\\|_2^2}_{\\textrm{reconstruction loss}} + \\underbrace{\\|\\text{sg}[E(\\mathbf{x})] - \\mathbf{e}_k\\|_2^2}_{\\textrm{VQ loss}} + \\underbrace{\\beta \\|E(\\mathbf{x}) - \\text{sg}[\\mathbf{e}_k]\\|_2^2}_{\\textrm{commitment loss}} $$ where $\\text{sq}[.]$ is the stop_gradient operator.\nThe embedding vectors in the codebook is updated through EMA (exponential moving average). Given a code vector $\\mathbf{e}_i$, say we have $n_i$ encoder output vectors, $\\{\\mathbf{z}_{i,j}\\}_{j=1}^{n_i}$, that are quantized to $\\mathbf{e}_i$:\n $$ N_i^{(t)} = \\gamma N_i^{(t-1)} + (1-\\gamma)n_i^{(t)}\\;\\;\\; \\mathbf{m}_i^{(t)} = \\gamma \\mathbf{m}_i^{(t-1)} + (1-\\gamma)\\sum_{j=1}^{n_i^{(t)}}\\mathbf{z}_{i,j}^{(t)}\\;\\;\\; \\mathbf{e}_i^{(t)} = \\mathbf{m}_i^{(t)} / N_i^{(t)} $$ where $(t)$ refers to batch sequence in time. $N_i$ and $\\mathbf{m}_i$ are accumulated vector count and volume, respectively.\nVQ-VAE-2 (Ali Razavi, et al. 2019) is a two-level hierarchical VQ-VAE combined with self-attention autoregressive model.\n Stage 1 is to train a hierarchical VQ-VAE: The design of hierarchical latent variables intends to separate local patterns (i.e., texture) from global information (i.e., object shapes). The training of the larger bottom level codebook is conditioned on the smaller top level code too, so that it does not have to learn everything from scratch. Stage 2 is to learn a prior over the latent discrete codebook so that we sample from it and generate images. In this way, the decoder can receive input vectors sampled from a similar distribution as the one in training. A powerful autoregressive model enhanced with multi-headed self-attention layers is used to capture the prior distribution (like PixelSNAIL; Chen et al 2017). Considering that VQ-VAE-2 depends on discrete latent variables configured in a simple hierarchical setting, the quality of its generated images are pretty amazing.\nFig. 11. Architecture of hierarchical VQ-VAE and multi-stage image generation. (Image source: Ali Razavi, et al. 2019) Fig. 12. The VQ-VAE-2 algorithm. (Image source: Ali Razavi, et al. 2019) TD-VAE TD-VAE (“Temporal Difference VAE”; Gregor et al., 2019) works with sequential data. It relies on three main ideas, described below.\nFig. 13. State-space model as a Markov Chain model. 1. State-Space Models In (latent) state-space models, a sequence of unobserved hidden states $\\mathbf{z} = (z_1, \\dots, z_T)$ determine the observation states $\\mathbf{x} = (x_1, \\dots, x_T)$. Each time step in the Markov chain model in Fig. 13 can be trained in a similar manner as in Fig. 6, where the intractable posterior $p(z \\vert x)$ is approximated by a function $q(z \\vert x)$.\n2. Belief State An agent should learn to encode all the past states to reason about the future, named as belief state, $b_t = belief(x_1, \\dots, x_t) = belief(b_{t-1}, x_t)$. Given this, the distribution of future states conditioned on the past can be written as $p(x_{t+1}, \\dots, x_T \\vert x_1, \\dots, x_t) \\approx p(x_{t+1}, \\dots, x_T \\vert b_t)$. The hidden states in a recurrent policy are used as the agent\u0026rsquo;s belief state in TD-VAE. Thus we have $b_t = \\text{RNN}(b_{t-1}, x_t)$.\n3. Jumpy Prediction Further, an agent is expected to imagine distant futures based on all the information gathered so far, suggesting the capability of making jumpy predictions, that is, predicting states several steps further into the future.\nRecall what we have learned from the variance lower bound above:\n $$ \\begin{aligned} \\log p(x) \u0026\\geq \\log p(x) - D_\\text{KL}(q(z|x)\\|p(z|x)) \\\\ \u0026= \\mathbb{E}_{z\\sim q} \\log p(x|z) - D_\\text{KL}(q(z|x)\\|p(z)) \\\\ \u0026= \\mathbb{E}_{z \\sim q} \\log p(x|z) - \\mathbb{E}_{z \\sim q} \\log \\frac{q(z|x)}{p(z)} \\\\ \u0026= \\mathbb{E}_{z \\sim q}[\\log p(x|z) -\\log q(z|x) + \\log p(z)] \\\\ \u0026= \\mathbb{E}_{z \\sim q}[\\log p(x, z) -\\log q(z|x)] \\\\ \\log p(x) \u0026\\geq \\mathbb{E}_{z \\sim q}[\\log p(x, z) -\\log q(z|x)] \\end{aligned} $$ Now let\u0026rsquo;s model the distribution of the state $x_t$ as a probability function conditioned on all the past states $x_{\u0026lt;t}$ and two latent variables, $z_t$ and $z_{t-1}$, at current time step and one step back:\n $$ \\log p(x_t|x_{Continue expanding the equation:\n $$ \\begin{aligned} \u0026 \\log p(x_t|x_{Notice two things:\n The red terms can be ignored according to Markov assumptions. The blue term is expanded according to Markov assumptions. The green term is expanded to include an one-step prediction back to the past as a smoothing distribution. Precisely, there are four types of distributions to learn:\n $p_D(.)$ is the decoder distribution: $p(x_t \\mid z_t)$ is the encoder by the common definition; $p(x_t \\mid z_t) \\to p_D(x_t \\mid z_t)$; $p_T(.)$ is the transition distribution: $p(z_t \\mid z_{t-1})$ captures the sequential dependency between latent variables; $p(z_t \\mid z_{t-1}) \\to p_T(z_t \\mid z_{t-1})$; $p_B(.)$ is the belief distribution: Both $p(z_{t-1} \\mid x_{\u0026lt;t})$ and $q(z_t \\mid x_{\\leq t})$ can use the belief states to predict the latent variables; $p(z_{t-1} \\mid x_{\u0026lt;t}) \\to p_B(z_{t-1} \\mid b_{t-1})$; $q(z_{t} \\mid x_{\\leq t}) \\to p_B(z_t \\mid b_t)$; $p_S(.)$ is the smoothing distribution: The back-to-past smoothing term $q(z_{t-1} \\mid z_t, x_{\\leq t})$ can be rewritten to be dependent of belief states too; $q(z_{t-1} \\mid z_t, x_{\\leq t}) \\to p_S(z_{t-1} \\mid z_t, b_{t-1}, b_t)$; To incorporate the idea of jumpy prediction, the sequential ELBO has to not only work on $t, t+1$, but also two distant timestamp $t_1 \u0026lt; t_2$. Here is the final TD-VAE objective function to maximize:\n $$ J_{t_1, t_2} = \\mathbb{E}[ \\log p_D(x_{t_2}|z_{t_2}) + \\log p_B(z_{t_1}|b_{t_1}) + \\log p_T(z_{t_2}|z_{t_1}) - \\log p_B(z_{t_2}|b_{t_2}) - \\log p_S(z_{t_1}|z_{t_2}, b_{t_1}, b_{t_2})] $$ Fig. 14. A detailed overview of TD-VAE architecture, very nicely done. (Image source: TD-VAE paper) Cited as:\n@article{weng2018VAE, title = \u0026quot;From Autoencoder to Beta-VAE\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2018\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2018-08-12-vae/\u0026quot; } References [1] Geoffrey E. Hinton, and Ruslan R. Salakhutdinov. \u0026ldquo;Reducing the dimensionality of data with neural networks.\u0026quot; Science 313.5786 (2006): 504-507.\n[2] Pascal Vincent, et al. \u0026ldquo;Extracting and composing robust features with denoising autoencoders.\u0026quot; ICML, 2008.\n[3] Pascal Vincent, et al. \u0026ldquo;Stacked denoising autoencoders: Learning useful representations in a deep network with a local denoising criterion.\u0026quot;. Journal of machine learning research 11.Dec (2010): 3371-3408.\n[4] Geoffrey E. Hinton, Nitish Srivastava, Alex Krizhevsky, Ilya Sutskever, and Ruslan R. Salakhutdinov. \u0026ldquo;Improving neural networks by preventing co-adaptation of feature detectors.\u0026rdquo; arXiv preprint arXiv:1207.0580 (2012).\n[5] Sparse Autoencoder by Andrew Ng.\n[6] Alireza Makhzani, Brendan Frey (2013). \u0026ldquo;k-sparse autoencoder\u0026rdquo;. ICLR 2014.\n[7] Salah Rifai, et al. \u0026ldquo;Contractive auto-encoders: Explicit invariance during feature extraction.\u0026quot; ICML, 2011.\n[8] Diederik P. Kingma, and Max Welling. \u0026ldquo;Auto-encoding variational bayes.\u0026quot; ICLR 2014.\n[9] Tutorial - What is a variational autoencoder? on jaan.io\n[10] Youtube tutorial: Variational Autoencoders by Arxiv Insights\n[11] \u0026ldquo;A Beginner\u0026rsquo;s Guide to Variational Methods: Mean-Field Approximation\u0026rdquo; by Eric Jang.\n[12] Carl Doersch. \u0026ldquo;Tutorial on variational autoencoders.\u0026quot; arXiv:1606.05908, 2016.\n[13] Irina Higgins, et al. \u0026quot;$\\beta$-VAE: Learning basic visual concepts with a constrained variational framework.\u0026quot; ICLR 2017.\n[14] Christopher P. Burgess, et al. \u0026ldquo;Understanding disentangling in beta-VAE.\u0026quot; NIPS 2017.\n[15] Aaron van den Oord, et al. \u0026ldquo;Neural Discrete Representation Learning\u0026rdquo; NIPS 2017.\n[16] Ali Razavi, et al. \u0026ldquo;Generating Diverse High-Fidelity Images with VQ-VAE-2\u0026rdquo;. arXiv preprint arXiv:1906.00446 (2019).\n[17] Xi Chen, et al. \u0026ldquo;PixelSNAIL: An Improved Autoregressive Generative Model.\u0026quot; arXiv preprint arXiv:1712.09763 (2017).\n[18] Karol Gregor, et al. \u0026ldquo;Temporal Difference Variational Auto-Encoder.\u0026quot; ICLR 2019.\n","permalink":"https://lilianweng.github.io/posts/2018-08-12-vae/","summary":"[Updated on 2019-07-18: add a section on VQ-VAE \u0026amp; VQ-VAE-2.] [Updated on 2019-07-26: add a section on TD-VAE.] \nAutocoder is invented to reconstruct high-dimensional data using a neural network model with a narrow bottleneck layer in the middle (oops, this is probably not true for Variational Autoencoder, and we will investigate it in details in later sections). A nice byproduct is dimension reduction: the bottleneck layer captures a compressed latent encoding.","title":"From Autoencoder to Beta-VAE"},{"content":"[Updated on 2018-10-28: Add Pointer Network and the link to my implementation of Transformer.] [Updated on 2018-11-06: Add a link to the implementation of Transformer model.] [Updated on 2018-11-18: Add Neural Turing Machines.] [Updated on 2019-07-18: Correct the mistake on using the term \u0026ldquo;self-attention\u0026rdquo; when introducing the show-attention-tell paper; moved it to Self-Attention section.] [Updated on 2020-04-07: A follow-up post on improved Transformer models is here.]\nAttention is, to some extent, motivated by how we pay visual attention to different regions of an image or correlate words in one sentence. Take the picture of a Shiba Inu in Fig. 1 as an example.\nFig. 1. A Shiba Inu in a men’s outfit. The credit of the original photo goes to Instagram @mensweardog. Human visual attention allows us to focus on a certain region with \u0026ldquo;high resolution\u0026rdquo; (i.e. look at the pointy ear in the yellow box) while perceiving the surrounding image in \u0026ldquo;low resolution\u0026rdquo; (i.e. now how about the snowy background and the outfit?), and then adjust the focal point or do the inference accordingly. Given a small patch of an image, pixels in the rest provide clues what should be displayed there. We expect to see a pointy ear in the yellow box because we have seen a dog’s nose, another pointy ear on the right, and Shiba\u0026rsquo;s mystery eyes (stuff in the red boxes). However, the sweater and blanket at the bottom would not be as helpful as those doggy features.\nSimilarly, we can explain the relationship between words in one sentence or close context. When we see \u0026ldquo;eating\u0026rdquo;, we expect to encounter a food word very soon. The color term describes the food, but probably not so much with \u0026ldquo;eating\u0026rdquo; directly.\nFig. 2. One word \"attends\" to other words in the same sentence differently. In a nutshell, attention in deep learning can be broadly interpreted as a vector of importance weights: in order to predict or infer one element, such as a pixel in an image or a word in a sentence, we estimate using the attention vector how strongly it is correlated with (or \u0026ldquo;attends to\u0026rdquo; as you may have read in many papers) other elements and take the sum of their values weighted by the attention vector as the approximation of the target.\nWhat’s Wrong with Seq2Seq Model? The seq2seq model was born in the field of language modeling (Sutskever, et al. 2014). Broadly speaking, it aims to transform an input sequence (source) to a new one (target) and both sequences can be of arbitrary lengths. Examples of transformation tasks include machine translation between multiple languages in either text or audio, question-answer dialog generation, or even parsing sentences into grammar trees.\nThe seq2seq model normally has an encoder-decoder architecture, composed of:\n An encoder processes the input sequence and compresses the information into a context vector (also known as sentence embedding or \u0026ldquo;thought\u0026rdquo; vector) of a fixed length. This representation is expected to be a good summary of the meaning of the whole source sequence. A decoder is initialized with the context vector to emit the transformed output. The early work only used the last state of the encoder network as the decoder initial state. Both the encoder and decoder are recurrent neural networks, i.e. using LSTM or GRU units.\nFig. 3. The encoder-decoder model, translating the sentence \"she is eating a green apple\" to Chinese. The visualization of both encoder and decoder is unrolled in time. A critical and apparent disadvantage of this fixed-length context vector design is incapability of remembering long sentences. Often it has forgotten the first part once it completes processing the whole input. The attention mechanism was born (Bahdanau et al., 2015) to resolve this problem.\nBorn for Translation The attention mechanism was born to help memorize long source sentences in neural machine translation (NMT). Rather than building a single context vector out of the encoder\u0026rsquo;s last hidden state, the secret sauce invented by attention is to create shortcuts between the context vector and the entire source input. The weights of these shortcut connections are customizable for each output element.\nWhile the context vector has access to the entire input sequence, we don’t need to worry about forgetting. The alignment between the source and target is learned and controlled by the context vector. Essentially the context vector consumes three pieces of information:\n encoder hidden states; decoder hidden states; alignment between source and target. Fig. 4. The encoder-decoder model with additive attention mechanism in Bahdanau et al., 2015. Definition Now let’s define the attention mechanism introduced in NMT in a scientific way. Say, we have a source sequence $\\mathbf{x}$ of length $n$ and try to output a target sequence $\\mathbf{y}$ of length $m$:\n $$ \\begin{aligned} \\mathbf{x} \u0026= [x_1, x_2, \\dots, x_n] \\\\ \\mathbf{y} \u0026= [y_1, y_2, \\dots, y_m] \\end{aligned} $$ (Variables in bold indicate that they are vectors; same for everything else in this post.)\nThe encoder is a bidirectional RNN (or other recurrent network setting of your choice) with a forward hidden state $\\overrightarrow{\\boldsymbol{h}}_i$ and a backward one $\\overleftarrow{\\boldsymbol{h}}_i$. A simple concatenation of two represents the encoder state. The motivation is to include both the preceding and following words in the annotation of one word.\n $$ \\boldsymbol{h}_i = [\\overrightarrow{\\boldsymbol{h}}_i^\\top; \\overleftarrow{\\boldsymbol{h}}_i^\\top]^\\top, i=1,\\dots,n $$ The decoder network has hidden state $\\boldsymbol{s}_t=f(\\boldsymbol{s}_{t-1}, y_{t-1}, \\mathbf{c}_t)$ for the output word at position t, $t=1,\\dots,m$, where the context vector $\\mathbf{c}_t$ is a sum of hidden states of the input sequence, weighted by alignment scores:\n $$ \\begin{aligned} \\mathbf{c}_t \u0026= \\sum_{i=1}^n \\alpha_{t,i} \\boldsymbol{h}_i \u0026 \\small{\\text{; Context vector for output }y_t}\\\\ \\alpha_{t,i} \u0026= \\text{align}(y_t, x_i) \u0026 \\small{\\text{; How well two words }y_t\\text{ and }x_i\\text{ are aligned.}}\\\\ \u0026= \\frac{\\exp(\\text{score}(\\boldsymbol{s}_{t-1}, \\boldsymbol{h}_i))}{\\sum_{i'=1}^n \\exp(\\text{score}(\\boldsymbol{s}_{t-1}, \\boldsymbol{h}_{i'}))} \u0026 \\small{\\text{; Softmax of some predefined alignment score.}}. \\end{aligned} $$ The alignment model assigns a score $\\alpha_{t,i}$ to the pair of input at position i and output at position t, $(y_t, x_i)$, based on how well they match. The set of $\\{\\alpha_{t, i}\\}$ are weights defining how much of each source hidden state should be considered for each output. In Bahdanau\u0026rsquo;s paper, the alignment score $\\alpha$ is parametrized by a feed-forward network with a single hidden layer and this network is jointly trained with other parts of the model. The score function is therefore in the following form, given that tanh is used as the non-linear activation function:\n $$ \\text{score}(\\boldsymbol{s}_t, \\boldsymbol{h}_i) = \\mathbf{v}_a^\\top \\tanh(\\mathbf{W}_a[\\boldsymbol{s}_t; \\boldsymbol{h}_i]) $$ where both $\\mathbf{v}_a$ and $\\mathbf{W}_a$ are weight matrices to be learned in the alignment model.\nThe matrix of alignment scores is a nice byproduct to explicitly show the correlation between source and target words.\nFig. 5. Alignment matrix of \"L'accord sur l'Espace économique européen a été signé en août 1992\" (French) and its English translation \"The agreement on the European Economic Area was signed in August 1992\". (Image source: Fig 3 in Bahdanau et al., 2015) Check out this nice tutorial by Tensorflow team for more implementation instructions.\nA Family of Attention Mechanisms With the help of the attention, the dependencies between source and target sequences are not restricted by the in-between distance anymore! Given the big improvement by attention in machine translation, it soon got extended into the computer vision field (Xu et al. 2015) and people started exploring various other forms of attention mechanisms (Luong, et al., 2015; Britz et al., 2017; Vaswani, et al., 2017).\nSummary Below is a summary table of several popular attention mechanisms and corresponding alignment score functions:\n Name Alignment score function Citation Content-base attention $\\text{score}(\\boldsymbol{s}_t, \\boldsymbol{h}_i) = \\text{cosine}[\\boldsymbol{s}_t, \\boldsymbol{h}_i]$ Graves2014 Additive(*) $\\text{score}(\\boldsymbol{s}_t, \\boldsymbol{h}_i) = \\mathbf{v}_a^\\top \\tanh(\\mathbf{W}_a[\\boldsymbol{s}_{t-1}; \\boldsymbol{h}_i])$ Bahdanau2015 Location-Base $\\alpha_{t,i} = \\text{softmax}(\\mathbf{W}_a \\boldsymbol{s}_t)$Note: This simplifies the softmax alignment to only depend on the target position. Luong2015 General $\\text{score}(\\boldsymbol{s}_t, \\boldsymbol{h}_i) = \\boldsymbol{s}_t^\\top\\mathbf{W}_a\\boldsymbol{h}_i$where $\\mathbf{W}_a$ is a trainable weight matrix in the attention layer. Luong2015 Dot-Product $\\text{score}(\\boldsymbol{s}_t, \\boldsymbol{h}_i) = \\boldsymbol{s}_t^\\top\\boldsymbol{h}_i$ Luong2015 Scaled Dot-Product(^) $\\text{score}(\\boldsymbol{s}_t, \\boldsymbol{h}_i) = \\frac{\\boldsymbol{s}_t^\\top\\boldsymbol{h}_i}{\\sqrt{n}}$Note: very similar to the dot-product attention except for a scaling factor; where n is the dimension of the source hidden state. Vaswani2017 (*) Referred to as \u0026ldquo;concat\u0026rdquo; in Luong, et al., 2015 and as \u0026ldquo;additive attention\u0026rdquo; in Vaswani, et al., 2017. (^) It adds a scaling factor $1/\\sqrt{n}$, motivated by the concern when the input is large, the softmax function may have an extremely small gradient, hard for efficient learning.\nHere are a summary of broader categories of attention mechanisms:\n Name Definition Citation Self-Attention(\u0026amp;) Relating different positions of the same input sequence. Theoretically the self-attention can adopt any score functions above, but just replace the target sequence with the same input sequence. Cheng2016 Global/Soft Attending to the entire input state space. Xu2015 Local/Hard Attending to the part of input state space; i.e. a patch of the input image. Xu2015; Luong2015 (\u0026amp;) Also, referred to as \u0026ldquo;intra-attention\u0026rdquo; in Cheng et al., 2016 and some other papers.\nSelf-Attention Self-attention, also known as intra-attention, is an attention mechanism relating different positions of a single sequence in order to compute a representation of the same sequence. It has been shown to be very useful in machine reading, abstractive summarization, or image description generation.\nThe long short-term memory network paper used self-attention to do machine reading. In the example below, the self-attention mechanism enables us to learn the correlation between the current words and the previous part of the sentence.\nFig. 6. The current word is in red and the size of the blue shade indicates the activation level. (Image source: Cheng et al., 2016) Soft vs Hard Attention In the show, attend and tell paper, attention mechanism is applied to images to generate captions. The image is first encoded by a CNN to extract features. Then a LSTM decoder consumes the convolution features to produce descriptive words one by one, where the weights are learned through attention. The visualization of the attention weights clearly demonstrates which regions of the image the model is paying attention to so as to output a certain word.\nFig. 7. \"A woman is throwing a frisbee in a park.\" (Image source: Fig. 6(b) in Xu et al. 2015) This paper first proposed the distinction between \u0026ldquo;soft\u0026rdquo; vs \u0026ldquo;hard\u0026rdquo; attention, based on whether the attention has access to the entire image or only a patch:\n Soft Attention: the alignment weights are learned and placed \u0026ldquo;softly\u0026rdquo; over all patches in the source image; essentially the same type of attention as in Bahdanau et al., 2015. Pro: the model is smooth and differentiable. Con: expensive when the source input is large. Hard Attention: only selects one patch of the image to attend to at a time. Pro: less calculation at the inference time. Con: the model is non-differentiable and requires more complicated techniques such as variance reduction or reinforcement learning to train. (Luong, et al., 2015) Global vs Local Attention Luong, et al., 2015 proposed the \u0026ldquo;global\u0026rdquo; and \u0026ldquo;local\u0026rdquo; attention. The global attention is similar to the soft attention, while the local one is an interesting blend between hard and soft, an improvement over the hard attention to make it differentiable: the model first predicts a single aligned position for the current target word and a window centered around the source position is then used to compute a context vector.\nFig. 8. Global vs local attention (Image source: Fig 2 \u0026 3 in Luong, et al., 2015) Neural Turing Machines Alan Turing in 1936 proposed a minimalistic model of computation. It is composed of a infinitely long tape and a head to interact with the tape. The tape has countless cells on it, each filled with a symbol: 0, 1 or blank (\u0026quot; \u0026ldquo;). The operation head can read symbols, edit symbols and move left/right on the tape. Theoretically a Turing machine can simulate any computer algorithm, irrespective of how complex or expensive the procedure might be. The infinite memory gives a Turing machine an edge to be mathematically limitless. However, infinite memory is not feasible in real modern computers and then we only consider Turing machine as a mathematical model of computation.\nFig. 9. How a Turing machine looks like: a tape + a head that handles the tape. (Image source: http://aturingmachine.com/) Neural Turing Machine (NTM, Graves, Wayne \u0026amp; Danihelka, 2014) is a model architecture for coupling a neural network with external memory storage. The memory mimics the Turing machine tape and the neural network controls the operation heads to read from or write to the tape. However, the memory in NTM is finite, and thus it probably looks more like a “Neural von Neumann Machine”.\nNTM contains two major components, a controller neural network and a memory bank. Controller: is in charge of executing operations on the memory. It can be any type of neural network, feed-forward or recurrent. Memory: stores processed information. It is a matrix of size $N \\times M$, containing N vector rows and each has $M$ dimensions.\nIn one update iteration, the controller processes the input and interacts with the memory bank accordingly to generate output. The interaction is handled by a set of parallel read and write heads. Both read and write operations are “blurry” by softly attending to all the memory addresses.\nFig 10. Neural Turing Machine Architecture. Reading and Writing When reading from the memory at time t, an attention vector of size $N$, $\\mathbf{w}_t$ controls how much attention to assign to different memory locations (matrix rows). The read vector $\\mathbf{r}_t$ is a sum weighted by attention intensity:\n $$ \\mathbf{r}_t = \\sum_{i=1}^N w_t(i)\\mathbf{M}_t(i)\\text{, where }\\sum_{i=1}^N w_t(i)=1, \\forall i: 0 \\leq w_t(i) \\leq 1 $$ where $w_t(i)$ is the $i$-th element in $\\mathbf{w}_t$ and $\\mathbf{M}_t(i)$ is the $i$-th row vector in the memory.\nWhen writing into the memory at time t, as inspired by the input and forget gates in LSTM, a write head first wipes off some old content according to an erase vector $\\mathbf{e}_t$ and then adds new information by an add vector $\\mathbf{a}_t$.\n $$ \\begin{aligned} \\tilde{\\mathbf{M}}_t(i) \u0026= \\mathbf{M}_{t-1}(i) [\\mathbf{1} - w_t(i)\\mathbf{e}_t] \u0026\\scriptstyle{\\text{; erase}}\\\\ \\mathbf{M}_t(i) \u0026= \\tilde{\\mathbf{M}}_t(i) + w_t(i) \\mathbf{a}_t \u0026\\scriptstyle{\\text{; add}} \\end{aligned} $$ Attention Mechanisms In Neural Turing Machine, how to generate the attention distribution $\\mathbf{w}_t$ depends on the addressing mechanisms: NTM uses a mixture of content-based and location-based addressings.\nContent-based addressing\nThe content-addressing creates attention vectors based on the similarity between the key vector $\\mathbf{k}_t$ extracted by the controller from the input and memory rows. The content-based attention scores are computed as cosine similarity and then normalized by softmax. In addition, NTM adds a strength multiplier $\\beta_t$ to amplify or attenuate the focus of the distribution.\n $$ w_t^c(i) = \\text{softmax}(\\beta_t \\cdot \\text{cosine}[\\mathbf{k}_t, \\mathbf{M}_t(i)]) = \\frac{\\exp(\\beta_t \\frac{\\mathbf{k}_t \\cdot \\mathbf{M}_t(i)}{\\|\\mathbf{k}_t\\| \\cdot \\|\\mathbf{M}_t(i)\\|})}{\\sum_{j=1}^N \\exp(\\beta_t \\frac{\\mathbf{k}_t \\cdot \\mathbf{M}_t(j)}{\\|\\mathbf{k}_t\\| \\cdot \\|\\mathbf{M}_t(j)\\|})} $$ Interpolation\nThen an interpolation gate scalar $g_t$ is used to blend the newly generated content-based attention vector with the attention weights in the last time step:\n $$ \\mathbf{w}_t^g = g_t \\mathbf{w}_t^c + (1 - g_t) \\mathbf{w}_{t-1} $$ Location-based addressing\nThe location-based addressing sums up the values at different positions in the attention vector, weighted by a weighting distribution over allowable integer shifts. It is equivalent to a 1-d convolution with a kernel $\\mathbf{s}_t(.)$, a function of the position offset. There are multiple ways to define this distribution. See Fig. 11. for inspiration.\nFig. 11. Two ways to represent the shift weighting distribution $\\mathbf{s}\\_t$. Finally the attention distribution is enhanced by a sharpening scalar $\\gamma_t \\geq 1$.\n $$ \\begin{aligned} \\tilde{w}_t(i) \u0026= \\sum_{j=1}^N w_t^g(j) s_t(i-j) \u0026 \\scriptstyle{\\text{; circular convolution}}\\\\ w_t(i) \u0026= \\frac{\\tilde{w}_t(i)^{\\gamma_t}}{\\sum_{j=1}^N \\tilde{w}_t(j)^{\\gamma_t}} \u0026 \\scriptstyle{\\text{; sharpen}} \\end{aligned} $$ The complete process of generating the attention vector $\\mathbf{w}_t$ at time step t is illustrated in Fig. 12. All the parameters produced by the controller are unique for each head. If there are multiple read and write heads in parallel, the controller would output multiple sets.\nFig. 12. Flow diagram of the addressing mechanisms in Neural Turing Machine. (Image source: Graves, Wayne \u0026 Danihelka, 2014) Pointer Network In problems like sorting or travelling salesman, both input and output are sequential data. Unfortunately, they cannot be easily solved by classic seq-2-seq or NMT models, given that the discrete categories of output elements are not determined in advance, but depends on the variable input size. The Pointer Net (Ptr-Net; Vinyals, et al. 2015) is proposed to resolve this type of problems: When the output elements correspond to positions in an input sequence. Rather than using attention to blend hidden units of an encoder into a context vector (See Fig. 8), the Pointer Net applies attention over the input elements to pick one as the output at each decoder step.\nFig. 13. The architecture of a Pointer Network model. (Image source: Vinyals, et al. 2015) The Ptr-Net outputs a sequence of integer indices, $\\boldsymbol{c} = (c_1, \\dots, c_m)$ given a sequence of input vectors $\\boldsymbol{x} = (x_1, \\dots, x_n)$ and $1 \\leq c_i \\leq n$. The model still embraces an encoder-decoder framework. The encoder and decoder hidden states are denoted as $(\\boldsymbol{h}_1, \\dots, \\boldsymbol{h}_n)$ and $(\\boldsymbol{s}_1, \\dots, \\boldsymbol{s}_m)$, respectively. Note that $\\mathbf{s}_i$ is the output gate after cell activation in the decoder. The Ptr-Net applies additive attention between states and then normalizes it by softmax to model the output conditional probability:\n $$ \\begin{aligned} y_i \u0026= p(c_i \\vert c_1, \\dots, c_{i-1}, \\boldsymbol{x}) \\\\ \u0026= \\text{softmax}(\\text{score}(\\boldsymbol{s}_t; \\boldsymbol{h}_i)) = \\text{softmax}(\\mathbf{v}_a^\\top \\tanh(\\mathbf{W}_a[\\boldsymbol{s}_t; \\boldsymbol{h}_i])) \\end{aligned} $$ The attention mechanism is simplified, as Ptr-Net does not blend the encoder states into the output with attention weights. In this way, the output only responds to the positions but not the input content.\nTransformer \u0026ldquo;Attention is All you Need\u0026rdquo; (Vaswani, et al., 2017), without a doubt, is one of the most impactful and interesting paper in 2017. It presented a lot of improvements to the soft attention and make it possible to do seq2seq modeling without recurrent network units. The proposed \u0026ldquo;transformer\u0026rdquo; model is entirely built on the self-attention mechanisms without using sequence-aligned recurrent architecture.\nThe secret recipe is carried in its model architecture.\nKey, Value and Query The major component in the transformer is the unit of multi-head self-attention mechanism. The transformer views the encoded representation of the input as a set of key-value pairs, $(\\mathbf{K}, \\mathbf{V})$, both of dimension $n$ (input sequence length); in the context of NMT, both the keys and values are the encoder hidden states. In the decoder, the previous output is compressed into a query ($\\mathbf{Q}$ of dimension $m$) and the next output is produced by mapping this query and the set of keys and values.\nThe transformer adopts the scaled dot-product attention: the output is a weighted sum of the values, where the weight assigned to each value is determined by the dot-product of the query with all the keys:\n $$ \\text{Attention}(\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}) = \\text{softmax}(\\frac{\\mathbf{Q}\\mathbf{K}^\\top}{\\sqrt{n}})\\mathbf{V} $$ Multi-Head Self-Attention Fig. 14. Multi-head scaled dot-product attention mechanism. (Image source: Fig 2 in Vaswani, et al., 2017) Rather than only computing the attention once, the multi-head mechanism runs through the scaled dot-product attention multiple times in parallel. The independent attention outputs are simply concatenated and linearly transformed into the expected dimensions. I assume the motivation is because ensembling always helps? ;) According to the paper, \u0026ldquo;multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.\u0026quot;\n $$ \\begin{aligned} \\text{MultiHead}(\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}) \u0026= [\\text{head}_1; \\dots; \\text{head}_h]\\mathbf{W}^O \\\\ \\text{where head}_i \u0026= \\text{Attention}(\\mathbf{Q}\\mathbf{W}^Q_i, \\mathbf{K}\\mathbf{W}^K_i, \\mathbf{V}\\mathbf{W}^V_i) \\end{aligned} $$ where $\\mathbf{W}^Q_i$, $\\mathbf{W}^K_i$, $\\mathbf{W}^V_i$, and $\\mathbf{W}^O$ are parameter matrices to be learned.\nEncoder Fig. 15. The transformer’s encoder. (Image source: Vaswani, et al., 2017) The encoder generates an attention-based representation with capability to locate a specific piece of information from a potentially infinitely-large context.\n A stack of N=6 identical layers. Each layer has a multi-head self-attention layer and a simple position-wise fully connected feed-forward network. Each sub-layer adopts a residual connection and a layer normalization. All the sub-layers output data of the same dimension $d_\\text{model} = 512$. Decoder Fig. 16. The transformer’s decoder. (Image source: Vaswani, et al., 2017) The decoder is able to retrieval from the encoded representation.\n A stack of N = 6 identical layers Each layer has two sub-layers of multi-head attention mechanisms and one sub-layer of fully-connected feed-forward network. Similar to the encoder, each sub-layer adopts a residual connection and a layer normalization. The first multi-head attention sub-layer is modified to prevent positions from attending to subsequent positions, as we don’t want to look into the future of the target sequence when predicting the current position. Full Architecture Finally here is the complete view of the transformer\u0026rsquo;s architecture:\n Both the source and target sequences first go through embedding layers to produce data of the same dimension $d_\\text{model} =512$. To preserve the position information, a sinusoid-wave-based positional encoding is applied and summed with the embedding output. A softmax and linear layer are added to the final decoder output. Fig. 17. The full model architecture of the transformer. (Image source: Fig 1 \u0026 2 in Vaswani, et al., 2017.) Try to implement the transformer model is an interesting experience, here is mine: lilianweng/transformer-tensorflow. Read the comments in the code if you are interested.\nSNAIL The transformer has no recurrent or convolutional structure, even with the positional encoding added to the embedding vector, the sequential order is only weakly incorporated. For problems sensitive to the positional dependency like reinforcement learning, this can be a big problem.\nThe Simple Neural Attention Meta-Learner (SNAIL) (Mishra et al., 2017) was developed partially to resolve the problem with positioning in the transformer model by combining the self-attention mechanism in transformer with temporal convolutions. It has been demonstrated to be good at both supervised learning and reinforcement learning tasks.\nFig. 18. SNAIL model architecture (Image source: Mishra et al., 2017) SNAIL was born in the field of meta-learning, which is another big topic worthy of a post by itself. But in simple words, the meta-learning model is expected to be generalizable to novel, unseen tasks in the similar distribution. Read this nice introduction if interested.\nSelf-Attention GAN Self-Attention GAN (SAGAN; Zhang et al., 2018) adds self-attention layers into GAN to enable both the generator and the discriminator to better model relationships between spatial regions.\nThe classic DCGAN (Deep Convolutional GAN) represents both discriminator and generator as multi-layer convolutional networks. However, the representation capacity of the network is restrained by the filter size, as the feature of one pixel is limited to a small local region. In order to connect regions far apart, the features have to be dilute through layers of convolutional operations and the dependencies are not guaranteed to be maintained.\nAs the (soft) self-attention in the vision context is designed to explicitly learn the relationship between one pixel and all other positions, even regions far apart, it can easily capture global dependencies. Hence GAN equipped with self-attention is expected to handle details better, hooray!\nFig. 19. Convolution operation and self-attention have access to regions of very different sizes. The SAGAN adopts the non-local neural network to apply the attention computation. The convolutional image feature maps $\\mathbf{x}$ is branched out into three copies, corresponding to the concepts of key, value, and query in the transformer:\n Key: $f(\\mathbf{x}) = \\mathbf{W}_f \\mathbf{x}$ Query: $g(\\mathbf{x}) = \\mathbf{W}_g \\mathbf{x}$ Value: $h(\\mathbf{x}) = \\mathbf{W}_h \\mathbf{x}$ Then we apply the dot-product attention to output the self-attention feature maps:\n $$ \\begin{aligned} \\alpha_{i,j} \u0026= \\text{softmax}(f(\\mathbf{x}_i)^\\top g(\\mathbf{x}_j)) \\\\ \\mathbf{o}_j \u0026= \\mathbf{W}_v \\Big( \\sum_{i=1}^N \\alpha_{i,j} h(\\mathbf{x}_i) \\Big) \\end{aligned} $$ Fig. 20. The self-attention mechanism in SAGAN. (Image source: Fig. 2 in Zhang et al., 2018) Note that $\\alpha_{i,j}$ is one entry in the attention map, indicating how much attention the model should pay to the $i$-th position when synthesizing the $j$-th location. $\\mathbf{W}_f$, $\\mathbf{W}_g$, and $\\mathbf{W}_h$ are all 1x1 convolution filters. If you feel that 1x1 conv sounds like a weird concept (i.e., isn\u0026rsquo;t it just to multiply the whole feature map with one number?), watch this short tutorial by Andrew Ng. The output $\\mathbf{o}_j$ is a column vector of the final output $\\mathbf{o}= (\\mathbf{o}_1, \\mathbf{o}_2, \\dots, \\mathbf{o}_j, \\dots, \\mathbf{o}_N)$.\nFurthermore, the output of the attention layer is multiplied by a scale parameter and added back to the original input feature map:\n $$ \\mathbf{y} = \\mathbf{x}_i + \\gamma \\mathbf{o}_i $$ While the scaling parameter $\\gamma$ is increased gradually from 0 during the training, the network is configured to first rely on the cues in the local regions and then gradually learn to assign more weight to the regions that are further away.\nFig. 21. 128×128 example images generated by SAGAN for different classes. (Image source: Partial Fig. 6 in Zhang et al., 2018) Cited as:\n@article{weng2018attention, title = \u0026quot;Attention? Attention!\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2018\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2018-06-24-attention/\u0026quot; } References [1] \u0026ldquo;Attention and Memory in Deep Learning and NLP.\u0026quot; - Jan 3, 2016 by Denny Britz\n[2] \u0026ldquo;Neural Machine Translation (seq2seq) Tutorial\u0026rdquo;\n[3] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. \u0026ldquo;Neural machine translation by jointly learning to align and translate.\u0026quot; ICLR 2015.\n[4] Kelvin Xu, Jimmy Ba, Ryan Kiros, Kyunghyun Cho, Aaron Courville, Ruslan Salakhudinov, Rich Zemel, and Yoshua Bengio. \u0026ldquo;Show, attend and tell: Neural image caption generation with visual attention.\u0026quot; ICML, 2015.\n[5] Ilya Sutskever, Oriol Vinyals, and Quoc V. Le. \u0026ldquo;Sequence to sequence learning with neural networks.\u0026quot; NIPS 2014.\n[6] Thang Luong, Hieu Pham, Christopher D. Manning. \u0026ldquo;Effective Approaches to Attention-based Neural Machine Translation.\u0026quot; EMNLP 2015.\n[7] Denny Britz, Anna Goldie, Thang Luong, and Quoc Le. \u0026ldquo;Massive exploration of neural machine translation architectures.\u0026quot; ACL 2017.\n[8] Ashish Vaswani, et al. \u0026ldquo;Attention is all you need.\u0026quot; NIPS 2017.\n[9] Jianpeng Cheng, Li Dong, and Mirella Lapata. \u0026ldquo;Long short-term memory-networks for machine reading.\u0026quot; EMNLP 2016.\n[10] Xiaolong Wang, et al. \u0026ldquo;Non-local Neural Networks.\u0026quot; CVPR 2018\n[11] Han Zhang, Ian Goodfellow, Dimitris Metaxas, and Augustus Odena. \u0026ldquo;Self-Attention Generative Adversarial Networks.\u0026quot; arXiv preprint arXiv:1805.08318 (2018).\n[12] Nikhil Mishra, Mostafa Rohaninejad, Xi Chen, and Pieter Abbeel. \u0026ldquo;A simple neural attentive meta-learner.\u0026quot; ICLR 2018.\n[13] \u0026ldquo;WaveNet: A Generative Model for Raw Audio\u0026rdquo; - Sep 8, 2016 by DeepMind.\n[14] Oriol Vinyals, Meire Fortunato, and Navdeep Jaitly. \u0026ldquo;Pointer networks.\u0026quot; NIPS 2015.\n[15] Alex Graves, Greg Wayne, and Ivo Danihelka. \u0026ldquo;Neural turing machines.\u0026quot; arXiv preprint arXiv:1410.5401 (2014).\n","permalink":"https://lilianweng.github.io/posts/2018-06-24-attention/","summary":"[Updated on 2018-10-28: Add Pointer Network and the link to my implementation of Transformer.] [Updated on 2018-11-06: Add a link to the implementation of Transformer model.] [Updated on 2018-11-18: Add Neural Turing Machines.] [Updated on 2019-07-18: Correct the mistake on using the term \u0026ldquo;self-attention\u0026rdquo; when introducing the show-attention-tell paper; moved it to Self-Attention section.] [Updated on 2020-04-07: A follow-up post on improved Transformer models is here.]\nAttention is, to some extent, motivated by how we pay visual attention to different regions of an image or correlate words in one sentence.","title":"Attention? Attention!"},{"content":"The full implementation is available in lilianweng/deep-reinforcement-learning-gym\nIn the previous two posts, I have introduced the algorithms of many deep reinforcement learning models. Now it is the time to get our hands dirty and practice how to implement the models in the wild. The implementation is gonna be built in Tensorflow and OpenAI gym environment. The full version of the code in this tutorial is available in [lilian/deep-reinforcement-learning-gym].\nEnvironment Setup Make sure you have Homebrew installed: /usr/bin/ruby -e \u0026#34;$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)\u0026#34; I would suggest starting a virtualenv for your development. It makes life so much easier when you have multiple projects with conflicting requirements; i.e. one works in Python 2.7 while the other is only compatible with Python 3.5+. # Install python virtualenv brew install pyenv-virtualenv # Create a virtual environment of any name you like with Python 3.6.4 support pyenv virtualenv 3.6.4 workspace # Activate the virtualenv named \u0026#34;workspace\u0026#34; pyenv activate workspace [*] For every new installation below, please make sure you are in the virtualenv.\nInstall OpenAI gym according to the instruction. For a minimal installation, run: git clone https://github.com/openai/gym.git cd gym pip install -e . If you are interested in playing with Atari games or other advanced packages, please continue to get a couple of system packages installed.\nbrew install cmake boost boost-python sdl2 swig wget For Atari, go to the gym directory and pip install it. This post is pretty helpful if you have troubles with ALE (arcade learning environment) installation.\npip install -e \u0026#39;.[atari]\u0026#39; Finally clone the \u0026ldquo;playground\u0026rdquo; code and install the requirements. git clone git@github.com:lilianweng/deep-reinforcement-learning-gym.git cd deep-reinforcement-learning-gym pip install -e . # install the \u0026#34;playground\u0026#34; project. pip install -r requirements.txt # install required packages. Gym Environment The OpenAI Gym toolkit provides a set of physical simulation environments, games, and robot simulators that we can play with and design reinforcement learning agents for. An environment object can be initialized by gym.make(\u0026quot;{environment name}\u0026quot;:\nimport gym env = gym.make(\u0026#34;MsPacman-v0\u0026#34;) The formats of action and observation of an environment are defined by env.action_space and env.observation_space, respectively.\nTypes of gym spaces:\n gym.spaces.Discrete(n): discrete values from 0 to n-1. gym.spaces.Box: a multi-dimensional vector of numeric values, the upper and lower bounds of each dimension are defined by Box.low and Box.high. We interact with the env through two major api calls:\nob = env.reset()\n Resets the env to the original setting. Returns the initial observation. ob_next, reward, done, info = env.step(action)\n Applies one action in the env which should be compatible with env.action_space. Gets back the new observation ob_next (env.observation_space), a reward (float), a done flag (bool), and other meta information (dict). If done=True, the episode is complete and we should reset the env to restart. Read more here. Naive Q-Learning Q-learning (Watkins \u0026amp; Dayan, 1992) learns the action value (\u0026ldquo;Q-value\u0026rdquo;) and update it according to the Bellman equation. The key point is while estimating what is the next action, it does not follow the current policy but rather adopt the best Q value (the part in red) independently.\n $$ Q(s, a) \\leftarrow (1 - \\alpha) Q(s, a) + \\alpha (r + \\gamma \\color{red}{\\max_{a' \\in \\mathcal{A}} Q(s', a')}) $$ In a naive implementation, the Q value for all (s, a) pairs can be simply tracked in a dict. No complicated machine learning model is involved yet.\nfrom collections import defaultdict Q = defaultdict(float) gamma = 0.99 # Discounting factor alpha = 0.5 # soft update param env = gym.make(\u0026#34;CartPole-v0\u0026#34;) actions = range(env.action_space) def update_Q(s, r, a, s_next, done): max_q_next = max([Q[s_next, a] for a in actions]) # Do not include the next state\u0026#39;s value if currently at the terminal state. Q[s, a] += alpha * (r + gamma * max_q_next * (1.0 - done) - Q[s, a]) Most gym environments have a multi-dimensional continuous observation space (gym.spaces.Box). To make sure our Q dictionary will not explode by trying to memorize an infinite number of keys, we apply a wrapper to discretize the observation. The concept of wrappers is very powerful, with which we are capable to customize observation, action, step function, etc. of an env. No matter how many wrappers are applied, env.unwrapped always gives back the internal original environment object.\nimport gym class DiscretizedObservationWrapper(gym.ObservationWrapper): \u0026#34;\u0026#34;\u0026#34;This wrapper converts a Box observation into a single integer. \u0026#34;\u0026#34;\u0026#34; def __init__(self, env, n_bins=10, low=None, high=None): super().__init__(env) assert isinstance(env.observation_space, Box) low = self.observation_space.low if low is None else low high = self.observation_space.high if high is None else high self.n_bins = n_bins self.val_bins = [np.linspace(l, h, n_bins + 1) for l, h in zip(low.flatten(), high.flatten())] self.observation_space = Discrete(n_bins ** low.flatten().shape[0]) def _convert_to_one_number(self, digits): return sum([d * ((self.n_bins + 1) ** i) for i, d in enumerate(digits)]) def observation(self, observation): digits = [np.digitize([x], bins)[0] for x, bins in zip(observation.flatten(), self.val_bins)] return self._convert_to_one_number(digits) env = DiscretizedObservationWrapper( env, n_bins=8, low=[-2.4, -2.0, -0.42, -3.5], high=[2.4, 2.0, 0.42, 3.5] ) Let\u0026rsquo;s plug in the interaction with a gym env and update the Q function every time a new transition is generated. When picking the action, we use ε-greedy to force exploration.\nimport gym import numpy as np n_steps = 100000 epsilon = 0.1 # 10% chances to apply a random action def act(ob): if np.random.random() \u0026lt; epsilon: # action_space.sample() is a convenient function to get a random action # that is compatible with this given action space. return env.action_space.sample() # Pick the action with highest q value. qvals = {a: q[state, a] for a in actions} max_q = max(qvals.values()) # In case multiple actions have the same maximum q value. actions_with_max_q = [a for a, q in qvals.items() if q == max_q] return np.random.choice(actions_with_max_q) ob = env.reset() rewards = [] reward = 0.0 for step in range(n_steps): a = act(ob) ob_next, r, done, _ = env.step(a) update_Q(ob, r, a, ob_next, done) reward += r if done: rewards.append(reward) reward = 0.0 ob = env.reset() else: ob = ob_next Often we start with a high epsilon and gradually decrease it during the training, known as \u0026ldquo;epsilon annealing\u0026rdquo;. The full code of QLearningPolicy is available here.\nDeep Q-Network Deep Q-network is a seminal piece of work to make the training of Q-learning more stable and more data-efficient, when the Q value is approximated with a nonlinear function. Two key ingredients are experience replay and a separately updated target network.\nThe main loss function looks like the following,\n $$ \\begin{aligned} \u0026 Y(s, a, r, s') = r + \\gamma \\max_{a'} Q_{\\theta^{-}}(s', a') \\\\ \u0026 \\mathcal{L}(\\theta) = \\mathbb{E}_{(s, a, r, s') \\sim U(D)} \\Big[ \\big( Y(s, a, r, s') - Q_\\theta(s, a) \\big)^2 \\Big] \\end{aligned} $$ The Q network can be a multi-layer dense neural network, a convolutional network, or a recurrent network, depending on the problem. In the full implementation of the DQN policy, it is determined by the model_type parameter, one of (\u0026ldquo;dense\u0026rdquo;, \u0026ldquo;conv\u0026rdquo;, \u0026ldquo;lstm\u0026rdquo;).\nIn the following example, I\u0026rsquo;m using a 2-layer densely connected neural network to learn Q values for the cart pole balancing problem.\nimport gym env = gym.make(\u0026#39;CartPole-v1\u0026#39;) # The observation space is `Box(4,)`, a 4-element vector. observation_size = env.observation_space.shape[0] We have a helper function for creating the networks below:\nimport tensorflow as tf def dense_nn(inputs, layers_sizes, scope_name): \u0026#34;\u0026#34;\u0026#34;Creates a densely connected multi-layer neural network. inputs: the input tensor layers_sizes (list\u0026lt;int\u0026gt;): defines the number of units in each layer. The output layer has the size layers_sizes[-1]. \u0026#34;\u0026#34;\u0026#34; with tf.variable_scope(scope_name): for i, size in enumerate(layers_sizes): inputs = tf.layers.dense( inputs, size, # Add relu activation only for internal layers. activation=tf.nn.relu if i \u0026lt; len(layers_sizes) - 1 else None, kernel_initializer=tf.contrib.layers.xavier_initializer(), name=scope_name + \u0026#39;_l\u0026#39; + str(i) ) return inputs The Q-network and the target network are updated with a batch of transitions (state, action, reward, state_next, done_flag). The input tensors are:\nbatch_size = 32 # A tunable hyperparameter. states = tf.placeholder(tf.float32, shape=(batch_size, observation_size), name=\u0026#39;state\u0026#39;) states_next = tf.placeholder(tf.float32, shape=(batch_size, observation_size), name=\u0026#39;state_next\u0026#39;) actions = tf.placeholder(tf.int32, shape=(batch_size,), name=\u0026#39;action\u0026#39;) rewards = tf.placeholder(tf.float32, shape=(batch_size,), name=\u0026#39;reward\u0026#39;) done_flags = tf.placeholder(tf.float32, shape=(batch_size,), name=\u0026#39;done\u0026#39;) We have two networks of the same structure. Both have the same network architectures with the state observation as the inputs and Q values over all the actions as the outputs.\nq = dense(states, [32, 32, 2], name=\u0026#39;Q_primary\u0026#39;) q_target = dense(states_next, [32, 32, 2], name=\u0026#39;Q_target\u0026#39;) The target network \u0026ldquo;Q_target\u0026rdquo; takes the states_next tensor as the input, because we use its prediction to select the optimal next state in the Bellman equation.\n# The prediction by the primary Q network for the actual actions. action_one_hot = tf.one_hot(actions, act_size, 1.0, 0.0, name=\u0026#39;action_one_hot\u0026#39;) pred = tf.reduce_sum(q * action_one_hot, reduction_indices=-1, name=\u0026#39;q_acted\u0026#39;) # The optimization target defined by the Bellman equation and the target network. max_q_next_by_target = tf.reduce_max(q_target, axis=-1) y = rewards + (1. - done_flags) * gamma * max_q_next_by_target # The loss measures the mean squared error between prediction and target. loss = tf.reduce_mean(tf.square(pred - tf.stop_gradient(y)), name=\u0026#34;loss_mse_train\u0026#34;) optimizer = tf.train.AdamOptimizer(0.001).minimize(loss, name=\u0026#34;adam_optim\u0026#34;) Note that tf.stop_gradient() on the target y, because the target network should stay fixed during the loss-minimizing gradient update.\nThe target network is updated by copying the primary Q network parameters over every C number of steps (\u0026ldquo;hard update\u0026rdquo;) or polyak averaging towards the primary network (\u0026ldquo;soft update\u0026rdquo;)\n# Get all the variables in the Q primary network. q_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=\u0026#34;Q_primary\u0026#34;) # Get all the variables in the Q target network. q_target_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=\u0026#34;Q_target\u0026#34;) assert len(q_vars) == len(q_target_vars) def update_target_q_net_hard(): # Hard update sess.run([v_t.assign(v) for v_t, v in zip(q_target_vars, q_vars)]) def update_target_q_net_soft(tau=0.05): # Soft update: polyak averaging. sess.run([v_t.assign(v_t * (1. - tau) + v * tau) for v_t, v in zip(q_target_vars, q_vars)]) Double Q-Learning If we look into the standard form of the Q value target, $Y(s, a) = r + \\gamma \\max_{a' \\in \\mathcal{A}} Q_\\theta (s', a')$, it is easy to notice that we use $Q_\\theta$ to select the best next action at state s' and then apply the action value predicted by the same $Q_\\theta$. This two-step reinforcing procedure could potentially lead to overestimation of an (already) overestimated value, further leading to training instability. The solution proposed by double Q-learning (Hasselt, 2010) is to decouple the action selection and action value estimation by using two Q networks, $Q_1$ and $Q_2$: when $Q_1$ is being updated, $Q_2$ decides the best next action, and vice versa.\n $$ Y_1(s, a, r, s') = r + \\gamma Q_1 (s', \\arg\\max_{a' \\in \\mathcal{A}}Q_2(s', a'))\\\\ Y_2(s, a, r, s') = r + \\gamma Q_2 (s', \\arg\\max_{a' \\in \\mathcal{A}}Q_1(s', a')) $$ To incorporate double Q-learning into DQN, the minimum modification (Hasselt, Guez, \u0026amp; Silver, 2016) is to use the primary Q network to select the action while the action value is estimated by the target network:\n $$ Y(s, a, r, s') = r + \\gamma Q_{\\theta^{-}}(s', \\arg\\max_{a' \\in \\mathcal{A}} Q_\\theta(s', a')) $$ In the code, we add a new tensor for getting the action selected by the primary Q network as the input and a tensor operation for selecting this action.\nactions_next = tf.placeholder(tf.int32, shape=(None,), name=\u0026#39;action_next\u0026#39;) actions_selected_by_q = tf.argmax(q, axis=-1, name=\u0026#39;action_selected\u0026#39;) The prediction target y in the loss function becomes:\nactions_next_flatten = actions_next + tf.range(0, batch_size) * q_target.shape[1] max_q_next_target = tf.gather(tf.reshape(q_target, [-1]), actions_next_flatten) y = rewards + (1. - done_flags) * gamma * max_q_next_by_target Here I used tf.gather() to select the action values of interests.\n(Image source: tf.gather() docs) During the episode rollout, we compute the actions_next by feeding the next states' data into the actions_selected_by_q operation.\n# batch_data is a dict with keys, ‘s\u0026#39;, ‘a\u0026#39;, ‘r\u0026#39;, ‘s_next\u0026#39; and ‘done\u0026#39;, containing a batch of transitions. actions_next = sess.run(actions_selected_by_q, {states: batch_data[\u0026#39;s_next\u0026#39;]}) Dueling Q-Network The dueling Q-network (Wang et al., 2016) is equipped with an enhanced network architecture: the output layer branches out into two heads, one for predicting state value, V, and the other for advantage, A. The Q-value is then reconstructed, $Q(s, a) = V(s) + A(s, a)$.\n $$ \\begin{aligned} A(s, a) \u0026= Q(s, a) - V(s)\\\\ V(s) \u0026= \\sum_a Q(s, a) \\pi(a \\vert s) = \\sum_a (V(s) + A(s, a)) \\pi(a \\vert s) = V(s) + \\sum_a A(s, a)\\pi(a \\vert s)\\\\ \\text{Thus, }\u0026 \\sum_a A(s, a)\\pi(a \\vert s) = 0 \\end{aligned} $$ To make sure the estimated advantage values sum up to zero, $\\sum_a A(s, a)\\pi(a \\vert s) = 0$, we deduct the mean value from the prediction.\n $$ Q(s, a) = V(s) + (A(s, a) - \\frac{1}{|\\mathcal{A}|} \\sum_a A(s, a)) $$ The code change is straightforward:\nq_hidden = dense_nn(states, [32], name=\u0026#39;Q_primary_hidden\u0026#39;) adv = dense_nn(q_hidden, [32, env.action_space.n], name=\u0026#39;Q_primary_adv\u0026#39;) v = dense_nn(q_hidden, [32, 1], name=\u0026#39;Q_primary_v\u0026#39;) # Average dueling q = v + (adv - tf.reduce_mean(adv, reduction_indices=1, keepdims=True)) (Image source: Wang et al., 2016) Check the code for the complete flow.\nMonte-Carlo Policy Gradient I reviewed a number of popular policy gradient methods in my last post. Monte-Carlo policy gradient, also known as REINFORCE, is a classic on-policy method that learns the policy model explicitly. It uses the return estimated from a full on-policy trajectory and updates the policy parameters with policy gradient.\nThe returns are computed during rollouts and then fed into the Tensorflow graph as inputs.\n# Inputs states = tf.placeholder(tf.float32, shape=(None, obs_size), name=\u0026#39;state\u0026#39;) actions = tf.placeholder(tf.int32, shape=(None,), name=\u0026#39;action\u0026#39;) returns = tf.placeholder(tf.float32, shape=(None,), name=\u0026#39;return\u0026#39;) The policy network is contructed. We update the policy parameters by minimizing the loss function, $\\mathcal{L} = - (G_t - V(s)) \\log \\pi(a \\vert s)$. tf.nn.sparse_softmax_cross_entropy_with_logits() asks for the raw logits as inputs, rather then the probabilities after softmax, and that\u0026rsquo;s why we do not have a softmax layer on top of the policy network.\n# Policy network pi = dense_nn(states, [32, 32, env.action_space.n], name=\u0026#39;pi_network\u0026#39;) sampled_actions = tf.squeeze(tf.multinomial(pi, 1)) # For sampling actions according to probabilities. with tf.variable_scope(\u0026#39;pi_optimize\u0026#39;): loss_pi = tf.reduce_mean( returns * tf.nn.sparse_softmax_cross_entropy_with_logits( logits=pi, labels=actions), name=\u0026#39;loss_pi\u0026#39;) optim_pi = tf.train.AdamOptimizer(0.001).minimize(loss_pi, name=\u0026#39;adam_optim_pi\u0026#39;) During the episode rollout, the return is calculated as follows:\n# env = gym.make(...) # gamma = 0.99 # sess = tf.Session(...) def act(ob): return sess.run(sampled_actions, {states: [ob]}) for _ in range(n_episodes): ob = env.reset() done = False obs = [] actions = [] rewards = [] returns = [] while not done: a = act(ob) new_ob, r, done, info = env.step(a) obs.append(ob) actions.append(a) rewards.append(r) ob = new_ob # Estimate returns backwards. return_so_far = 0.0 for r in rewards[::-1]: return_so_far = gamma * return_so_far + r returns.append(return_so_far) returns = returns[::-1] # Update the policy network with the data from one episode. sess.run([optim_pi], feed_dict={ states: np.array(obs), actions: np.array(actions), returns: np.array(returns), }) The full implementation of REINFORCE is here.\nActor-Critic The actor-critic algorithm learns two models at the same time, the actor for learning the best policy and the critic for estimating the state value.\n Initialize the actor network, $\\pi(a \\vert s)$ and the critic, $V(s)$ Collect a new transition (s, a, r, s'): Sample the action $a \\sim \\pi(a \\vert s)$ for the current state s, and get the reward r and the next state s'. Compute the TD target during episode rollout, $G_t = r + \\gamma V(s')$ and TD error, $\\delta_t = r + \\gamma V(s') - V(s)$. Update the critic network by minimizing the critic loss: $L_c = (V(s) - G_t)$. Update the actor network by minimizing the actor loss: $L_a = - \\delta_t \\log \\pi(a \\vert s)$. Set s' = s and repeat step 2.-5. Overall the implementation looks pretty similar to REINFORCE with an extra critic network. The full implementation is here.\n# Inputs states = tf.placeholder(tf.float32, shape=(None, observation_size), name=\u0026#39;state\u0026#39;) actions = tf.placeholder(tf.int32, shape=(None,), name=\u0026#39;action\u0026#39;) td_targets = tf.placeholder(tf.float32, shape=(None,), name=\u0026#39;td_target\u0026#39;) # Actor: action probabilities actor = dense_nn(states, [32, 32, env.action_space.n], name=\u0026#39;actor\u0026#39;) # Critic: action value (Q-value) critic = dense_nn(states, [32, 32, 1], name=\u0026#39;critic\u0026#39;) action_ohe = tf.one_hot(actions, act_size, 1.0, 0.0, name=\u0026#39;action_one_hot\u0026#39;) pred_value = tf.reduce_sum(critic * action_ohe, reduction_indices=-1, name=\u0026#39;q_acted\u0026#39;) td_errors = td_targets - tf.reshape(pred_value, [-1]) with tf.variable_scope(\u0026#39;critic_train\u0026#39;): loss_c = tf.reduce_mean(tf.square(td_errors)) optim_c = tf.train.AdamOptimizer(0.01).minimize(loss_c) with tf.variable_scope(\u0026#39;actor_train\u0026#39;): loss_a = tf.reduce_mean( tf.stop_gradient(td_errors) * tf.nn.sparse_softmax_cross_entropy_with_logits( logits=actor, labels=actions), name=\u0026#39;loss_actor\u0026#39;) optim_a = tf.train.AdamOptimizer(0.01).minimize(loss_a) train_ops = [optim_c, optim_a] The tensorboard graph is always helpful: References [1] Tensorflow API Docs\n[2] Christopher JCH Watkins, and Peter Dayan. \u0026ldquo;Q-learning.\u0026quot; Machine learning 8.3-4 (1992): 279-292.\n[3] Hado Van Hasselt, Arthur Guez, and David Silver. \u0026ldquo;Deep Reinforcement Learning with Double Q-Learning.\u0026quot; AAAI. Vol. 16. 2016.\n[4] Hado van Hasselt. \u0026ldquo;Double Q-learning.\u0026quot; NIPS, 23:2613–2621, 2010.\n[5] Ziyu Wang, et al. Dueling network architectures for deep reinforcement learning. ICML. 2016.\n","permalink":"https://lilianweng.github.io/posts/2018-05-05-drl-implementation/","summary":"The full implementation is available in lilianweng/deep-reinforcement-learning-gym\nIn the previous two posts, I have introduced the algorithms of many deep reinforcement learning models. Now it is the time to get our hands dirty and practice how to implement the models in the wild. The implementation is gonna be built in Tensorflow and OpenAI gym environment. The full version of the code in this tutorial is available in [lilian/deep-reinforcement-learning-gym].\nEnvironment Setup Make sure you have Homebrew installed: /usr/bin/ruby -e \u0026#34;$(curl -fsSL https://raw.","title":"Implementing Deep Reinforcement Learning Models with Tensorflow + OpenAI Gym"},{"content":"[Updated on 2018-06-30: add two new policy gradient methods, SAC and D4PG.] [Updated on 2018-09-30: add a new policy gradient method, TD3.] [Updated on 2019-02-09: add SAC with automatically adjusted temperature]. [Updated on 2019-06-26: Thanks to Chanseok, we have a version of this post in Korean]. [Updated on 2019-09-12: add a new policy gradient method SVPG.] [Updated on 2019-12-22: add a new policy gradient method IMPALA.] [Updated on 2020-10-15: add a new policy gradient method PPG \u0026amp; some new discussion in PPO.] [Updated on 2021-09-19: Thanks to Wenhao \u0026amp; 爱吃猫的鱼, we have this post in Chinese1 \u0026amp; Chinese2].\nWhat is Policy Gradient Policy gradient is an approach to solve reinforcement learning problems. If you haven\u0026rsquo;t looked into the field of reinforcement learning, please first read the section \u0026ldquo;A (Long) Peek into Reinforcement Learning \u0026raquo; Key Concepts\u0026rdquo; for the problem definition and key concepts.\nNotations Here is a list of notations to help you read through equations in the post easily.\n Symbol Meaning $s \\in \\mathcal{S}$ States. $a \\in \\mathcal{A}$ Actions. $r \\in \\mathcal{R}$ Rewards. $S_t, A_t, R_t$ State, action, and reward at time step $t$ of one trajectory. I may occasionally use $s_t, a_t, r_t$ as well. $\\gamma$ Discount factor; penalty to uncertainty of future rewards; $0\u0026lt;\\gamma \\leq 1$. $G_t$ Return; or discounted future reward; $G_t = \\sum_{k=0}^{\\infty} \\gamma^k R_{t+k+1}$. $P(s', r \\vert s, a)$ Transition probability of getting to the next state $s'$ from the current state $s$ with action $a$ and reward $r$. $\\pi(a \\vert s)$ Stochastic policy (agent behavior strategy); $\\pi_\\theta(.)$ is a policy parameterized by $\\theta$. $\\mu(s)$ Deterministic policy; we can also label this as $\\pi(s)$, but using a different letter gives better distinction so that we can easily tell when the policy is stochastic or deterministic without further explanation. Either $\\pi$ or $\\mu$ is what a reinforcement learning algorithm aims to learn. $V(s)$ State-value function measures the expected return of state $s$; $V_w(.)$ is a value function parameterized by $w$. $V^\\pi(s)$ The value of state $s$ when we follow a policy $\\pi$; $V^\\pi (s) = \\mathbb{E}_{a\\sim \\pi} [G_t \\vert S_t = s]$. $Q(s, a)$ Action-value function is similar to $V(s)$, but it assesses the expected return of a pair of state and action $(s, a)$; $Q_w(.)$ is a action value function parameterized by $w$. $Q^\\pi(s, a)$ Similar to $V^\\pi(.)$, the value of (state, action) pair when we follow a policy $\\pi$; $Q^\\pi(s, a) = \\mathbb{E}_{a\\sim \\pi} [G_t \\vert S_t = s, A_t = a]$. $A(s, a)$ Advantage function, $A(s, a) = Q(s, a) - V(s)$; it can be considered as another version of Q-value with lower variance by taking the state-value off as the baseline. Policy Gradient The goal of reinforcement learning is to find an optimal behavior strategy for the agent to obtain optimal rewards. The policy gradient methods target at modeling and optimizing the policy directly. The policy is usually modeled with a parameterized function respect to $\\theta$, $\\pi_\\theta(a \\vert s)$. The value of the reward (objective) function depends on this policy and then various algorithms can be applied to optimize $\\theta$ for the best reward.\nThe reward function is defined as:\n $$ J(\\theta) = \\sum_{s \\in \\mathcal{S}} d^\\pi(s) V^\\pi(s) = \\sum_{s \\in \\mathcal{S}} d^\\pi(s) \\sum_{a \\in \\mathcal{A}} \\pi_\\theta(a \\vert s) Q^\\pi(s, a) $$ where $d^\\pi(s)$ is the stationary distribution of Markov chain for $\\pi_\\theta$ (on-policy state distribution under $\\pi$). For simplicity, the parameter $\\theta$ would be omitted for the policy $\\pi_\\theta$ when the policy is present in the subscript of other functions; for example, $d^{\\pi}$ and $Q^\\pi$ should be $d^{\\pi_\\theta}$ and $Q^{\\pi_\\theta}$ if written in full.\nImagine that you can travel along the Markov chain\u0026rsquo;s states forever, and eventually, as the time progresses, the probability of you ending up with one state becomes unchanged \u0026mdash; this is the stationary probability for $\\pi_\\theta$. $d^\\pi(s) = \\lim_{t \\to \\infty} P(s_t = s \\vert s_0, \\pi_\\theta)$ is the probability that $s_t=s$ when starting from $s_0$ and following policy $\\pi_\\theta$ for t steps. Actually, the existence of the stationary distribution of Markov chain is one main reason for why PageRank algorithm works. If you want to read more, check this.\nIt is natural to expect policy-based methods are more useful in the continuous space. Because there is an infinite number of actions and (or) states to estimate the values for and hence value-based approaches are way too expensive computationally in the continuous space. For example, in generalized policy iteration, the policy improvement step $\\arg\\max_{a \\in \\mathcal{A}} Q^\\pi(s, a)$ requires a full scan of the action space, suffering from the curse of dimensionality.\nUsing gradient ascent, we can move $\\theta$ toward the direction suggested by the gradient $\\nabla_\\theta J(\\theta)$ to find the best $\\theta$ for $\\pi_\\theta$ that produces the highest return.\nPolicy Gradient Theorem Computing the gradient $\\nabla_\\theta J(\\theta)$ is tricky because it depends on both the action selection (directly determined by $\\pi_\\theta$) and the stationary distribution of states following the target selection behavior (indirectly determined by $\\pi_\\theta$). Given that the environment is generally unknown, it is difficult to estimate the effect on the state distribution by a policy update.\nLuckily, the policy gradient theorem comes to save the world! Woohoo! It provides a nice reformation of the derivative of the objective function to not involve the derivative of the state distribution $d^\\pi(.)$ and simplify the gradient computation $\\nabla_\\theta J(\\theta)$ a lot.\n $$ \\begin{aligned} \\nabla_\\theta J(\\theta) \u0026= \\nabla_\\theta \\sum_{s \\in \\mathcal{S}} d^\\pi(s) \\sum_{a \\in \\mathcal{A}} Q^\\pi(s, a) \\pi_\\theta(a \\vert s) \\\\ \u0026\\propto \\sum_{s \\in \\mathcal{S}} d^\\pi(s) \\sum_{a \\in \\mathcal{A}} Q^\\pi(s, a) \\nabla_\\theta \\pi_\\theta(a \\vert s) \\end{aligned} $$ Proof of Policy Gradient Theorem This session is pretty dense, as it is the time for us to go through the proof (Sutton \u0026amp; Barto, 2017; Sec. 13.1) and figure out why the policy gradient theorem is correct.\nWe first start with the derivative of the state value function:\n $$ \\begin{aligned} \u0026 \\nabla_\\theta V^\\pi(s) \\\\ =\u0026 \\nabla_\\theta \\Big(\\sum_{a \\in \\mathcal{A}} \\pi_\\theta(a \\vert s)Q^\\pi(s, a) \\Big) \u0026 \\\\ =\u0026 \\sum_{a \\in \\mathcal{A}} \\Big( \\nabla_\\theta \\pi_\\theta(a \\vert s)Q^\\pi(s, a) + \\pi_\\theta(a \\vert s) \\color{red}{\\nabla_\\theta Q^\\pi(s, a)} \\Big) \u0026 \\scriptstyle{\\text{; Derivative product rule.}} \\\\ =\u0026 \\sum_{a \\in \\mathcal{A}} \\Big( \\nabla_\\theta \\pi_\\theta(a \\vert s)Q^\\pi(s, a) + \\pi_\\theta(a \\vert s) \\color{red}{\\nabla_\\theta \\sum_{s', r} P(s',r \\vert s,a)(r + V^\\pi(s'))} \\Big) \u0026 \\scriptstyle{\\text{; Extend } Q^\\pi \\text{ with future state value.}} \\\\ =\u0026 \\sum_{a \\in \\mathcal{A}} \\Big( \\nabla_\\theta \\pi_\\theta(a \\vert s)Q^\\pi(s, a) + \\pi_\\theta(a \\vert s) \\color{red}{\\sum_{s', r} P(s',r \\vert s,a) \\nabla_\\theta V^\\pi(s')} \\Big) \u0026 \\scriptstyle{P(s',r \\vert s,a) \\text{ or } r \\text{ is not a func of }\\theta}\\\\ =\u0026 \\sum_{a \\in \\mathcal{A}} \\Big( \\nabla_\\theta \\pi_\\theta(a \\vert s)Q^\\pi(s, a) + \\pi_\\theta(a \\vert s) \\color{red}{\\sum_{s'} P(s' \\vert s,a) \\nabla_\\theta V^\\pi(s')} \\Big) \u0026 \\scriptstyle{\\text{; Because } P(s' \\vert s, a) = \\sum_r P(s', r \\vert s, a)} \\end{aligned} $$ Now we have:\n $$ \\color{red}{\\nabla_\\theta V^\\pi(s)} = \\sum_{a \\in \\mathcal{A}} \\Big( \\nabla_\\theta \\pi_\\theta(a \\vert s)Q^\\pi(s, a) + \\pi_\\theta(a \\vert s) \\sum_{s'} P(s' \\vert s,a) \\color{red}{\\nabla_\\theta V^\\pi(s')} \\Big) $$ This equation has a nice recursive form (see the red parts!) and the future state value function $V^\\pi(s')$ can be repeated unrolled by following the same equation.\nLet\u0026rsquo;s consider the following visitation sequence and label the probability of transitioning from state s to state x with policy $\\pi_\\theta$ after k step as $\\rho^\\pi(s \\to x, k)$.\n $$ s \\xrightarrow[]{a \\sim \\pi_\\theta(.\\vert s)} s' \\xrightarrow[]{a \\sim \\pi_\\theta(.\\vert s')} s'' \\xrightarrow[]{a \\sim \\pi_\\theta(.\\vert s'')} \\dots $$ When k = 0: $\\rho^\\pi(s \\to s, k=0) = 1$. When k = 1, we scan through all possible actions and sum up the transition probabilities to the target state: $\\rho^\\pi(s \\to s', k=1) = \\sum_a \\pi_\\theta(a \\vert s) P(s' \\vert s, a)$. Imagine that the goal is to go from state s to x after k+1 steps while following policy $\\pi_\\theta$. We can first travel from s to a middle point s' (any state can be a middle point, $s' \\in \\mathcal{S}$) after k steps and then go to the final state x during the last step. In this way, we are able to update the visitation probability recursively: $\\rho^\\pi(s \\to x, k+1) = \\sum_{s'} \\rho^\\pi(s \\to s', k) \\rho^\\pi(s' \\to x, 1)$. Then we go back to unroll the recursive representation of $\\nabla_\\theta V^\\pi(s)$! Let $\\phi(s) = \\sum_{a \\in \\mathcal{A}} \\nabla_\\theta \\pi_\\theta(a \\vert s)Q^\\pi(s, a)$ to simplify the maths. If we keep on extending $\\nabla_\\theta V^\\pi(.)$ infinitely, it is easy to find out that we can transition from the starting state s to any state after any number of steps in this unrolling process and by summing up all the visitation probabilities, we get $\\nabla_\\theta V^\\pi(s)$!\n $$ \\begin{aligned} \u0026 \\color{red}{\\nabla_\\theta V^\\pi(s)} \\\\ =\u0026 \\phi(s) + \\sum_a \\pi_\\theta(a \\vert s) \\sum_{s'} P(s' \\vert s,a) \\color{red}{\\nabla_\\theta V^\\pi(s')} \\\\ =\u0026 \\phi(s) + \\sum_{s'} \\sum_a \\pi_\\theta(a \\vert s) P(s' \\vert s,a) \\color{red}{\\nabla_\\theta V^\\pi(s')} \\\\ =\u0026 \\phi(s) + \\sum_{s'} \\rho^\\pi(s \\to s', 1) \\color{red}{\\nabla_\\theta V^\\pi(s')} \\\\ =\u0026 \\phi(s) + \\sum_{s'} \\rho^\\pi(s \\to s', 1) \\color{red}{\\nabla_\\theta V^\\pi(s')} \\\\ =\u0026 \\phi(s) + \\sum_{s'} \\rho^\\pi(s \\to s', 1) \\color{red}{[ \\phi(s') + \\sum_{s''} \\rho^\\pi(s' \\to s'', 1) \\nabla_\\theta V^\\pi(s'')]} \\\\ =\u0026 \\phi(s) + \\sum_{s'} \\rho^\\pi(s \\to s', 1) \\phi(s') + \\sum_{s''} \\rho^\\pi(s \\to s'', 2)\\color{red}{\\nabla_\\theta V^\\pi(s'')} \\scriptstyle{\\text{ ; Consider }s'\\text{ as the middle point for }s \\to s''}\\\\ =\u0026 \\phi(s) + \\sum_{s'} \\rho^\\pi(s \\to s', 1) \\phi(s') + \\sum_{s''} \\rho^\\pi(s \\to s'', 2)\\phi(s'') + \\sum_{s'''} \\rho^\\pi(s \\to s''', 3)\\color{red}{\\nabla_\\theta V^\\pi(s''')} \\\\ =\u0026 \\dots \\scriptstyle{\\text{; Repeatedly unrolling the part of }\\nabla_\\theta V^\\pi(.)} \\\\ =\u0026 \\sum_{x\\in\\mathcal{S}}\\sum_{k=0}^\\infty \\rho^\\pi(s \\to x, k) \\phi(x) \\end{aligned} $$ The nice rewriting above allows us to exclude the derivative of Q-value function, $\\nabla_\\theta Q^\\pi(s, a)$. By plugging it into the objective function $J(\\theta)$, we are getting the following:\n $$ \\begin{aligned} \\nabla_\\theta J(\\theta) \u0026= \\nabla_\\theta V^\\pi(s_0) \u0026 \\scriptstyle{\\text{; Starting from a random state } s_0} \\\\ \u0026= \\sum_{s}\\color{blue}{\\sum_{k=0}^\\infty \\rho^\\pi(s_0 \\to s, k)} \\phi(s) \u0026\\scriptstyle{\\text{; Let }\\color{blue}{\\eta(s) = \\sum_{k=0}^\\infty \\rho^\\pi(s_0 \\to s, k)}} \\\\ \u0026= \\sum_{s}\\eta(s) \\phi(s) \u0026 \\\\ \u0026= \\Big( {\\sum_s \\eta(s)} \\Big)\\sum_{s}\\frac{\\eta(s)}{\\sum_s \\eta(s)} \\phi(s) \u0026 \\scriptstyle{\\text{; Normalize } \\eta(s), s\\in\\mathcal{S} \\text{ to be a probability distribution.}}\\\\ \u0026\\propto \\sum_s \\frac{\\eta(s)}{\\sum_s \\eta(s)} \\phi(s) \u0026 \\scriptstyle{\\sum_s \\eta(s)\\text{ is a constant}} \\\\ \u0026= \\sum_s d^\\pi(s) \\sum_a \\nabla_\\theta \\pi_\\theta(a \\vert s)Q^\\pi(s, a) \u0026 \\scriptstyle{d^\\pi(s) = \\frac{\\eta(s)}{\\sum_s \\eta(s)}\\text{ is stationary distribution.}} \\end{aligned} $$ In the episodic case, the constant of proportionality ($\\sum_s \\eta(s)$) is the average length of an episode; in the continuing case, it is 1 (Sutton \u0026amp; Barto, 2017; Sec. 13.2). The gradient can be further written as:\n $$ \\begin{aligned} \\nabla_\\theta J(\\theta) \u0026\\propto \\sum_{s \\in \\mathcal{S}} d^\\pi(s) \\sum_{a \\in \\mathcal{A}} Q^\\pi(s, a) \\nabla_\\theta \\pi_\\theta(a \\vert s) \u0026\\\\ \u0026= \\sum_{s \\in \\mathcal{S}} d^\\pi(s) \\sum_{a \\in \\mathcal{A}} \\pi_\\theta(a \\vert s) Q^\\pi(s, a) \\frac{\\nabla_\\theta \\pi_\\theta(a \\vert s)}{\\pi_\\theta(a \\vert s)} \u0026\\\\ \u0026= \\mathbb{E}_\\pi [Q^\\pi(s, a) \\nabla_\\theta \\ln \\pi_\\theta(a \\vert s)] \u0026 \\scriptstyle{\\text{; Because } (\\ln x)' = 1/x} \\end{aligned} $$ Where $\\mathbb{E}_\\pi$ refers to $\\mathbb{E}_{s \\sim d_\\pi, a \\sim \\pi_\\theta}$ when both state and action distributions follow the policy $\\pi_\\theta$ (on policy).\nThe policy gradient theorem lays the theoretical foundation for various policy gradient algorithms. This vanilla policy gradient update has no bias but high variance. Many following algorithms were proposed to reduce the variance while keeping the bias unchanged.\n $$ \\nabla_\\theta J(\\theta) = \\mathbb{E}_\\pi [Q^\\pi(s, a) \\nabla_\\theta \\ln \\pi_\\theta(a \\vert s)] $$ Here is a nice summary of a general form of policy gradient methods borrowed from the GAE (general advantage estimation) paper (Schulman et al., 2016) and this post thoroughly discussed several components in GAE , highly recommended.\nFig. 1. A general form of policy gradient methods. (Image source: Schulman et al., 2016) Policy Gradient Algorithms Tons of policy gradient algorithms have been proposed during recent years and there is no way for me to exhaust them. I\u0026rsquo;m introducing some of them that I happened to know and read about.\nREINFORCE REINFORCE (Monte-Carlo policy gradient) relies on an estimated return by Monte-Carlo methods using episode samples to update the policy parameter $\\theta$. REINFORCE works because the expectation of the sample gradient is equal to the actual gradient:\n $$ \\begin{aligned} \\nabla_\\theta J(\\theta) \u0026= \\mathbb{E}_\\pi [Q^\\pi(s, a) \\nabla_\\theta \\ln \\pi_\\theta(a \\vert s)] \u0026 \\\\ \u0026= \\mathbb{E}_\\pi [G_t \\nabla_\\theta \\ln \\pi_\\theta(A_t \\vert S_t)] \u0026 \\scriptstyle{\\text{; Because } Q^\\pi(S_t, A_t) = \\mathbb{E}_\\pi[G_t \\vert S_t, A_t]} \\end{aligned} $$ Therefore we are able to measure $G_t$ from real sample trajectories and use that to update our policy gradient. It relies on a full trajectory and that\u0026rsquo;s why it is a Monte-Carlo method.\nThe process is pretty straightforward:\n Initialize the policy parameter $\\theta$ at random. Generate one trajectory on policy $\\pi_\\theta$: $S_1, A_1, R_2, S_2, A_2, \\dots, S_T$. For t=1, 2, \u0026hellip; , T: Estimate the the return $G_t$; Update policy parameters: $\\theta \\leftarrow \\theta + \\alpha \\gamma^t G_t \\nabla_\\theta \\ln \\pi_\\theta(A_t \\vert S_t)$ A widely used variation of REINFORCE is to subtract a baseline value from the return $G_t$ to reduce the variance of gradient estimation while keeping the bias unchanged (Remember we always want to do this when possible). For example, a common baseline is to subtract state-value from action-value, and if applied, we would use advantage $A(s, a) = Q(s, a) - V(s)$ in the gradient ascent update. This post nicely explained why a baseline works for reducing the variance, in addition to a set of fundamentals of policy gradient.\nActor-Critic Two main components in policy gradient are the policy model and the value function. It makes a lot of sense to learn the value function in addition to the policy, since knowing the value function can assist the policy update, such as by reducing gradient variance in vanilla policy gradients, and that is exactly what the Actor-Critic method does.\nActor-critic methods consist of two models, which may optionally share parameters:\n Critic updates the value function parameters w and depending on the algorithm it could be action-value $Q_w(a \\vert s)$ or state-value $V_w(s)$. Actor updates the policy parameters $\\theta$ for $\\pi_\\theta(a \\vert s)$, in the direction suggested by the critic. Let\u0026rsquo;s see how it works in a simple action-value actor-critic algorithm.\n Initialize $s, \\theta, w$ at random; sample $a \\sim \\pi_\\theta(a \\vert s)$. For $t = 1 \\dots T$: Sample reward $r_t \\sim R(s, a)$ and next state $s' \\sim P(s' \\vert s, a)$; Then sample the next action $a' \\sim \\pi_\\theta(a' \\vert s')$; Update the policy parameters: $\\theta \\leftarrow \\theta + \\alpha_\\theta Q_w(s, a) \\nabla_\\theta \\ln \\pi_\\theta(a \\vert s)$; Compute the correction (TD error) for action-value at time t: $\\delta_t = r_t + \\gamma Q_w(s', a') - Q_w(s, a)$ and use it to update the parameters of action-value function: $w \\leftarrow w + \\alpha_w \\delta_t \\nabla_w Q_w(s, a)$ Update $a \\leftarrow a'$ and $s \\leftarrow s'$. Two learning rates, $\\alpha_\\theta$ and $\\alpha_w$, are predefined for policy and value function parameter updates respectively.\nOff-Policy Policy Gradient Both REINFORCE and the vanilla version of actor-critic method are on-policy: training samples are collected according to the target policy \u0026mdash; the very same policy that we try to optimize for. Off policy methods, however, result in several additional advantages:\n The off-policy approach does not require full trajectories and can reuse any past episodes (“experience replay”) for much better sample efficiency. The sample collection follows a behavior policy different from the target policy, bringing better exploration. Now let\u0026rsquo;s see how off-policy policy gradient is computed. The behavior policy for collecting samples is a known policy (predefined just like a hyperparameter), labelled as $\\beta(a \\vert s)$. The objective function sums up the reward over the state distribution defined by this behavior policy:\n $$ J(\\theta) = \\sum_{s \\in \\mathcal{S}} d^\\beta(s) \\sum_{a \\in \\mathcal{A}} Q^\\pi(s, a) \\pi_\\theta(a \\vert s) = \\mathbb{E}_{s \\sim d^\\beta} \\big[ \\sum_{a \\in \\mathcal{A}} Q^\\pi(s, a) \\pi_\\theta(a \\vert s) \\big] $$ where $d^\\beta(s)$ is the stationary distribution of the behavior policy $\\beta$; recall that $d^\\beta(s) = \\lim_{t \\to \\infty} P(S_t = s \\vert S_0, \\beta)$; and $Q^\\pi$ is the action-value function estimated with regard to the target policy $\\pi$ (not the behavior policy!).\nGiven that the training observations are sampled by $a \\sim \\beta(a \\vert s)$, we can rewrite the gradient as:\n $$ \\begin{aligned} \\nabla_\\theta J(\\theta) \u0026= \\nabla_\\theta \\mathbb{E}_{s \\sim d^\\beta} \\Big[ \\sum_{a \\in \\mathcal{A}} Q^\\pi(s, a) \\pi_\\theta(a \\vert s) \\Big] \u0026 \\\\ \u0026= \\mathbb{E}_{s \\sim d^\\beta} \\Big[ \\sum_{a \\in \\mathcal{A}} \\big( Q^\\pi(s, a) \\nabla_\\theta \\pi_\\theta(a \\vert s) + \\color{red}{\\pi_\\theta(a \\vert s) \\nabla_\\theta Q^\\pi(s, a)} \\big) \\Big] \u0026 \\scriptstyle{\\text{; Derivative product rule.}}\\\\ \u0026\\stackrel{(i)}{\\approx} \\mathbb{E}_{s \\sim d^\\beta} \\Big[ \\sum_{a \\in \\mathcal{A}} Q^\\pi(s, a) \\nabla_\\theta \\pi_\\theta(a \\vert s) \\Big] \u0026 \\scriptstyle{\\text{; Ignore the red part: } \\color{red}{\\pi_\\theta(a \\vert s) \\nabla_\\theta Q^\\pi(s, a)}}. \\\\ \u0026= \\mathbb{E}_{s \\sim d^\\beta} \\Big[ \\sum_{a \\in \\mathcal{A}} \\beta(a \\vert s) \\frac{\\pi_\\theta(a \\vert s)}{\\beta(a \\vert s)} Q^\\pi(s, a) \\frac{\\nabla_\\theta \\pi_\\theta(a \\vert s)}{\\pi_\\theta(a \\vert s)} \\Big] \u0026 \\\\ \u0026= \\mathbb{E}_\\beta \\Big[\\frac{\\color{blue}{\\pi_\\theta(a \\vert s)}}{\\color{blue}{\\beta(a \\vert s)}} Q^\\pi(s, a) \\nabla_\\theta \\ln \\pi_\\theta(a \\vert s) \\Big] \u0026 \\scriptstyle{\\text{; The blue part is the importance weight.}} \\end{aligned} $$ where $\\frac{\\pi_\\theta(a \\vert s)}{\\beta(a \\vert s)}$ is the importance weight. Because $Q^\\pi$ is a function of the target policy and thus a function of policy parameter $\\theta$, we should take the derivative of $\\nabla_\\theta Q^\\pi(s, a)$ as well according to the product rule. However, it is super hard to compute $\\nabla_\\theta Q^\\pi(s, a)$ in reality. Fortunately if we use an approximated gradient with the gradient of Q ignored, we still guarantee the policy improvement and eventually achieve the true local minimum. This is justified in the proof here (Degris, White \u0026amp; Sutton, 2012).\nIn summary, when applying policy gradient in the off-policy setting, we can simple adjust it with a weighted sum and the weight is the ratio of the target policy to the behavior policy, $\\frac{\\pi_\\theta(a \\vert s)}{\\beta(a \\vert s)}$.\nA3C [paper|code]\nAsynchronous Advantage Actor-Critic (Mnih et al., 2016), short for A3C, is a classic policy gradient method with a special focus on parallel training.\nIn A3C, the critics learn the value function while multiple actors are trained in parallel and get synced with global parameters from time to time. Hence, A3C is designed to work well for parallel training.\nLet\u0026rsquo;s use the state-value function as an example. The loss function for state value is to minimize the mean squared error, $J_v(w) = (G_t - V_w(s))^2$ and gradient descent can be applied to find the optimal w. This state-value function is used as the baseline in the policy gradient update.\nHere is the algorithm outline:\n We have global parameters, $\\theta$ and $w$; similar thread-specific parameters, $\\theta'$ and $w'$.\n Initialize the time step $t = 1$\n While $T \\leq T_\\text{MAX}$:\n Reset gradient: $\\mathrm{d}\\theta = 0$ and $\\mathrm{d}w = 0$. Synchronize thread-specific parameters with global ones: $\\theta' = \\theta$ and $w' = w$. $t_\\text{start}$ = t and sample a starting state $s_t$. While ($s_t$ != TERMINAL) and $t - t_\\text{start} \\leq t_\\text{max}$: Pick the action $A_t \\sim \\pi_{\\theta'}(A_t \\vert S_t)$ and receive a new reward $R_t$ and a new state $s_{t+1}$. Update $t = t + 1$ and $T = T + 1$ Initialize the variable that holds the return estimation $$ R = \\begin{cases} 0 \u0026 \\text{if } s_t \\text{ is TERMINAL} \\\\ V_{w'}(s_t) \u0026 \\text{otherwise} \\end{cases} $$ 6. For $i = t-1, \\dots, t\\_\\text{start}$: 1. $R \\leftarrow \\gamma R + R\\_i$; here R is a MC measure of $G\\_i$. 2. Accumulate gradients w.r.t. $\\theta'$: $d\\theta \\leftarrow d\\theta + \\nabla\\_{\\theta'} \\log \\pi\\_{\\theta'}(a\\_i \\vert s\\_i)(R - V\\_{w'}(s\\_i))$;Accumulate gradients w.r.t. w': $dw \\leftarrow dw + 2 (R - V\\_{w'}(s\\_i)) \\nabla\\_{w'} (R - V\\_{w'}(s\\_i))$. Update asynchronously $\\theta$ using $\\mathrm{d}\\theta$, and $w$ using $\\mathrm{d}w$. A3C enables the parallelism in multiple agent training. The gradient accumulation step (6.2) can be considered as a parallelized reformation of minibatch-based stochastic gradient update: the values of $w$ or $\\theta$ get corrected by a little bit in the direction of each training thread independently.\nA2C [paper|code]\nA2C is a synchronous, deterministic version of A3C; that\u0026rsquo;s why it is named as “A2C” with the first “A” (“asynchronous”) removed. In A3C each agent talks to the global parameters independently, so it is possible sometimes the thread-specific agents would be playing with policies of different versions and therefore the aggregated update would not be optimal. To resolve the inconsistency, a coordinator in A2C waits for all the parallel actors to finish their work before updating the global parameters and then in the next iteration parallel actors starts from the same policy. The synchronized gradient update keeps the training more cohesive and potentially to make convergence faster.\nA2C has been shown to be able to utilize GPUs more efficiently and work better with large batch sizes while achieving same or better performance than A3C.\nFig. 2. The architecture of A3C versus A2C. DPG [paper|code]\nIn methods described above, the policy function $\\pi(. \\vert s)$ is always modeled as a probability distribution over actions $\\mathcal{A}$ given the current state and thus it is stochastic. Deterministic policy gradient (DPG) instead models the policy as a deterministic decision: $a = \\mu(s)$. It may look bizarre \u0026mdash; how can you calculate the gradient of the action probability when it outputs a single action? Let\u0026rsquo;s look into it step by step.\nRefresh on a few notations to facilitate the discussion:\n $\\rho_0(s)$: The initial distribution over states $\\rho^\\mu(s \\to s', k)$: Starting from state s, the visitation probability density at state s' after moving k steps by policy $\\mu$. $\\rho^\\mu(s')$: Discounted state distribution, defined as $\\rho^\\mu(s') = \\int_\\mathcal{S} \\sum_{k=1}^\\infty \\gamma^{k-1} \\rho_0(s) \\rho^\\mu(s \\to s', k) ds$. The objective function to optimize for is listed as follows:\n $$ J(\\theta) = \\int_\\mathcal{S} \\rho^\\mu(s) Q(s, \\mu_\\theta(s)) ds $$ Deterministic policy gradient theorem: Now it is the time to compute the gradient! According to the chain rule, we first take the gradient of Q w.r.t. the action a and then take the gradient of the deterministic policy function $\\mu$ w.r.t. $\\theta$:\n $$ \\begin{aligned} \\nabla_\\theta J(\\theta) \u0026= \\int_\\mathcal{S} \\rho^\\mu(s) \\nabla_a Q^\\mu(s, a) \\nabla_\\theta \\mu_\\theta(s) \\rvert_{a=\\mu_\\theta(s)} ds \\\\ \u0026= \\mathbb{E}_{s \\sim \\rho^\\mu} [\\nabla_a Q^\\mu(s, a) \\nabla_\\theta \\mu_\\theta(s) \\rvert_{a=\\mu_\\theta(s)}] \\end{aligned} $$ We can consider the deterministic policy as a special case of the stochastic one, when the probability distribution contains only one extreme non-zero value over one action. Actually, in the DPG paper, the authors have shown that if the stochastic policy $\\pi_{\\mu_\\theta, \\sigma}$ is re-parameterized by a deterministic policy $\\mu_\\theta$ and a variation variable $\\sigma$, the stochastic policy is eventually equivalent to the deterministic case when $\\sigma=0$. Compared to the deterministic policy, we expect the stochastic policy to require more samples as it integrates the data over the whole state and action space.\nThe deterministic policy gradient theorem can be plugged into common policy gradient frameworks.\nLet\u0026rsquo;s consider an example of on-policy actor-critic algorithm to showcase the procedure. In each iteration of on-policy actor-critic, two actions are taken deterministically $a = \\mu_\\theta(s)$ and the SARSA update on policy parameters relies on the new gradient that we just computed above:\n $$ \\begin{aligned} \\delta_t \u0026= R_t + \\gamma Q_w(s_{t+1}, a_{t+1}) - Q_w(s_t, a_t) \u0026 \\small{\\text{; TD error in SARSA}}\\\\ w_{t+1} \u0026= w_t + \\alpha_w \\delta_t \\nabla_w Q_w(s_t, a_t) \u0026 \\\\ \\theta_{t+1} \u0026= \\theta_t + \\alpha_\\theta \\color{red}{\\nabla_a Q_w(s_t, a_t) \\nabla_\\theta \\mu_\\theta(s) \\rvert_{a=\\mu_\\theta(s)}} \u0026 \\small{\\text{; Deterministic policy gradient theorem}} \\end{aligned} $$ However, unless there is sufficient noise in the environment, it is very hard to guarantee enough exploration due to the determinacy of the policy. We can either add noise into the policy (ironically this makes it nondeterministic!) or learn it off-policy-ly by following a different stochastic behavior policy to collect samples.\nSay, in the off-policy approach, the training trajectories are generated by a stochastic policy $\\beta(a \\vert s)$ and thus the state distribution follows the corresponding discounted state density $\\rho^\\beta$:\n $$ \\begin{aligned} J_\\beta(\\theta) \u0026= \\int_\\mathcal{S} \\rho^\\beta Q^\\mu(s, \\mu_\\theta(s)) ds \\\\ \\nabla_\\theta J_\\beta(\\theta) \u0026= \\mathbb{E}_{s \\sim \\rho^\\beta} [\\nabla_a Q^\\mu(s, a) \\nabla_\\theta \\mu_\\theta(s) \\rvert_{a=\\mu_\\theta(s)} ] \\end{aligned} $$ Note that because the policy is deterministic, we only need $Q^\\mu(s, \\mu_\\theta(s))$ rather than $\\sum_a \\pi(a \\vert s) Q^\\pi(s, a)$ as the estimated reward of a given state s. In the off-policy approach with a stochastic policy, importance sampling is often used to correct the mismatch between behavior and target policies, as what we have described above. However, because the deterministic policy gradient removes the integral over actions, we can avoid importance sampling.\nDDPG [paper|code]\nDDPG (Lillicrap, et al., 2015), short for Deep Deterministic Policy Gradient, is a model-free off-policy actor-critic algorithm, combining DPG with DQN. Recall that DQN (Deep Q-Network) stabilizes the learning of Q-function by experience replay and the frozen target network. The original DQN works in discrete space, and DDPG extends it to continuous space with the actor-critic framework while learning a deterministic policy.\nIn order to do better exploration, an exploration policy $\\mu'$ is constructed by adding noise $\\mathcal{N}$:\n $$ \\mu'(s) = \\mu_\\theta(s) + \\mathcal{N} $$ In addition, DDPG does soft updates (\u0026ldquo;conservative policy iteration\u0026rdquo;) on the parameters of both actor and critic, with $\\tau \\ll 1$: $\\theta' \\leftarrow \\tau \\theta + (1 - \\tau) \\theta'$. In this way, the target network values are constrained to change slowly, different from the design in DQN that the target network stays frozen for some period of time.\nOne detail in the paper that is particularly useful in robotics is on how to normalize the different physical units of low dimensional features. For example, a model is designed to learn a policy with the robot\u0026rsquo;s positions and velocities as input; these physical statistics are different by nature and even statistics of the same type may vary a lot across multiple robots. Batch normalization is applied to fix it by normalizing every dimension across samples in one minibatch.\nFig 3. DDPG Algorithm. (Image source: Lillicrap, et al., 2015) D4PG [paper|code (Search “github d4pg” and you will see a few.)]\nDistributed Distributional DDPG (D4PG) applies a set of improvements on DDPG to make it run in the distributional fashion.\n(1) Distributional Critic: The critic estimates the expected Q value as a random variable ~ a distribution $Z_w$ parameterized by $w$ and therefore $Q_w(s, a) = \\mathbb{E} Z_w(x, a)$. The loss for learning the distribution parameter is to minimize some measure of the distance between two distributions \u0026mdash; distributional TD error: $L(w) = \\mathbb{E}[d(\\mathcal{T}_{\\mu_\\theta}, Z_{w'}(s, a), Z_w(s, a)]$, where $\\mathcal{T}_{\\mu_\\theta}$ is the Bellman operator.\nThe deterministic policy gradient update becomes:\n $$ \\begin{aligned} \\nabla_\\theta J(\\theta) \u0026\\approx \\mathbb{E}_{\\rho^\\mu} [\\nabla_a Q_w(s, a) \\nabla_\\theta \\mu_\\theta(s) \\rvert_{a=\\mu_\\theta(s)}] \u0026 \\scriptstyle{\\text{; gradient update in DPG}} \\\\ \u0026= \\mathbb{E}_{\\rho^\\mu} [\\mathbb{E}[\\nabla_a Z_w(s, a)] \\nabla_\\theta \\mu_\\theta(s) \\rvert_{a=\\mu_\\theta(s)}] \u0026 \\scriptstyle{\\text{; expectation of the Q-value distribution.}} \\end{aligned} $$ (2) $N$-step returns: When calculating the TD error, D4PG computes $N$-step TD target rather than one-step to incorporate rewards in more future steps. Thus the new TD target is:\n $$ r(s_0, a_0) + \\mathbb{E}[\\sum_{n=1}^{N-1} r(s_n, a_n) + \\gamma^N Q(s_N, \\mu_\\theta(s_N)) \\vert s_0, a_0 ] $$ (3) Multiple Distributed Parallel Actors: D4PG utilizes $K$ independent actors, gathering experience in parallel and feeding data into the same replay buffer.\n(4) Prioritized Experience Replay (PER): The last piece of modification is to do sampling from the replay buffer of size $R$ with an non-uniform probability $p_i$. In this way, a sample $i$ has the probability $(Rp_i)^{-1}$ to be selected and thus the importance weight is $(Rp_i)^{-1}$.\nFig. 4. D4PG algorithm (Image source: Barth-Maron, et al. 2018); Note that in the original paper, the variable letters are chosen slightly differently from what in the post; i.e. I use $\\mu(.)$ for representing a deterministic policy instead of $\\pi(.)$. MADDPG [paper|code]\nMulti-agent DDPG (MADDPG) (Lowe et al., 2017) extends DDPG to an environment where multiple agents are coordinating to complete tasks with only local information. In the viewpoint of one agent, the environment is non-stationary as policies of other agents are quickly upgraded and remain unknown. MADDPG is an actor-critic model redesigned particularly for handling such a changing environment and interactions between agents.\nThe problem can be formalized in the multi-agent version of MDP, also known as Markov games. MADDPG is proposed for partially observable Markov games. Say, there are N agents in total with a set of states $\\mathcal{S}$. Each agent owns a set of possible action, $\\mathcal{A}_1, \\dots, \\mathcal{A}_N$, and a set of observation, $\\mathcal{O}_1, \\dots, \\mathcal{O}_N$. The state transition function involves all states, action and observation spaces $\\mathcal{T}: \\mathcal{S} \\times \\mathcal{A}_1 \\times \\dots \\mathcal{A}_N \\mapsto \\mathcal{S}$. Each agent\u0026rsquo;s stochastic policy only involves its own state and action: $\\pi_{\\theta_i}: \\mathcal{O}_i \\times \\mathcal{A}_i \\mapsto [0, 1]$, a probability distribution over actions given its own observation, or a deterministic policy: $\\mu_{\\theta_i}: \\mathcal{O}_i \\mapsto \\mathcal{A}_i$.\nLet $\\vec{o} = {o_1, \\dots, o_N}$, $\\vec{\\mu} = {\\mu_1, \\dots, \\mu_N}$ and the policies are parameterized by $\\vec{\\theta} = {\\theta_1, \\dots, \\theta_N}$.\nThe critic in MADDPG learns a centralized action-value function $Q^\\vec{\\mu}_i(\\vec{o}, a_1, \\dots, a_N)$ for the i-th agent, where $a_1 \\in \\mathcal{A}_1, \\dots, a_N \\in \\mathcal{A}_N$ are actions of all agents. Each $Q^\\vec{\\mu}_i$ is learned separately for $i=1, \\dots, N$ and therefore multiple agents can have arbitrary reward structures, including conflicting rewards in a competitive setting. Meanwhile, multiple actors, one for each agent, are exploring and upgrading the policy parameters $\\theta_i$ on their own.\nActor update:\n $$ \\nabla_{\\theta_i} J(\\theta_i) = \\mathbb{E}_{\\vec{o}, a \\sim \\mathcal{D}} [\\nabla_{a_i} Q^{\\vec{\\mu}}_i (\\vec{o}, a_1, \\dots, a_N) \\nabla_{\\theta_i} \\mu_{\\theta_i}(o_i) \\rvert_{a_i=\\mu_{\\theta_i}(o_i)} ] $$ Where $\\mathcal{D}$ is the memory buffer for experience replay, containing multiple episode samples $(\\vec{o}, a_1, \\dots, a_N, r_1, \\dots, r_N, \\vec{o}')$ \u0026mdash; given current observation $\\vec{o}$, agents take action $a_1, \\dots, a_N$ and get rewards $r_1, \\dots, r_N$, leading to the new observation $\\vec{o}'$.\nCritic update:\n $$ \\begin{aligned} \\mathcal{L}(\\theta_i) \u0026= \\mathbb{E}_{\\vec{o}, a_1, \\dots, a_N, r_1, \\dots, r_N, \\vec{o}'}[ (Q^{\\vec{\\mu}}_i(\\vec{o}, a_1, \\dots, a_N) - y)^2 ] \u0026 \\\\ \\text{where } y \u0026= r_i + \\gamma Q^{\\vec{\\mu}'}_i (\\vec{o}', a'_1, \\dots, a'_N) \\rvert_{a'_j = \\mu'_{\\theta_j}} \u0026 \\scriptstyle{\\text{; TD target!}} \\end{aligned} $$ where $\\vec{\\mu}'$ are the target policies with delayed softly-updated parameters.\nIf the policies $\\vec{\\mu}$ are unknown during the critic update, we can ask each agent to learn and evolve its own approximation of others' policies. Using the approximated policies, MADDPG still can learn efficiently although the inferred policies might not be accurate.\nTo mitigate the high variance triggered by the interaction between competing or collaborating agents in the environment, MADDPG proposed one more element - policy ensembles:\n Train K policies for one agent; Pick a random policy for episode rollouts; Take an ensemble of these K policies to do gradient update. In summary, MADDPG added three additional ingredients on top of DDPG to make it adapt to the multi-agent environment:\n Centralized critic + decentralized actors; Actors are able to use estimated policies of other agents for learning; Policy ensembling is good for reducing variance. Fig. 5. The architecture design of MADDPG. (Image source: Lowe et al., 2017) TRPO [paper|code]\nTo improve training stability, we should avoid parameter updates that change the policy too much at one step. Trust region policy optimization (TRPO) (Schulman, et al., 2015) carries out this idea by enforcing a KL divergence constraint on the size of policy update at each iteration.\nConsider the case when we are doing off-policy RL, the policy $\\beta$ used for collecting trajectories on rollout workers is different from the policy $\\pi$ to optimize for. The objective function in an off-policy model measures the total advantage over the state visitation distribution and actions, while the mismatch between the training data distribution and the true policy state distribution is compensated by importance sampling estimator:\n $$ \\begin{aligned} J(\\theta) \u0026= \\sum_{s \\in \\mathcal{S}} \\rho^{\\pi_{\\theta_\\text{old}}} \\sum_{a \\in \\mathcal{A}} \\big( \\pi_\\theta(a \\vert s) \\hat{A}_{\\theta_\\text{old}}(s, a) \\big) \u0026 \\\\ \u0026= \\sum_{s \\in \\mathcal{S}} \\rho^{\\pi_{\\theta_\\text{old}}} \\sum_{a \\in \\mathcal{A}} \\big( \\beta(a \\vert s) \\frac{\\pi_\\theta(a \\vert s)}{\\beta(a \\vert s)} \\hat{A}_{\\theta_\\text{old}}(s, a) \\big) \u0026 \\scriptstyle{\\text{; Importance sampling}} \\\\ \u0026= \\mathbb{E}_{s \\sim \\rho^{\\pi_{\\theta_\\text{old}}}, a \\sim \\beta} \\big[ \\frac{\\pi_\\theta(a \\vert s)}{\\beta(a \\vert s)} \\hat{A}_{\\theta_\\text{old}}(s, a) \\big] \u0026 \\end{aligned} $$ where $\\theta_\\text{old}$ is the policy parameters before the update and thus known to us; $\\rho^{\\pi_{\\theta_\\text{old}}}$ is defined in the same way as above; $\\beta(a \\vert s)$ is the behavior policy for collecting trajectories. Noted that we use an estimated advantage $\\hat{A}(.)$ rather than the true advantage function $A(.)$ because the true rewards are usually unknown.\nWhen training on policy, theoretically the policy for collecting data is same as the policy that we want to optimize. However, when rollout workers and optimizers are running in parallel asynchronously, the behavior policy can get stale. TRPO considers this subtle difference: It labels the behavior policy as $\\pi_{\\theta_\\text{old}}(a \\vert s)$ and thus the objective function becomes:\n $$ J(\\theta) = \\mathbb{E}_{s \\sim \\rho^{\\pi_{\\theta_\\text{old}}}, a \\sim \\pi_{\\theta_\\text{old}}} \\big[ \\frac{\\pi_\\theta(a \\vert s)}{\\pi_{\\theta_\\text{old}}(a \\vert s)} \\hat{A}_{\\theta_\\text{old}}(s, a) \\big] $$ TRPO aims to maximize the objective function $J(\\theta)$ subject to, trust region constraint which enforces the distance between old and new policies measured by KL-divergence to be small enough, within a parameter δ:\n $$ \\mathbb{E}_{s \\sim \\rho^{\\pi_{\\theta_\\text{old}}}} [D_\\text{KL}(\\pi_{\\theta_\\text{old}}(.\\vert s) \\| \\pi_\\theta(.\\vert s)] \\leq \\delta $$ In this way, the old and new policies would not diverge too much when this hard constraint is met. While still, TRPO can guarantee a monotonic improvement over policy iteration (Neat, right?). Please read the proof in the paper if interested :)\nPPO [paper|code]\nGiven that TRPO is relatively complicated and we still want to implement a similar constraint, proximal policy optimization (PPO) simplifies it by using a clipped surrogate objective while retaining similar performance.\nFirst, let\u0026rsquo;s denote the probability ratio between old and new policies as:\n $$ r(\\theta) = \\frac{\\pi_\\theta(a \\vert s)}{\\pi_{\\theta_\\text{old}}(a \\vert s)} $$ Then, the objective function of TRPO (on policy) becomes:\n $$ J^\\text{TRPO} (\\theta) = \\mathbb{E} [ r(\\theta) \\hat{A}_{\\theta_\\text{old}}(s, a) ] $$ Without a limitation on the distance between $\\theta_\\text{old}$ and $\\theta$, to maximize $J^\\text{TRPO} (\\theta)$ would lead to instability with extremely large parameter updates and big policy ratios. PPO imposes the constraint by forcing $r(\\theta)$ to stay within a small interval around 1, precisely $[1-\\epsilon, 1+\\epsilon]$, where $\\epsilon$ is a hyperparameter.\n $$ J^\\text{CLIP} (\\theta) = \\mathbb{E} [ \\min( r(\\theta) \\hat{A}_{\\theta_\\text{old}}(s, a), \\text{clip}(r(\\theta), 1 - \\epsilon, 1 + \\epsilon) \\hat{A}_{\\theta_\\text{old}}(s, a))] $$ The function $\\text{clip}(r(\\theta), 1 - \\epsilon, 1 + \\epsilon)$ clips the ratio to be no more than $1+\\epsilon$ and no less than $1-\\epsilon$. The objective function of PPO takes the minimum one between the original value and the clipped version and therefore we lose the motivation for increasing the policy update to extremes for better rewards.\nWhen applying PPO on the network architecture with shared parameters for both policy (actor) and value (critic) functions, in addition to the clipped reward, the objective function is augmented with an error term on the value estimation (formula in red) and an entropy term (formula in blue) to encourage sufficient exploration.\n $$ J^\\text{CLIP'} (\\theta) = \\mathbb{E} [ J^\\text{CLIP} (\\theta) - \\color{red}{c_1 (V_\\theta(s) - V_\\text{target})^2} + \\color{blue}{c_2 H(s, \\pi_\\theta(.))} ] $$ where Both $c_1$ and $c_2$ are two hyperparameter constants.\nPPO has been tested on a set of benchmark tasks and proved to produce awesome results with much greater simplicity.\nIn a later paper by Hsu et al., 2020, two common design choices in PPO are revisited, precisely (1) clipped probability ratio for policy regularization and (2) parameterize policy action space by continuous Gaussian or discrete softmax distribution. They first identified three failure modes in PPO and proposed replacements for these two designs.\nThe failure modes are:\n On continuous action spaces, standard PPO is unstable when rewards vanish outside bounded support. On discrete action spaces with sparse high rewards, standard PPO often gets stuck at suboptimal actions. The policy is sensitive to initialization when there are locally optimal actions close to initialization. Discretizing the action space or use Beta distribution helps avoid failure mode 1\u0026amp;3 associated with Gaussian policy. Using KL regularization (same motivation as in TRPO) as an alternative surrogate model helps resolve failure mode 1\u0026amp;2.\nPPG [paper|code]\nSharing parameters between policy and value networks have pros and cons. It allows policy and value functions to share the learned features with each other, but it may cause conflicts between competing objectives and demands the same data for training two networks at the same time. Phasic policy gradient (PPG; Cobbe, et al 2020) modifies the traditional on-policy actor-critic policy gradient algorithm. precisely PPO, to have separate training phases for policy and value functions. In two alternating phases:\n The policy phase: updates the policy network by optimizing the PPO objective $L^\\text{CLIP} (\\theta)$; The auxiliary phase: optimizes an auxiliary objective alongside a behavioral cloning loss. In the paper, value function error is the sole auxiliary objective, but it can be quite general and includes any other additional auxiliary losses. $$ \\begin{aligned} L^\\text{joint} \u0026= L^\\text{aux} + \\beta_\\text{clone} \\cdot \\mathbb{E}_t[\\text{KL}[\\pi_{\\theta_\\text{old}}(\\cdot\\mid s_t), \\pi_\\theta(\\cdot\\mid s_t)]] \\\\ L^\\text{aux} \u0026= L^\\text{value} = \\mathbb{E}_t \\big[\\frac{1}{2}\\big( V_w(s_t) - \\hat{V}_t^\\text{targ} \\big)^2\\big] \\end{aligned} $$ where $\\beta_\\text{clone}$ is a hyperparameter for controlling how much we would like to keep the policy not diverge too much from its original behavior while optimizing the auxiliary objectives.\nFig. 6. The algorithm of PPG. (Image source: Cobbe, et al 2020) where\n $N_\\pi$ is the number of policy update iterations in the policy phase. Note that the policy phase performs multiple iterations of updates per single auxiliary phase. $E_\\pi$ and $E_V$ control the sample reuse (i.e. the number of training epochs performed across data in the reply buffer) for the policy and value functions, respectively. Note that this happens within the policy phase and thus $E_V$ affects the learning of true value function not the auxiliary value function. $E_\\text{aux}$ defines the sample reuse in the auxiliary phrase. In PPG, value function optimization can tolerate a much higher level sample reuse; for example, in the experiments of the paper, $E_\\text{aux} = 6$ while $E_\\pi = E_V = 1$. PPG leads to a significant improvement on sample efficiency compared to PPO.\nFig. 7. The mean normalized performance of PPG vs PPO on the Procgen benchmark. (Image source: Cobbe, et al 2020) ACER [paper|code]\nACER, short for actor-critic with experience replay (Wang, et al., 2017), is an off-policy actor-critic model with experience replay, greatly increasing the sample efficiency and decreasing the data correlation. A3C builds up the foundation for ACER, but it is on policy; ACER is A3C\u0026rsquo;s off-policy counterpart. The major obstacle to making A3C off policy is how to control the stability of the off-policy estimator. ACER proposes three designs to overcome it:\n Use Retrace Q-value estimation; Truncate the importance weights with bias correction; Apply efficient TRPO. Retrace Q-value Estimation\nRetrace is an off-policy return-based Q-value estimation algorithm with a nice guarantee for convergence for any target and behavior policy pair $(\\pi, \\beta)$, plus good data efficiency.\nRecall how TD learning works for prediction:\n Compute TD error: $\\delta_t = R_t + \\gamma \\mathbb{E}_{a \\sim \\pi} Q(S_{t+1}, a) - Q(S_t, A_t)$; the term $r_t + \\gamma \\mathbb{E}_{a \\sim \\pi} Q(s_{t+1}, a) $ is known as “TD target”. The expectation $\\mathbb{E}_{a \\sim \\pi}$ is used because for the future step the best estimation we can make is what the return would be if we follow the current policy $\\pi$. Update the value by correcting the error to move toward the goal: $Q(S_t, A_t) \\leftarrow Q(S_t, A_t) + \\alpha \\delta_t$. In other words, the incremental update on Q is proportional to the TD error: $\\Delta Q(S_t, A_t) = \\alpha \\delta_t$. When the rollout is off policy, we need to apply importance sampling on the Q update:\n $$ \\Delta Q^\\text{imp}(S_t, A_t) = \\gamma^t \\prod_{1 \\leq \\tau \\leq t} \\frac{\\pi(A_\\tau \\vert S_\\tau)}{\\beta(A_\\tau \\vert S_\\tau)} \\delta_t $$ The product of importance weights looks pretty scary when we start imagining how it can cause super high variance and even explode. Retrace Q-value estimation method modifies $\\Delta Q$ to have importance weights truncated by no more than a constant $c$:\n $$ \\Delta Q^\\text{ret}(S_t, A_t) = \\gamma^t \\prod_{1 \\leq \\tau \\leq t} \\min(c, \\frac{\\pi(A_\\tau \\vert S_\\tau)}{\\beta(A_\\tau \\vert S_\\tau)}) \\delta_t $$ ACER uses $Q^\\text{ret}$ as the target to train the critic by minimizing the L2 error term: $(Q^\\text{ret}(s, a) - Q(s, a))^2$.\nImportance weights truncation\nTo reduce the high variance of the policy gradient $\\hat{g}$, ACER truncates the importance weights by a constant c, plus a correction term. The label $\\hat{g}_t^\\text{acer}$ is the ACER policy gradient at time t.\n $$ \\begin{aligned} \\hat{g}_t^\\text{acer} = \u0026 \\omega_t \\big( Q^\\text{ret}(S_t, A_t) - V_{\\theta_v}(S_t) \\big) \\nabla_\\theta \\ln \\pi_\\theta(A_t \\vert S_t) \u0026 \\scriptstyle{\\text{; Let }\\omega_t=\\frac{\\pi(A_t \\vert S_t)}{\\beta(A_t \\vert S_t)}} \\\\ = \u0026 \\color{blue}{\\min(c, \\omega_t) \\big( Q^\\text{ret}(S_t, A_t) - V_w(S_t) \\big) \\nabla_\\theta \\ln \\pi_\\theta(A_t \\vert S_t)} \\\\ \u0026 + \\color{red}{\\mathbb{E}_{a \\sim \\pi} \\big[ \\max(0, \\frac{\\omega_t(a) - c}{\\omega_t(a)}) \\big( Q_w(S_t, a) - V_w(S_t) \\big) \\nabla_\\theta \\ln \\pi_\\theta(a \\vert S_t) \\big]} \u0026 \\scriptstyle{\\text{; Let }\\omega_t (a) =\\frac{\\pi(a \\vert S_t)}{\\beta(a \\vert S_t)}} \\end{aligned} $$ where $Q_w(.)$ and $V_w(.)$ are value functions predicted by the critic with parameter w. The first term (blue) contains the clipped important weight. The clipping helps reduce the variance, in addition to subtracting state value function $V_w(.)$ as a baseline. The second term (red) makes a correction to achieve unbiased estimation.\nEfficient TRPO\nFurthermore, ACER adopts the idea of TRPO but with a small adjustment to make it more computationally efficient: rather than measuring the KL divergence between policies before and after one update, ACER maintains a running average of past policies and forces the updated policy to not deviate far from this average.\nThe ACER paper is pretty dense with many equations. Hopefully, with the prior knowledge on TD learning, Q-learning, importance sampling and TRPO, you will find the paper slightly easier to follow :)\nACTKR [paper|code]\nACKTR (actor-critic using Kronecker-factored trust region) (Yuhuai Wu, et al., 2017) proposed to use Kronecker-factored approximation curvature (K-FAC) to do the gradient update for both the critic and actor. K-FAC made an improvement on the computation of natural gradient, which is quite different from our standard gradient. Here is a nice, intuitive explanation of natural gradient. One sentence summary is probably:\n “we first consider all combinations of parameters that result in a new network a constant KL divergence away from the old network. This constant value can be viewed as the step size or learning rate. Out of all these possible combinations, we choose the one that minimizes our loss function.”\n I listed ACTKR here mainly for the completeness of this post, but I would not dive into details, as it involves a lot of theoretical knowledge on natural gradient and optimization methods. If interested, check these papers/posts, before reading the ACKTR paper:\n Amari. Natural Gradient Works Efficiently in Learning. 1998 Kakade. A Natural Policy Gradient. 2002 A intuitive explanation of natural gradient descent Wiki: Kronecker product Martens \u0026amp; Grosse. Optimizing neural networks with kronecker-factored approximate curvature. 2015. Here is a high level summary from the K-FAC paper:\n \u0026ldquo;This approximation is built in two stages. In the first, the rows and columns of the Fisher are divided into groups, each of which corresponds to all the weights in a given layer, and this gives rise to a block-partitioning of the matrix. These blocks are then approximated as Kronecker products between much smaller matrices, which we show is equivalent to making certain approximating assumptions regarding the statistics of the network\u0026rsquo;s gradients.\n In the second stage, this matrix is further approximated as having an inverse which is either block-diagonal or block-tridiagonal. We justify this approximation through a careful examination of the relationships between inverse covariances, tree-structured graphical models, and linear regression. Notably, this justification doesn\u0026rsquo;t apply to the Fisher itself, and our experiments confirm that while the inverse Fisher does indeed possess this structure (approximately), the Fisher itself does not.\u0026rdquo;\n SAC [paper|code]\nSoft Actor-Critic (SAC) (Haarnoja et al. 2018) incorporates the entropy measure of the policy into the reward to encourage exploration: we expect to learn a policy that acts as randomly as possible while it is still able to succeed at the task. It is an off-policy actor-critic model following the maximum entropy reinforcement learning framework. A precedent work is Soft Q-learning.\nThree key components in SAC:\n An actor-critic architecture with separate policy and value function networks; An off-policy formulation that enables reuse of previously collected data for efficiency; Entropy maximization to enable stability and exploration. The policy is trained with the objective to maximize the expected return and the entropy at the same time:\n $$ J(\\theta) = \\sum_{t=1}^T \\mathbb{E}_{(s_t, a_t) \\sim \\rho_{\\pi_\\theta}} [r(s_t, a_t) + \\alpha \\mathcal{H}(\\pi_\\theta(.\\vert s_t))] $$ where $\\mathcal{H}(.)$ is the entropy measure and $\\alpha$ controls how important the entropy term is, known as temperature parameter. The entropy maximization leads to policies that can (1) explore more and (2) capture multiple modes of near-optimal strategies (i.e., if there exist multiple options that seem to be equally good, the policy should assign each with an equal probability to be chosen).\nPrecisely, SAC aims to learn three functions:\n The policy with parameter $\\theta$, $\\pi_\\theta$. Soft Q-value function parameterized by $w$, $Q_w$. Soft state value function parameterized by $\\psi$, $V_\\psi$; theoretically we can infer $V$ by knowing $Q$ and $\\pi$, but in practice, it helps stabilize the training. Soft Q-value and soft state value are defined as:\n $$ \\begin{aligned} Q(s_t, a_t) \u0026= r(s_t, a_t) + \\gamma \\mathbb{E}_{s_{t+1} \\sim \\rho_{\\pi}(s)} [V(s_{t+1})] \u0026 \\text{; according to Bellman equation.}\\\\ \\text{where }V(s_t) \u0026= \\mathbb{E}_{a_t \\sim \\pi} [Q(s_t, a_t) - \\alpha \\log \\pi(a_t \\vert s_t)] \u0026 \\text{; soft state value function.} \\end{aligned} $$ $$ \\text{Thus, } Q(s_t, a_t) = r(s_t, a_t) + \\gamma \\mathbb{E}_{(s_{t+1}, a_{t+1}) \\sim \\rho_{\\pi}} [Q(s_{t+1}, a_{t+1}) - \\alpha \\log \\pi(a_{t+1} \\vert s_{t+1})] $$ $\\rho_\\pi(s)$ and $\\rho_\\pi(s, a)$ denote the state and the state-action marginals of the state distribution induced by the policy $\\pi(a \\vert s)$; see the similar definitions in DPG section.\nThe soft state value function is trained to minimize the mean squared error:\n $$ \\begin{aligned} J_V(\\psi) \u0026= \\mathbb{E}_{s_t \\sim \\mathcal{D}} [\\frac{1}{2} \\big(V_\\psi(s_t) - \\mathbb{E}[Q_w(s_t, a_t) - \\log \\pi_\\theta(a_t \\vert s_t)] \\big)^2] \\\\ \\text{with gradient: }\\nabla_\\psi J_V(\\psi) \u0026= \\nabla_\\psi V_\\psi(s_t)\\big( V_\\psi(s_t) - Q_w(s_t, a_t) + \\log \\pi_\\theta (a_t \\vert s_t) \\big) \\end{aligned} $$ where $\\mathcal{D}$ is the replay buffer.\nThe soft Q function is trained to minimize the soft Bellman residual:\n $$ \\begin{aligned} J_Q(w) \u0026= \\mathbb{E}_{(s_t, a_t) \\sim \\mathcal{D}} [\\frac{1}{2}\\big( Q_w(s_t, a_t) - (r(s_t, a_t) + \\gamma \\mathbb{E}_{s_{t+1} \\sim \\rho_\\pi(s)}[V_{\\bar{\\psi}}(s_{t+1})]) \\big)^2] \\\\ \\text{with gradient: } \\nabla_w J_Q(w) \u0026= \\nabla_w Q_w(s_t, a_t) \\big( Q_w(s_t, a_t) - r(s_t, a_t) - \\gamma V_{\\bar{\\psi}}(s_{t+1})\\big) \\end{aligned} $$ where $\\bar{\\psi}$ is the target value function which is the exponential moving average (or only gets updated periodically in a “hard” way), just like how the parameter of the target Q network is treated in DQN to stabilize the training.\nSAC updates the policy to minimize the KL-divergence:\n $$ \\begin{aligned} \\pi_\\text{new} \u0026= \\arg\\min_{\\pi' \\in \\Pi} D_\\text{KL} \\Big( \\pi'(.\\vert s_t) \\| \\frac{\\exp(Q^{\\pi_\\text{old}}(s_t, .))}{Z^{\\pi_\\text{old}}(s_t)} \\Big) \\\\[6pt] \u0026= \\arg\\min_{\\pi' \\in \\Pi} D_\\text{KL} \\big( \\pi'(.\\vert s_t) \\| \\exp(Q^{\\pi_\\text{old}}(s_t, .) - \\log Z^{\\pi_\\text{old}}(s_t)) \\big) \\\\[6pt] \\text{objective for update: } J_\\pi(\\theta) \u0026= \\nabla_\\theta D_\\text{KL} \\big( \\pi_\\theta(. \\vert s_t) \\| \\exp(Q_w(s_t, .) - \\log Z_w(s_t)) \\big) \\\\[6pt] \u0026= \\mathbb{E}_{a_t\\sim\\pi} \\Big[ - \\log \\big( \\frac{\\exp(Q_w(s_t, a_t) - \\log Z_w(s_t))}{\\pi_\\theta(a_t \\vert s_t)} \\big) \\Big] \\\\[6pt] \u0026= \\mathbb{E}_{a_t\\sim\\pi} [ \\log \\pi_\\theta(a_t \\vert s_t) - Q_w(s_t, a_t) + \\log Z_w(s_t) ] \\end{aligned} $$ where $\\Pi$ is the set of potential policies that we can model our policy as to keep them tractable; for example, $\\Pi$ can be the family of Gaussian mixture distributions, expensive to model but highly expressive and still tractable. $Z^{\\pi_\\text{old}}(s_t)$ is the partition function to normalize the distribution. It is usually intractable but does not contribute to the gradient. How to minimize $J_\\pi(\\theta)$ depends our choice of $\\Pi$.\nThis update guarantees that $Q^{\\pi_\\text{new}}(s_t, a_t) \\geq Q^{\\pi_\\text{old}}(s_t, a_t)$, please check the proof on this lemma in the Appendix B.2 in the original paper.\nOnce we have defined the objective functions and gradients for soft action-state value, soft state value and the policy network, the soft actor-critic algorithm is straightforward:\nFig. 8. The soft actor-critic algorithm. (Image source: original paper) SAC with Automatically Adjusted Temperature [paper|code]\nSAC is brittle with respect to the temperature parameter. Unfortunately it is difficult to adjust temperature, because the entropy can vary unpredictably both across tasks and during training as the policy becomes better. An improvement on SAC formulates a constrained optimization problem: while maximizing the expected return, the policy should satisfy a minimum entropy constraint:\n $$ \\max_{\\pi_0, \\dots, \\pi_T} \\mathbb{E} \\Big[ \\sum_{t=0}^T r(s_t, a_t)\\Big] \\text{s.t. } \\forall t\\text{, } \\mathcal{H}(\\pi_t) \\geq \\mathcal{H}_0 $$ where $\\mathcal{H}_0$ is a predefined minimum policy entropy threshold.\nThe expected return $\\mathbb{E} \\Big[ \\sum_{t=0}^T r(s_t, a_t)\\Big]$ can be decomposed into a sum of rewards at all the time steps. Because the policy $\\pi_t$ at time t has no effect on the policy at the earlier time step, $\\pi_{t-1}$, we can maximize the return at different steps backward in time \u0026mdash; this is essentially DP.\n $$ \\underbrace{\\max_{\\pi_0} \\Big( \\mathbb{E}[r(s_0, a_0)]+ \\underbrace{\\max_{\\pi_1} \\Big(\\mathbb{E}[...] + \\underbrace{\\max_{\\pi_T} \\mathbb{E}[r(s_T, a_T)]}_\\text{1st maximization} \\Big)}_\\text{second but last maximization} \\Big)}_\\text{last maximization} $$ where we consider $\\gamma=1$.\nSo we start the optimization from the last timestep $T$:\n $$ \\text{maximize } \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi}} [ r(s_T, a_T) ] \\text{ s.t. } \\mathcal{H}(\\pi_T) - \\mathcal{H}_0 \\geq 0 $$ First, let us define the following functions:\n $$ \\begin{aligned} h(\\pi_T) \u0026= \\mathcal{H}(\\pi_T) - \\mathcal{H}_0 = \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi}} [-\\log \\pi_T(a_T\\vert s_T)] - \\mathcal{H}_0\\\\ f(\\pi_T) \u0026= \\begin{cases} \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi}} [ r(s_T, a_T) ], \u0026 \\text{if }h(\\pi_T) \\geq 0 \\\\ -\\infty, \u0026 \\text{otherwise} \\end{cases} \\end{aligned} $$ And the optimization becomes:\n $$ \\text{maximize } f(\\pi_T) \\text{ s.t. } h(\\pi_T) \\geq 0 $$ To solve the maximization optimization with inequality constraint, we can construct a Lagrangian expression with a Lagrange multiplier (also known as \u0026ldquo;dual variable\u0026rdquo;), $\\alpha_T$:\n $$ L(\\pi_T, \\alpha_T) = f(\\pi_T) + \\alpha_T h(\\pi_T) $$ Considering the case when we try to minimize $L(\\pi_T, \\alpha_T)$ with respect to $\\alpha_T$ - given a particular value $\\pi_T$,\n If the constraint is satisfied, $h(\\pi_T) \\geq 0$, at best we can set $\\alpha_T=0$ since we have no control over the value of $f(\\pi_T)$. Thus, $L(\\pi_T, 0) = f(\\pi_T)$. If the constraint is invalidated, $h(\\pi_T) \u0026lt; 0$, we can achieve $L(\\pi_T, \\alpha_T) \\to -\\infty$ by taking $\\alpha_T \\to \\infty$. Thus, $L(\\pi_T, \\infty) = -\\infty = f(\\pi_T)$. In either case, we can recover the following equation,\n $$ f(\\pi_T) = \\min_{\\alpha_T \\geq 0} L(\\pi_T, \\alpha_T) $$ At the same time, we want to maximize $f(\\pi_T)$,\n $$ \\max_{\\pi_T} f(\\pi_T) = \\min_{\\alpha_T \\geq 0} \\max_{\\pi_T} L(\\pi_T, \\alpha_T) $$ Therefore, to maximize $f(\\pi_T)$, the dual problem is listed as below. Note that to make sure $\\max_{\\pi_T} f(\\pi_T)$ is properly maximized and would not become $-\\infty$, the constraint has to be satisfied.\n $$ \\begin{aligned} \\max_{\\pi_T} \\mathbb{E}[ r(s_T, a_T) ] \u0026= \\max_{\\pi_T} f(\\pi_T) \\\\ \u0026= \\min_{\\alpha_T \\geq 0} \\max_{\\pi_T} L(\\pi_T, \\alpha_T) \\\\ \u0026= \\min_{\\alpha_T \\geq 0} \\max_{\\pi_T} f(\\pi_T) + \\alpha_T h(\\pi_T) \\\\ \u0026= \\min_{\\alpha_T \\geq 0} \\max_{\\pi_T} \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi}} [ r(s_T, a_T) ] + \\alpha_T ( \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi}} [-\\log \\pi_T(a_T\\vert s_T)] - \\mathcal{H}_0) \\\\ \u0026= \\min_{\\alpha_T \\geq 0} \\max_{\\pi_T} \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi}} [ r(s_T, a_T) - \\alpha_T \\log \\pi_T(a_T\\vert s_T)] - \\alpha_T \\mathcal{H}_0 \\\\ \u0026= \\min_{\\alpha_T \\geq 0} \\max_{\\pi_T} \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi}} [ r(s_T, a_T) + \\alpha_T \\mathcal{H}(\\pi_T) - \\alpha_T \\mathcal{H}_0 ] \\end{aligned} $$ We could compute the optimal $\\pi_T$ and $\\alpha_T$ iteratively. First given the current $\\alpha_T$, get the best policy $\\pi_T^{*}$ that maximizes $L(\\pi_T^{*}, \\alpha_T)$. Then plug in $\\pi_T^{*}$ and compute $\\alpha_T^{*}$ that minimizes $L(\\pi_T^{*}, \\alpha_T)$. Assuming we have one neural network for policy and one network for temperature parameter, the iterative update process is more aligned with how we update network parameters during training.\n $$ \\begin{aligned} \\pi^{*}_T \u0026= \\arg\\max_{\\pi_T} \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi}} [ r(s_T, a_T) + \\alpha_T \\mathcal{H}(\\pi_T) - \\alpha_T \\mathcal{H}_0 ] \\\\ \\color{blue}{\\alpha^{*}_T} \u0026\\color{blue}{=} \\color{blue}{\\arg\\min_{\\alpha_T \\geq 0} \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi^{*}}} [\\alpha_T \\mathcal{H}(\\pi^{*}_T) - \\alpha_T \\mathcal{H}_0 ]} \\end{aligned} $$ $$ \\text{Thus, }\\max_{\\pi_T} \\mathbb{E} [ r(s_T, a_T) ] = \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi^{*}}} [ r(s_T, a_T) + \\alpha^{*}_T \\mathcal{H}(\\pi^{*}_T) - \\alpha^{*}_T \\mathcal{H}_0 ] $$ Now let\u0026rsquo;s go back to the soft Q value function:\n $$ \\begin{aligned} Q_{T-1}(s_{T-1}, a_{T-1}) \u0026= r(s_{T-1}, a_{T-1}) + \\mathbb{E} [Q(s_T, a_T) - \\alpha_T \\log \\pi(a_T \\vert s_T)] \\\\ \u0026= r(s_{T-1}, a_{T-1}) + \\mathbb{E} [r(s_T, a_T)] + \\alpha_T \\mathcal{H}(\\pi_T) \\\\ Q_{T-1}^{*}(s_{T-1}, a_{T-1}) \u0026= r(s_{T-1}, a_{T-1}) + \\max_{\\pi_T} \\mathbb{E} [r(s_T, a_T)] + \\alpha_T \\mathcal{H}(\\pi^{*}_T) \u0026 \\text{; plug in the optimal }\\pi_T^{*} \\end{aligned} $$ Therefore the expected return is as follows, when we take one step further back to the time step $T-1$:\n $$ \\begin{aligned} \u0026\\max_{\\pi_{T-1}}\\Big(\\mathbb{E}[r(s_{T-1}, a_{T-1})] + \\max_{\\pi_T} \\mathbb{E}[r(s_T, a_T] \\Big) \\\\ \u0026= \\max_{\\pi_{T-1}} \\Big( Q^{*}_{T-1}(s_{T-1}, a_{T-1}) - \\alpha^{*}_T \\mathcal{H}(\\pi^{*}_T) \\Big) \u0026 \\text{; should s.t. } \\mathcal{H}(\\pi_{T-1}) - \\mathcal{H}_0 \\geq 0 \\\\ \u0026= \\min_{\\alpha_{T-1} \\geq 0} \\max_{\\pi_{T-1}} \\Big( Q^{*}_{T-1}(s_{T-1}, a_{T-1}) - \\alpha^{*}_T \\mathcal{H}(\\pi^{*}_T) + \\alpha_{T-1} \\big( \\mathcal{H}(\\pi_{T-1}) - \\mathcal{H}_0 \\big) \\Big) \u0026 \\text{; dual problem w/ Lagrangian.} \\\\ \u0026= \\min_{\\alpha_{T-1} \\geq 0} \\max_{\\pi_{T-1}} \\Big( Q^{*}_{T-1}(s_{T-1}, a_{T-1}) + \\alpha_{T-1} \\mathcal{H}(\\pi_{T-1}) - \\alpha_{T-1}\\mathcal{H}_0 \\Big) - \\alpha^{*}_T \\mathcal{H}(\\pi^{*}_T) \\end{aligned} $$ Similar to the previous step,\n $$ \\begin{aligned} \\pi^{*}_{T-1} \u0026= \\arg\\max_{\\pi_{T-1}} \\mathbb{E}_{(s_{T-1}, a_{T-1}) \\sim \\rho_\\pi} [Q^{*}_{T-1}(s_{T-1}, a_{T-1}) + \\alpha_{T-1} \\mathcal{H}(\\pi_{T-1}) - \\alpha_{T-1} \\mathcal{H}_0 ] \\\\ \\color{green}{\\alpha^{*}_{T-1}} \u0026\\color{green}{=} \\color{green}{\\arg\\min_{\\alpha_{T-1} \\geq 0} \\mathbb{E}_{(s_{T-1}, a_{T-1}) \\sim \\rho_{\\pi^{*}}} [ \\alpha_{T-1} \\mathcal{H}(\\pi^{*}_{T-1}) - \\alpha_{T-1}\\mathcal{H}_0 ]} \\end{aligned} $$ The equation for updating $\\alpha_{T-1}$ in green has the same format as the equation for updating $\\alpha_{T-1}$ in blue above. By repeating this process, we can learn the optimal temperature parameter in every step by minimizing the same objective function:\n $$ J(\\alpha) = \\mathbb{E}_{a_t \\sim \\pi_t} [-\\alpha \\log \\pi_t(a_t \\mid s_t) - \\alpha \\mathcal{H}_0] $$ The final algorithm is same as SAC except for learning $\\alpha$ explicitly with respect to the objective $J(\\alpha)$ (see Fig. 7):\nFig. 9. The soft actor-critic algorithm with automatically adjusted temperature. (Image source: original paper) TD3 [paper|code]\nThe Q-learning algorithm is commonly known to suffer from the overestimation of the value function. This overestimation can propagate through the training iterations and negatively affect the policy. This property directly motivated Double Q-learning and Double DQN: the action selection and Q-value update are decoupled by using two value networks.\nTwin Delayed Deep Deterministic (short for TD3; Fujimoto et al., 2018) applied a couple of tricks on DDPG to prevent the overestimation of the value function:\n(1) Clipped Double Q-learning: In Double Q-Learning, the action selection and Q-value estimation are made by two networks separately. In the DDPG setting, given two deterministic actors $(\\mu_{\\theta_1}, \\mu_{\\theta_2})$ with two corresponding critics $(Q_{w_1}, Q_{w_2})$, the Double Q-learning Bellman targets look like:\n $$ \\begin{aligned} y_1 \u0026= r + \\gamma Q_{w_2}(s', \\mu_{\\theta_1}(s'))\\\\ y_2 \u0026= r + \\gamma Q_{w_1}(s', \\mu_{\\theta_2}(s')) \\end{aligned} $$ However, due to the slow changing policy, these two networks could be too similar to make independent decisions. The Clipped Double Q-learning instead uses the minimum estimation among two so as to favor underestimation bias which is hard to propagate through training:\n $$ \\begin{aligned} y_1 \u0026= r + \\gamma \\min_{i=1,2}Q_{w_i}(s', \\mu_{\\theta_1}(s'))\\\\ y_2 \u0026= r + \\gamma \\min_{i=1,2} Q_{w_i}(s', \\mu_{\\theta_2}(s')) \\end{aligned} $$ (2) Delayed update of Target and Policy Networks: In the actor-critic model, policy and value updates are deeply coupled: Value estimates diverge through overestimation when the policy is poor, and the policy will become poor if the value estimate itself is inaccurate.\nTo reduce the variance, TD3 updates the policy at a lower frequency than the Q-function. The policy network stays the same until the value error is small enough after several updates. The idea is similar to how the periodically-updated target network stay as a stable objective in DQN.\n(3) Target Policy Smoothing: Given a concern with deterministic policies that they can overfit to narrow peaks in the value function, TD3 introduced a smoothing regularization strategy on the value function: adding a small amount of clipped random noises to the selected action and averaging over mini-batches.\n $$ \\begin{aligned} y \u0026= r + \\gamma Q_w (s', \\mu_{\\theta}(s') + \\epsilon) \u0026 \\\\ \\epsilon \u0026\\sim \\text{clip}(\\mathcal{N}(0, \\sigma), -c, +c) \u0026 \\scriptstyle{\\text{ ; clipped random noises.}} \\end{aligned} $$ This approach mimics the idea of SARSA update and enforces that similar actions should have similar values.\nHere is the final algorithm:\nFig. 10. TD3 Algorithm. (Image source: Fujimoto et al., 2018) SVPG [paper|code for SVPG]\nStein Variational Policy Gradient (SVPG; Liu et al, 2017) applies the Stein variational gradient descent (SVGD; Liu and Wang, 2016) algorithm to update the policy parameter $\\theta$.\nIn the setup of maximum entropy policy optimization, $\\theta$ is considered as a random variable $\\theta \\sim q(\\theta)$ and the model is expected to learn this distribution $q(\\theta)$. Assuming we know a prior on how $q$ might look like, $q_0$, and we would like to guide the learning process to not make $\\theta$ too far away from $q_0$ by optimizing the following objective function:\n $$ \\hat{J}(\\theta) = \\mathbb{E}_{\\theta \\sim q} [J(\\theta)] - \\alpha D_\\text{KL}(q\\|q_0) $$ where $\\mathbb{E}_{\\theta \\sim q} [R(\\theta)]$ is the expected reward when $\\theta \\sim q(\\theta)$ and $D_\\text{KL}$ is the KL divergence.\nIf we don\u0026rsquo;t have any prior information, we might set $q_0$ as a uniform distribution and set $q_0(\\theta)$ to a constant. Then the above objective function becomes SAC, where the entropy term encourages exploration:\n $$ \\begin{aligned} \\hat{J}(\\theta) \u0026= \\mathbb{E}_{\\theta \\sim q} [J(\\theta)] - \\alpha D_\\text{KL}(q\\|q_0) \\\\ \u0026= \\mathbb{E}_{\\theta \\sim q} [J(\\theta)] - \\alpha \\mathbb{E}_{\\theta \\sim q} [\\log q(\\theta) - \\log q_0(\\theta)] \\\\ \u0026= \\mathbb{E}_{\\theta \\sim q} [J(\\theta)] + \\alpha H(q(\\theta)) \\end{aligned} $$ Let\u0026rsquo;s take the derivative of $\\hat{J}(\\theta) = \\mathbb{E}_{\\theta \\sim q} [J(\\theta)] - \\alpha D_\\text{KL}(q|q_0)$ w.r.t. $q$:\n $$ \\begin{aligned} \\nabla_q \\hat{J}(\\theta) \u0026= \\nabla_q \\big( \\mathbb{E}_{\\theta \\sim q} [J(\\theta)] - \\alpha D_\\text{KL}(q\\|q_0) \\big) \\\\ \u0026= \\nabla_q \\int_\\theta \\big( q(\\theta) J(\\theta) - \\alpha q(\\theta)\\log q(\\theta) + \\alpha q(\\theta) \\log q_0(\\theta) \\big) \\\\ \u0026= \\int_\\theta \\big( J(\\theta) - \\alpha \\log q(\\theta) -\\alpha + \\alpha \\log q_0(\\theta) \\big) \\\\ \u0026= 0 \\end{aligned} $$ The optimal distribution is:\n $$ \\log q^{*}(\\theta) = \\frac{1}{\\alpha} J(\\theta) + \\log q_0(\\theta) - 1 \\text{ thus } \\underbrace{ q^{*}(\\theta) }_\\textrm{\"posterior\"} \\propto \\underbrace{\\exp ( J(\\theta) / \\alpha )}_\\textrm{\"likelihood\"} \\underbrace{q_0(\\theta)}_\\textrm{prior} $$ The temperature $\\alpha$ decides a tradeoff between exploitation and exploration. When $\\alpha \\rightarrow 0$, $\\theta$ is updated only according to the expected return $J(\\theta)$. When $\\alpha \\rightarrow \\infty$, $\\theta$ always follows the prior belief.\nWhen using the SVGD method to estimate the target posterior distribution $q(\\theta)$, it relies on a set of particle $\\{\\theta_i\\}_{i=1}^n$ (independently trained policy agents) and each is updated:\n $$ \\theta_i \\gets \\theta_i + \\epsilon \\phi^{*}(\\theta_i) \\text{ where } \\phi^{*} = \\max_{\\phi \\in \\mathcal{H}} \\{ - \\nabla_\\epsilon D_\\text{KL} (q'_{[\\theta + \\epsilon \\phi(\\theta)]} \\| q) \\text{ s.t. } \\|\\phi\\|_{\\mathcal{H}} \\leq 1\\} $$ where $\\epsilon$ is a learning rate and $\\phi^{*}$ is the unit ball of a RKHS (reproducing kernel Hilbert space) $\\mathcal{H}$ of $\\theta$-shaped value vectors that maximally decreases the KL divergence between the particles and the target distribution. $q'(.)$ is the distribution of $\\theta + \\epsilon \\phi(\\theta)$.\nComparing different gradient-based update methods:\n Method Update space Plain gradient $\\Delta \\theta$ on the parameter space Natural gradient $\\Delta \\theta$ on the search distribution space SVGD $\\Delta \\theta$ on the kernel function space (edited) One estimation of $\\phi^{*}$ has the following form. A positive definite kernel $k(\\vartheta, \\theta)$, i.e. a Gaussian radial basis function, measures the similarity between particles.\n $$ \\begin{aligned} \\phi^{*}(\\theta_i) \u0026= \\mathbb{E}_{\\vartheta \\sim q'} [\\nabla_\\vartheta \\log q(\\vartheta) k(\\vartheta, \\theta_i) + \\nabla_\\vartheta k(\\vartheta, \\theta_i)]\\\\ \u0026= \\frac{1}{n} \\sum_{j=1}^n [\\color{red}{\\nabla_{\\theta_j} \\log q(\\theta_j) k(\\theta_j, \\theta_i)} + \\color{green}{\\nabla_{\\theta_j} k(\\theta_j, \\theta_i)}] \u0026 \\scriptstyle{\\text{;approximate }q'\\text{ with current particle values}} \\end{aligned} $$ The first term in red encourages $\\theta_i$ learning towards the high probability regions of $q$ that is shared across similar particles. =\u0026gt; to be similar to other particles The second term in green pushes particles away from each other and therefore diversifies the policy. =\u0026gt; to be dissimilar to other particles Usually the temperature $\\alpha$ follows an annealing scheme so that the training process does more exploration at the beginning but more exploitation at a later stage.\nIMPALA [paper|code]\nIn order to scale up RL training to achieve a very high throughput, IMPALA (\u0026ldquo;Importance Weighted Actor-Learner Architecture\u0026rdquo;) framework decouples acting from learning on top of basic actor-critic setup and learns from all experience trajectories with V-trace off-policy correction.\nMultiple actors generate experience in parallel, while the learner optimizes both policy and value function parameters using all the generated experience. Actors update their parameters with the latest policy from the learner periodically. Because acting and learning are decoupled, we can add many more actor machines to generate a lot more trajectories per time unit. As the training policy and the behavior policy are not totally synchronized, there is a gap between them and thus we need off-policy corrections.\nLet the value function $V_\\theta$ parameterized by $\\theta$ and the policy $\\pi_\\phi$ parameterized by $\\phi$. Also we know the trajectories in the replay buffer are collected by a slightly older policy $\\mu$.\nAt the training time $t$, given $(s_t, a_t, s_{t+1}, r_t)$, the value function parameter $\\theta$ is learned through an L2 loss between the current value and a V-trace value target. The $n$-step V-trace target is defined as:\n $$ \\begin{aligned} v_t \u0026= V_\\theta(s_t) + \\sum_{i=t}^{t+n-1} \\gamma^{i-t} \\big(\\prod_{j=t}^{i-1} c_j\\big) \\color{red}{\\delta_i V} \\\\ \u0026= V_\\theta(s_t) + \\sum_{i=t}^{t+n-1} \\gamma^{i-t} \\big(\\prod_{j=t}^{i-1} c_j\\big) \\color{red}{\\rho_i (r_i + \\gamma V_\\theta(s_{i+1}) - V_\\theta(s_i))} \\end{aligned} $$ where the red part $\\delta_i V$ is a temporal difference for $V$. $\\rho_i = \\min\\big(\\bar{\\rho}, \\frac{\\pi(a_i \\vert s_i)}{\\mu(a_i \\vert s_i)}\\big)$ and $c_j = \\min\\big(\\bar{c}, \\frac{\\pi(a_j \\vert s_j)}{\\mu(a_j \\vert s_j)}\\big)$ are truncated importance sampling (IS) weights. The product of $c_t, \\dots, c_{i-1}$ measures how much a temporal difference $\\delta_i V$ observed at time $i$ impacts the update of the value function at a previous time $t$. In the on-policy case, we have $\\rho_i=1$ and $c_j=1$ (assuming $\\bar{c} \\geq 1$) and therefore the V-trace target becomes on-policy $n$-step Bellman target.\n$\\bar{\\rho}$ and $\\bar{c}$ are two truncation constants with $\\bar{\\rho} \\geq \\bar{c}$. $\\bar{\\rho}$ impacts the fixed-point of the value function we converge to and $\\bar{c}$ impacts the speed of convergence. When $\\bar{\\rho} =\\infty$ (untruncated), we converge to the value function of the target policy $V^\\pi$; when $\\bar{\\rho}$ is close to 0, we evaluate the value function of the behavior policy $V^\\mu$; when in-between, we evaluate a policy between $\\pi$ and $\\mu$.\nThe value function parameter is therefore updated in the direction of:\n $$ \\Delta\\theta = (v_t - V_\\theta(s_t))\\nabla_\\theta V_\\theta(s_t) $$ The policy parameter $\\phi$ is updated through policy gradient,\n $$ \\begin{aligned} \\Delta \\phi \u0026= \\rho_t \\nabla_\\phi \\log \\pi_\\phi(a_t \\vert s_t) \\big(r_t + \\gamma v_{t+1} - V_\\theta(s_t)\\big) + \\nabla_\\phi H(\\pi_\\phi)\\\\ \u0026= \\rho_t \\nabla_\\phi \\log \\pi_\\phi(a_t \\vert s_t) \\big(r_t + \\gamma v_{t+1} - V_\\theta(s_t)\\big) - \\nabla_\\phi \\sum_a \\pi_\\phi(a\\vert s_t)\\log \\pi_\\phi(a\\vert s_t) \\end{aligned} $$ where $r_t + \\gamma v_{t+1}$ is the estimated Q value, from which a state-dependent baseline $V_\\theta(s_t)$ is subtracted. $H(\\pi_\\phi)$ is an entropy bonus to encourage exploration.\nIn the experiments, IMPALA is used to train one agent over multiple tasks. Two different model architectures are involved, a shallow model (left) and a deep residual model (right).\nQuick Summary After reading through all the algorithms above, I list a few building blocks or principles that seem to be common among them:\n Try to reduce the variance and keep the bias unchanged to stabilize learning. Off-policy gives us better exploration and helps us use data samples more efficiently. Experience replay (training data sampled from a replay memory buffer); Target network that is either frozen periodically or updated slower than the actively learned policy network; Batch normalization; Entropy-regularized reward; The critic and actor can share lower layer parameters of the network and two output heads for policy and value functions. It is possible to learn with deterministic policy rather than stochastic one. Put constraint on the divergence between policy updates. New optimization methods (such as K-FAC). Entropy maximization of the policy helps encourage exploration. Try not to overestimate the value function. Think twice whether the policy and value network should share parameters. TBA more. Cited as:\n@article{weng2018PG, title = \u0026quot;Policy Gradient Algorithms\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2018\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2018-04-08-policy-gradient/\u0026quot; } References [1] jeremykun.com Markov Chain Monte Carlo Without all the Bullshit\n[2] Richard S. Sutton and Andrew G. Barto. Reinforcement Learning: An Introduction; 2nd Edition. 2017.\n[3] John Schulman, et al. \u0026ldquo;High-dimensional continuous control using generalized advantage estimation.\u0026quot; ICLR 2016.\n[4] Thomas Degris, Martha White, and Richard S. Sutton. \u0026ldquo;Off-policy actor-critic.\u0026quot; ICML 2012.\n[5] timvieira.github.io Importance sampling\n[6] Mnih, Volodymyr, et al. \u0026ldquo;Asynchronous methods for deep reinforcement learning.\u0026quot; ICML. 2016.\n[7] David Silver, et al. \u0026ldquo;Deterministic policy gradient algorithms.\u0026quot; ICML. 2014.\n[8] Timothy P. Lillicrap, et al. \u0026ldquo;Continuous control with deep reinforcement learning.\u0026quot; arXiv preprint arXiv:1509.02971 (2015).\n[9] Ryan Lowe, et al. \u0026ldquo;Multi-agent actor-critic for mixed cooperative-competitive environments.\u0026quot; NIPS. 2017.\n[10] John Schulman, et al. \u0026ldquo;Trust region policy optimization.\u0026quot; ICML. 2015.\n[11] Ziyu Wang, et al. \u0026ldquo;Sample efficient actor-critic with experience replay.\u0026quot; ICLR 2017.\n[12] Rémi Munos, Tom Stepleton, Anna Harutyunyan, and Marc Bellemare. \u0026ldquo;Safe and efficient off-policy reinforcement learning\u0026rdquo; NIPS. 2016.\n[13] Yuhuai Wu, et al. \u0026ldquo;Scalable trust-region method for deep reinforcement learning using Kronecker-factored approximation.\u0026quot; NIPS. 2017.\n[14] kvfrans.com A intuitive explanation of natural gradient descent\n[15] Sham Kakade. \u0026ldquo;A Natural Policy Gradient.\u0026quot;. NIPS. 2002.\n[16] \u0026ldquo;Going Deeper Into Reinforcement Learning: Fundamentals of Policy Gradients.\u0026quot; - Seita\u0026rsquo;s Place, Mar 2017.\n[17] \u0026ldquo;Notes on the Generalized Advantage Estimation Paper.\u0026quot; - Seita\u0026rsquo;s Place, Apr, 2017.\n[18] Gabriel Barth-Maron, et al. \u0026ldquo;Distributed Distributional Deterministic Policy Gradients.\u0026quot; ICLR 2018 poster.\n[19] Tuomas Haarnoja, Aurick Zhou, Pieter Abbeel, and Sergey Levine. \u0026ldquo;Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.\u0026quot; arXiv preprint arXiv:1801.01290 (2018).\n[20] Scott Fujimoto, Herke van Hoof, and Dave Meger. \u0026ldquo;Addressing Function Approximation Error in Actor-Critic Methods.\u0026quot; arXiv preprint arXiv:1802.09477 (2018).\n[21] Tuomas Haarnoja, et al. \u0026ldquo;Soft Actor-Critic Algorithms and Applications.\u0026quot; arXiv preprint arXiv:1812.05905 (2018).\n[22] David Knowles. \u0026ldquo;Lagrangian Duality for Dummies\u0026rdquo; Nov 13, 2010.\n[23] Yang Liu, et al. \u0026ldquo;Stein variational policy gradient.\u0026quot; arXiv preprint arXiv:1704.02399 (2017).\n[24] Qiang Liu and Dilin Wang. \u0026ldquo;Stein variational gradient descent: A general purpose bayesian inference algorithm.\u0026quot; NIPS. 2016.\n[25] Lasse Espeholt, et al. \u0026ldquo;IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures\u0026rdquo; arXiv preprint 1802.01561 (2018).\n[26] Karl Cobbe, et al. \u0026ldquo;Phasic Policy Gradient.\u0026quot; arXiv preprint arXiv:2009.04416 (2020).\n[27] Chloe Ching-Yun Hsu, et al. \u0026ldquo;Revisiting Design Choices in Proximal Policy Optimization.\u0026quot; arXiv preprint arXiv:2009.10897 (2020).\n","permalink":"https://lilianweng.github.io/posts/2018-04-08-policy-gradient/","summary":"[Updated on 2018-06-30: add two new policy gradient methods, SAC and D4PG.] [Updated on 2018-09-30: add a new policy gradient method, TD3.] [Updated on 2019-02-09: add SAC with automatically adjusted temperature]. [Updated on 2019-06-26: Thanks to Chanseok, we have a version of this post in Korean]. [Updated on 2019-09-12: add a new policy gradient method SVPG.] [Updated on 2019-12-22: add a new policy gradient method IMPALA.","title":"Policy Gradient Algorithms"},{"content":"[Updated on 2020-09-03: Updated the algorithm of SARSA and Q-learning so that the difference is more pronounced. [Updated on 2021-09-19: Thanks to 爱吃猫的鱼, we have this post in Chinese].\nA couple of exciting news in Artificial Intelligence (AI) has just happened in recent years. AlphaGo defeated the best professional human player in the game of Go. Very soon the extended algorithm AlphaGo Zero beat AlphaGo by 100-0 without supervised learning on human knowledge. Top professional game players lost to the bot developed by OpenAI on DOTA2 1v1 competition. After knowing these, it is pretty hard not to be curious about the magic behind these algorithms \u0026mdash; Reinforcement Learning (RL). I\u0026rsquo;m writing this post to briefly go over the field. We will first introduce several fundamental concepts and then dive into classic approaches to solving RL problems. Hopefully, this post could be a good starting point for newbies, bridging the future study on the cutting-edge research.\nWhat is Reinforcement Learning? Say, we have an agent in an unknown environment and this agent can obtain some rewards by interacting with the environment. The agent ought to take actions so as to maximize cumulative rewards. In reality, the scenario could be a bot playing a game to achieve high scores, or a robot trying to complete physical tasks with physical items; and not just limited to these.\nFig. 1. An agent interacts with the environment, trying to take smart actions to maximize cumulative rewards. The goal of Reinforcement Learning (RL) is to learn a good strategy for the agent from experimental trials and relative simple feedback received. With the optimal strategy, the agent is capable to actively adapt to the environment to maximize future rewards.\nKey Concepts Now Let\u0026rsquo;s formally define a set of key concepts in RL.\nThe agent is acting in an environment. How the environment reacts to certain actions is defined by a model which we may or may not know. The agent can stay in one of many states ($s \\in \\mathcal{S}$) of the environment, and choose to take one of many actions ($a \\in \\mathcal{A}$) to switch from one state to another. Which state the agent will arrive in is decided by transition probabilities between states ($P$). Once an action is taken, the environment delivers a reward ($r \\in \\mathcal{R}$) as feedback.\nThe model defines the reward function and transition probabilities. We may or may not know how the model works and this differentiate two circumstances:\n Know the model: planning with perfect information; do model-based RL. When we fully know the environment, we can find the optimal solution by Dynamic Programming (DP). Do you still remember \u0026ldquo;longest increasing subsequence\u0026rdquo; or \u0026ldquo;traveling salesmen problem\u0026rdquo; from your Algorithms 101 class? LOL. This is not the focus of this post though. Does not know the model: learning with incomplete information; do model-free RL or try to learn the model explicitly as part of the algorithm. Most of the following content serves the scenarios when the model is unknown. The agent\u0026rsquo;s policy $\\pi(s)$ provides the guideline on what is the optimal action to take in a certain state with the goal to maximize the total rewards. Each state is associated with a value function $V(s)$ predicting the expected amount of future rewards we are able to receive in this state by acting the corresponding policy. In other words, the value function quantifies how good a state is. Both policy and value functions are what we try to learn in reinforcement learning.\nFig. 2. Summary of approaches in RL based on whether we want to model the value, policy, or the environment. (Image source: reproduced from David Silver's RL course lecture 1.) The interaction between the agent and the environment involves a sequence of actions and observed rewards in time, $t=1, 2, \\dots, T$. During the process, the agent accumulates the knowledge about the environment, learns the optimal policy, and makes decisions on which action to take next so as to efficiently learn the best policy. Let\u0026rsquo;s label the state, action, and reward at time step t as $S_t$, $A_t$, and $R_t$, respectively. Thus the interaction sequence is fully described by one episode (also known as \u0026ldquo;trial\u0026rdquo; or \u0026ldquo;trajectory\u0026rdquo;) and the sequence ends at the terminal state $S_T$:\n $$ S_1, A_1, R_2, S_2, A_2, \\dots, S_T $$ Terms you will encounter a lot when diving into different categories of RL algorithms:\n Model-based: Rely on the model of the environment; either the model is known or the algorithm learns it explicitly. Model-free: No dependency on the model during learning. On-policy: Use the deterministic outcomes or samples from the target policy to train the algorithm. Off-policy: Training on a distribution of transitions or episodes produced by a different behavior policy rather than that produced by the target policy. Model: Transition and Reward The model is a descriptor of the environment. With the model, we can learn or infer how the environment would interact with and provide feedback to the agent. The model has two major parts, transition probability function $P$ and reward function $R$.\nLet\u0026rsquo;s say when we are in state s, we decide to take action a to arrive in the next state s' and obtain reward r. This is known as one transition step, represented by a tuple (s, a, s', r).\nThe transition function P records the probability of transitioning from state s to s' after taking action a while obtaining reward r. We use $\\mathbb{P}$ as a symbol of \u0026ldquo;probability\u0026rdquo;.\n $$ P(s', r \\vert s, a) = \\mathbb{P} [S_{t+1} = s', R_{t+1} = r \\vert S_t = s, A_t = a] $$ Thus the state-transition function can be defined as a function of $P(s', r \\vert s, a)$:\n $$ P_{ss'}^a = P(s' \\vert s, a) = \\mathbb{P} [S_{t+1} = s' \\vert S_t = s, A_t = a] = \\sum_{r \\in \\mathcal{R}} P(s', r \\vert s, a) $$ The reward function R predicts the next reward triggered by one action:\n $$ R(s, a) = \\mathbb{E} [R_{t+1} \\vert S_t = s, A_t = a] = \\sum_{r\\in\\mathcal{R}} r \\sum_{s' \\in \\mathcal{S}} P(s', r \\vert s, a) $$ Policy Policy, as the agent\u0026rsquo;s behavior function $\\pi$, tells us which action to take in state s. It is a mapping from state s to action a and can be either deterministic or stochastic:\n Deterministic: $\\pi(s) = a$. Stochastic: $\\pi(a \\vert s) = \\mathbb{P}_\\pi [A=a \\vert S=s]$. Value Function Value function measures the goodness of a state or how rewarding a state or an action is by a prediction of future reward. The future reward, also known as return, is a total sum of discounted rewards going forward. Let\u0026rsquo;s compute the return $G_t$ starting from time t:\n $$ G_t = R_{t+1} + \\gamma R_{t+2} + \\dots = \\sum_{k=0}^{\\infty} \\gamma^k R_{t+k+1} $$ The discounting factor $\\gamma \\in [0, 1]$ penalize the rewards in the future, because:\n The future rewards may have higher uncertainty; i.e. stock market. The future rewards do not provide immediate benefits; i.e. As human beings, we might prefer to have fun today rather than 5 years later ;). Discounting provides mathematical convenience; i.e., we don\u0026rsquo;t need to track future steps forever to compute return. We don\u0026rsquo;t need to worry about the infinite loops in the state transition graph. The state-value of a state s is the expected return if we are in this state at time t, $S_t = s$:\n $$ V_{\\pi}(s) = \\mathbb{E}_{\\pi}[G_t \\vert S_t = s] $$ Similarly, we define the action-value (\u0026ldquo;Q-value\u0026rdquo;; Q as \u0026ldquo;Quality\u0026rdquo; I believe?) of a state-action pair as:\n $$ Q_{\\pi}(s, a) = \\mathbb{E}_{\\pi}[G_t \\vert S_t = s, A_t = a] $$ Additionally, since we follow the target policy $\\pi$, we can make use of the probility distribution over possible actions and the Q-values to recover the state-value:\n $$ V_{\\pi}(s) = \\sum_{a \\in \\mathcal{A}} Q_{\\pi}(s, a) \\pi(a \\vert s) $$ The difference between action-value and state-value is the action advantage function (\u0026ldquo;A-value\u0026rdquo;):\n $$ A_{\\pi}(s, a) = Q_{\\pi}(s, a) - V_{\\pi}(s) $$ Optimal Value and Policy The optimal value function produces the maximum return:\n $$ V_{*}(s) = \\max_{\\pi} V_{\\pi}(s), Q_{*}(s, a) = \\max_{\\pi} Q_{\\pi}(s, a) $$ The optimal policy achieves optimal value functions:\n $$ \\pi_{*} = \\arg\\max_{\\pi} V_{\\pi}(s), \\pi_{*} = \\arg\\max_{\\pi} Q_{\\pi}(s, a) $$ And of course, we have $V_{\\pi_{*}}(s)=V_{*}(s)$ and $Q_{\\pi_{*}}(s, a) = Q_{*}(s, a)$.\nMarkov Decision Processes In more formal terms, almost all the RL problems can be framed as Markov Decision Processes (MDPs). All states in MDP has \u0026ldquo;Markov\u0026rdquo; property, referring to the fact that the future only depends on the current state, not the history:\n $$ \\mathbb{P}[ S_{t+1} \\vert S_t ] = \\mathbb{P} [S_{t+1} \\vert S_1, \\dots, S_t] $$ Or in other words, the future and the past are conditionally independent given the present, as the current state encapsulates all the statistics we need to decide the future.\nFig. 3. The agent-environment interaction in a Markov decision process. (Image source: Sec. 3.1 Sutton \u0026 Barto (2017).) A Markov deicison process consists of five elements $\\mathcal{M} = \\langle \\mathcal{S}, \\mathcal{A}, P, R, \\gamma \\rangle$, where the symbols carry the same meanings as key concepts in the previous section, well aligned with RL problem settings:\n $\\mathcal{S}$ - a set of states; $\\mathcal{A}$ - a set of actions; $P$ - transition probability function; $R$ - reward function; $\\gamma$ - discounting factor for future rewards. In an unknown environment, we do not have perfect knowledge about $P$ and $R$. Fig. 4. A fun example of Markov decision process: a typical work day. (Image source: randomant.net/reinforcement-learning-concepts) Bellman Equations Bellman equations refer to a set of equations that decompose the value function into the immediate reward plus the discounted future values.\n $$ \\begin{aligned} V(s) \u0026= \\mathbb{E}[G_t \\vert S_t = s] \\\\ \u0026= \\mathbb{E} [R_{t+1} + \\gamma R_{t+2} + \\gamma^2 R_{t+3} + \\dots \\vert S_t = s] \\\\ \u0026= \\mathbb{E} [R_{t+1} + \\gamma (R_{t+2} + \\gamma R_{t+3} + \\dots) \\vert S_t = s] \\\\ \u0026= \\mathbb{E} [R_{t+1} + \\gamma G_{t+1} \\vert S_t = s] \\\\ \u0026= \\mathbb{E} [R_{t+1} + \\gamma V(S_{t+1}) \\vert S_t = s] \\end{aligned} $$ Similarly for Q-value,\n $$ \\begin{aligned} Q(s, a) \u0026= \\mathbb{E} [R_{t+1} + \\gamma V(S_{t+1}) \\mid S_t = s, A_t = a] \\\\ \u0026= \\mathbb{E} [R_{t+1} + \\gamma \\mathbb{E}_{a\\sim\\pi} Q(S_{t+1}, a) \\mid S_t = s, A_t = a] \\end{aligned} $$ Bellman Expectation Equations The recursive update process can be further decomposed to be equations built on both state-value and action-value functions. As we go further in future action steps, we extend V and Q alternatively by following the policy $\\pi$.\nFig. 5. Illustration of how Bellman expection equations update state-value and action-value functions. $$ \\begin{aligned} V_{\\pi}(s) \u0026= \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s) Q_{\\pi}(s, a) \\\\ Q_{\\pi}(s, a) \u0026= R(s, a) + \\gamma \\sum_{s' \\in \\mathcal{S}} P_{ss'}^a V_{\\pi} (s') \\\\ V_{\\pi}(s) \u0026= \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s) \\big( R(s, a) + \\gamma \\sum_{s' \\in \\mathcal{S}} P_{ss'}^a V_{\\pi} (s') \\big) \\\\ Q_{\\pi}(s, a) \u0026= R(s, a) + \\gamma \\sum_{s' \\in \\mathcal{S}} P_{ss'}^a \\sum_{a' \\in \\mathcal{A}} \\pi(a' \\vert s') Q_{\\pi} (s', a') \\end{aligned} $$ Bellman Optimality Equations If we are only interested in the optimal values, rather than computing the expectation following a policy, we could jump right into the maximum returns during the alternative updates without using a policy. RECAP: the optimal values $V_*$ and $Q_*$ are the best returns we can obtain, defined here.\n $$ \\begin{aligned} V_*(s) \u0026= \\max_{a \\in \\mathcal{A}} Q_*(s,a)\\\\ Q_*(s, a) \u0026= R(s, a) + \\gamma \\sum_{s' \\in \\mathcal{S}} P_{ss'}^a V_*(s') \\\\ V_*(s) \u0026= \\max_{a \\in \\mathcal{A}} \\big( R(s, a) + \\gamma \\sum_{s' \\in \\mathcal{S}} P_{ss'}^a V_*(s') \\big) \\\\ Q_*(s, a) \u0026= R(s, a) + \\gamma \\sum_{s' \\in \\mathcal{S}} P_{ss'}^a \\max_{a' \\in \\mathcal{A}} Q_*(s', a') \\end{aligned} $$ Unsurprisingly they look very similar to Bellman expectation equations.\nIf we have complete information of the environment, this turns into a planning problem, solvable by DP. Unfortunately, in most scenarios, we do not know $P_{ss'}^a$ or $R(s, a)$, so we cannot solve MDPs by directly applying Bellmen equations, but it lays the theoretical foundation for many RL algorithms.\nCommon Approaches Now it is the time to go through the major approaches and classic algorithms for solving RL problems. In future posts, I plan to dive into each approach further.\nDynamic Programming When the model is fully known, following Bellman equations, we can use Dynamic Programming (DP) to iteratively evaluate value functions and improve policy.\nPolicy Evaluation Policy Evaluation is to compute the state-value $V_\\pi$ for a given policy $\\pi$:\n $$ V_{t+1}(s) = \\mathbb{E}_\\pi [r + \\gamma V_t(s') | S_t = s] = \\sum_a \\pi(a \\vert s) \\sum_{s', r} P(s', r \\vert s, a) (r + \\gamma V_t(s')) $$ Policy Improvement Based on the value functions, Policy Improvement generates a better policy $\\pi' \\geq \\pi$ by acting greedily.\n $$ Q_\\pi(s, a) = \\mathbb{E} [R_{t+1} + \\gamma V_\\pi(S_{t+1}) \\vert S_t=s, A_t=a] = \\sum_{s', r} P(s', r \\vert s, a) (r + \\gamma V_\\pi(s')) $$ Policy Iteration The Generalized Policy Iteration (GPI) algorithm refers to an iterative procedure to improve the policy when combining policy evaluation and improvement.\n $$ \\pi_0 \\xrightarrow[]{\\text{evaluation}} V_{\\pi_0} \\xrightarrow[]{\\text{improve}} \\pi_1 \\xrightarrow[]{\\text{evaluation}} V_{\\pi_1} \\xrightarrow[]{\\text{improve}} \\pi_2 \\xrightarrow[]{\\text{evaluation}} \\dots \\xrightarrow[]{\\text{improve}} \\pi_* \\xrightarrow[]{\\text{evaluation}} V_* $$ In GPI, the value function is approximated repeatedly to be closer to the true value of the current policy and in the meantime, the policy is improved repeatedly to approach optimality. This policy iteration process works and always converges to the optimality, but why this is the case?\nSay, we have a policy $\\pi$ and then generate an improved version $\\pi'$ by greedily taking actions, $\\pi'(s) = \\arg\\max_{a \\in \\mathcal{A}} Q_\\pi(s, a)$. The value of this improved $\\pi'$ is guaranteed to be better because:\n $$ \\begin{aligned} Q_\\pi(s, \\pi'(s)) \u0026= Q_\\pi(s, \\arg\\max_{a \\in \\mathcal{A}} Q_\\pi(s, a)) \\\\ \u0026= \\max_{a \\in \\mathcal{A}} Q_\\pi(s, a) \\geq Q_\\pi(s, \\pi(s)) = V_\\pi(s) \\end{aligned} $$ Monte-Carlo Methods First, let\u0026rsquo;s recall that $V(s) = \\mathbb{E}[ G_t \\vert S_t=s]$. Monte-Carlo (MC) methods uses a simple idea: It learns from episodes of raw experience without modeling the environmental dynamics and computes the observed mean return as an approximation of the expected return. To compute the empirical return $G_t$, MC methods need to learn from complete episodes $S_1, A_1, R_2, \\dots, S_T$ to compute $G_t = \\sum_{k=0}^{T-t-1} \\gamma^k R_{t+k+1}$ and all the episodes must eventually terminate.\nThe empirical mean return for state s is:\n $$ V(s) = \\frac{\\sum_{t=1}^T \\mathbb{1}[S_t = s] G_t}{\\sum_{t=1}^T \\mathbb{1}[S_t = s]} $$ where $\\mathbb{1}[S_t = s]$ is a binary indicator function. We may count the visit of state s every time so that there could exist multiple visits of one state in one episode (\u0026ldquo;every-visit\u0026rdquo;), or only count it the first time we encounter a state in one episode (\u0026ldquo;first-visit\u0026rdquo;). This way of approximation can be easily extended to action-value functions by counting (s, a) pair.\n $$ Q(s, a) = \\frac{\\sum_{t=1}^T \\mathbb{1}[S_t = s, A_t = a] G_t}{\\sum_{t=1}^T \\mathbb{1}[S_t = s, A_t = a]} $$ To learn the optimal policy by MC, we iterate it by following a similar idea to GPI.\n Improve the policy greedily with respect to the current value function: $\\pi(s) = \\arg\\max_{a \\in \\mathcal{A}} Q(s, a)$. Generate a new episode with the new policy $\\pi$ (i.e. using algorithms like ε-greedy helps us balance between exploitation and exploration.) Estimate Q using the new episode: $q_\\pi(s, a) = \\frac{\\sum_{t=1}^T \\big( \\mathbb{1}[S_t = s, A_t = a] \\sum_{k=0}^{T-t-1} \\gamma^k R_{t+k+1} \\big)}{\\sum_{t=1}^T \\mathbb{1}[S_t = s, A_t = a]}$ Temporal-Difference Learning Similar to Monte-Carlo methods, Temporal-Difference (TD) Learning is model-free and learns from episodes of experience. However, TD learning can learn from incomplete episodes and hence we don\u0026rsquo;t need to track the episode up to termination. TD learning is so important that Sutton \u0026amp; Barto (2017) in their RL book describes it as \u0026ldquo;one idea … central and novel to reinforcement learning\u0026rdquo;.\nBootstrapping TD learning methods update targets with regard to existing estimates rather than exclusively relying on actual rewards and complete returns as in MC methods. This approach is known as bootstrapping.\nValue Estimation The key idea in TD learning is to update the value function $V(S_t)$ towards an estimated return $R_{t+1} + \\gamma V(S_{t+1})$ (known as \u0026ldquo;TD target\u0026quot;). To what extent we want to update the value function is controlled by the learning rate hyperparameter α:\n $$ \\begin{aligned} V(S_t) \u0026\\leftarrow (1- \\alpha) V(S_t) + \\alpha G_t \\\\ V(S_t) \u0026\\leftarrow V(S_t) + \\alpha (G_t - V(S_t)) \\\\ V(S_t) \u0026\\leftarrow V(S_t) + \\alpha (R_{t+1} + \\gamma V(S_{t+1}) - V(S_t)) \\end{aligned} $$ Similarly, for action-value estimation:\n $$ Q(S_t, A_t) \\leftarrow Q(S_t, A_t) + \\alpha (R_{t+1} + \\gamma Q(S_{t+1}, A_{t+1}) - Q(S_t, A_t)) $$ Next, let\u0026rsquo;s dig into the fun part on how to learn optimal policy in TD learning (aka \u0026ldquo;TD control\u0026rdquo;). Be prepared, you are gonna see many famous names of classic algorithms in this section.\nSARSA: On-Policy TD control \u0026ldquo;SARSA\u0026rdquo; refers to the procedure of updaing Q-value by following a sequence of $\\dots, S_t, A_t, R_{t+1}, S_{t+1}, A_{t+1}, \\dots$. The idea follows the same route of GPI. Within one episode, it works as follows:\n Initialize $t=0$. Start with $S_0$ and choose action $A_0 = \\arg\\max_{a \\in \\mathcal{A}} Q(S_0, a)$, where $\\epsilon$-greedy is commonly applied. At time $t$, after applying action $A_t$, we observe reward $R_{t+1}$ and get into the next state $S_{t+1}$. Then pick the next action in the same way as in step 2: $A_{t+1} = \\arg\\max_{a \\in \\mathcal{A}} Q(S_{t+1}, a)$. Update the Q-value function: $ Q(S_t, A_t) \\leftarrow Q(S_t, A_t) + \\alpha (R_{t+1} + \\gamma Q(S_{t+1}, A_{t+1}) - Q(S_t, A_t)) $. Set $t = t+1$ and repeat from step 3. In each step of SARSA, we need to choose the next action according to the current policy.\nQ-Learning: Off-policy TD control The development of Q-learning (Watkins \u0026amp; Dayan, 1992) is a big breakout in the early days of Reinforcement Learning. Within one episode, it works as follows:\n Initialize $t=0$. Starts with $S_0$. At time step $t$, we pick the action according to Q values, $A_t = \\arg\\max_{a \\in \\mathcal{A}} Q(S_t, a)$ and $\\epsilon$-greedy is commonly applied. After applying action $A_t$, we observe reward $R_{t+1}$ and get into the next state $S_{t+1}$. Update the Q-value function: $Q(S_t, A_t) \\leftarrow Q(S_t, A_t) + \\alpha (R_{t+1} + \\gamma \\max_{a \\in \\mathcal{A}} Q(S_{t+1}, a) - Q(S_t, A_t))$. $t = t+1$ and repeat from step 3. The key difference from SARSA is that Q-learning does not follow the current policy to pick the second action $A_{t+1}$. It estimates $Q^*$ out of the best Q values, but which action (denoted as $a^*$) leads to this maximal Q does not matter and in the next step Q-learning may not follow $a^*$.\nFig. 6. The backup diagrams for Q-learning and SARSA. (Image source: Replotted based on Figure 6.5 in Sutton \u0026 Barto (2017)) Deep Q-Network Theoretically, we can memorize $Q_*(.)$ for all state-action pairs in Q-learning, like in a gigantic table. However, it quickly becomes computationally infeasible when the state and action space are large. Thus people use functions (i.e. a machine learning model) to approximate Q values and this is called function approximation. For example, if we use a function with parameter $\\theta$ to calculate Q values, we can label Q value function as $Q(s, a; \\theta)$.\nUnfortunately Q-learning may suffer from instability and divergence when combined with an nonlinear Q-value function approximation and bootstrapping (See Problems #2).\nDeep Q-Network (\u0026ldquo;DQN\u0026rdquo;; Mnih et al. 2015) aims to greatly improve and stabilize the training procedure of Q-learning by two innovative mechanisms:\n Experience Replay: All the episode steps $e_t = (S_t, A_t, R_t, S_{t+1})$ are stored in one replay memory $D_t = \\{ e_1, \\dots, e_t \\}$. $D_t$ has experience tuples over many episodes. During Q-learning updates, samples are drawn at random from the replay memory and thus one sample could be used multiple times. Experience replay improves data efficiency, removes correlations in the observation sequences, and smooths over changes in the data distribution. Periodically Updated Target: Q is optimized towards target values that are only periodically updated. The Q network is cloned and kept frozen as the optimization target every C steps (C is a hyperparameter). This modification makes the training more stable as it overcomes the short-term oscillations. The loss function looks like this:\n $$ \\mathcal{L}(\\theta) = \\mathbb{E}_{(s, a, r, s') \\sim U(D)} \\Big[ \\big( r + \\gamma \\max_{a'} Q(s', a'; \\theta^{-}) - Q(s, a; \\theta) \\big)^2 \\Big] $$ where $U(D)$ is a uniform distribution over the replay memory D; $\\theta^{-}$ is the parameters of the frozen target Q-network.\nIn addition, it is also found to be helpful to clip the error term to be between [-1, 1]. (I always get mixed feeling with parameter clipping, as many studies have shown that it works empirically but it makes the math much less pretty. :/)\nFig. 7. Algorithm for DQN with experience replay and occasionally frozen optimization target. The prepossessed sequence is the output of some processes running on the input images of Atari games. Don't worry too much about it; just consider them as input feature vectors. (Image source: Mnih et al. 2015) There are many extensions of DQN to improve the original design, such as DQN with dueling architecture (Wang et al. 2016) which estimates state-value function V(s) and advantage function A(s, a) with shared network parameters.\nCombining TD and MC Learning In the previous section on value estimation in TD learning, we only trace one step further down the action chain when calculating the TD target. One can easily extend it to take multiple steps to estimate the return.\nLet\u0026rsquo;s label the estimated return following n steps as $G_t^{(n)}, n=1, \\dots, \\infty$, then:\n $n$ $G_t$ Notes $n=1$ $G_t^{(1)} = R_{t+1} + \\gamma V(S_{t+1})$ TD learning $n=2$ $G_t^{(2)} = R_{t+1} + \\gamma R_{t+2} + \\gamma^2 V(S_{t+2})$ \u0026hellip; $n=n$ $ G_t^{(n)} = R_{t+1} + \\gamma R_{t+2} + \\dots + \\gamma^{n-1} R_{t+n} + \\gamma^n V(S_{t+n}) $ \u0026hellip; $n=\\infty$ $G_t^{(\\infty)} = R_{t+1} + \\gamma R_{t+2} + \\dots + \\gamma^{T-t-1} R_T + \\gamma^{T-t} V(S_T) $ MC estimation The generalized n-step TD learning still has the same form for updating the value function:\n $$ V(S_t) \\leftarrow V(S_t) + \\alpha (G_t^{(n)} - V(S_t)) $$ We are free to pick any $n$ in TD learning as we like. Now the question becomes what is the best $n$? Which $G_t^{(n)}$ gives us the best return approximation? A common yet smart solution is to apply a weighted sum of all possible n-step TD targets rather than to pick a single best n. The weights decay by a factor λ with n, $\\lambda^{n-1}$; the intuition is similar to why we want to discount future rewards when computing the return: the more future we look into the less confident we would be. To make all the weight (n → ∞) sum up to 1, we multiply every weight by (1-λ), because:\n $$ \\begin{aligned} \\text{let } S \u0026= 1 + \\lambda + \\lambda^2 + \\dots \\\\ S \u0026= 1 + \\lambda(1 + \\lambda + \\lambda^2 + \\dots) \\\\ S \u0026= 1 + \\lambda S \\\\ S \u0026= 1 / (1-\\lambda) \\end{aligned} $$ This weighted sum of many n-step returns is called λ-return $G_t^{\\lambda} = (1-\\lambda) \\sum_{n=1}^{\\infty} \\lambda^{n-1} G_t^{(n)}$. TD learning that adopts λ-return for value updating is labeled as TD(λ). The original version we introduced above is equivalent to TD(0).\nFig. 8. Comparison of the backup diagrams of Monte-Carlo, Temporal-Difference learning, and Dynamic Programming for state value functions. (Image source: David Silver's RL course lecture 4: \"Model-Free Prediction\") Policy Gradient All the methods we have introduced above aim to learn the state/action value function and then to select actions accordingly. Policy Gradient methods instead learn the policy directly with a parameterized function respect to $\\theta$, $\\pi(a \\vert s; \\theta)$. Let\u0026rsquo;s define the reward function (opposite of loss function) as the expected return and train the algorithm with the goal to maximize the reward function. My next post described why the policy gradient theorem works (proof) and introduced a number of policy gradient algorithms.\nIn discrete space:\n $$ \\mathcal{J}(\\theta) = V_{\\pi_\\theta}(S_1) = \\mathbb{E}_{\\pi_\\theta}[V_1] $$ where $S_1$ is the initial starting state.\nOr in continuous space:\n $$ \\mathcal{J}(\\theta) = \\sum_{s \\in \\mathcal{S}} d_{\\pi_\\theta}(s) V_{\\pi_\\theta}(s) = \\sum_{s \\in \\mathcal{S}} \\Big( d_{\\pi_\\theta}(s) \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s, \\theta) Q_\\pi(s, a) \\Big) $$ where $d_{\\pi_\\theta}(s)$ is stationary distribution of Markov chain for $\\pi_\\theta$. If you are unfamiliar with the definition of a \u0026ldquo;stationary distribution,\u0026rdquo; please check this reference.\nUsing gradient ascent we can find the best θ that produces the highest return. It is natural to expect policy-based methods are more useful in continuous space, because there is an infinite number of actions and/or states to estimate the values for in continuous space and hence value-based approaches are computationally much more expensive.\nPolicy Gradient Theorem Computing the gradient numerically can be done by perturbing θ by a small amount ε in the k-th dimension. It works even when $J(\\theta)$ is not differentiable (nice!), but unsurprisingly very slow.\n $$ \\frac{\\partial \\mathcal{J}(\\theta)}{\\partial \\theta_k} \\approx \\frac{\\mathcal{J}(\\theta + \\epsilon u_k) - \\mathcal{J}(\\theta)}{\\epsilon} $$ Or analytically,\n $$ \\mathcal{J}(\\theta) = \\mathbb{E}_{\\pi_\\theta} [r] = \\sum_{s \\in \\mathcal{S}} d_{\\pi_\\theta}(s) \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s; \\theta) R(s, a) $$ Actually we have nice theoretical support for (replacing $d(.)$ with $d_\\pi(.)$):\n $$ \\mathcal{J}(\\theta) = \\sum_{s \\in \\mathcal{S}} d_{\\pi_\\theta}(s) \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s; \\theta) Q_\\pi(s, a) \\propto \\sum_{s \\in \\mathcal{S}} d(s) \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s; \\theta) Q_\\pi(s, a) $$ Check Sec 13.1 in Sutton \u0026amp; Barto (2017) for why this is the case.\nThen,\n $$ \\begin{aligned} \\mathcal{J}(\\theta) \u0026= \\sum_{s \\in \\mathcal{S}} d(s) \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s; \\theta) Q_\\pi(s, a) \\\\ \\nabla \\mathcal{J}(\\theta) \u0026= \\sum_{s \\in \\mathcal{S}} d(s) \\sum_{a \\in \\mathcal{A}} \\nabla \\pi(a \\vert s; \\theta) Q_\\pi(s, a) \\\\ \u0026= \\sum_{s \\in \\mathcal{S}} d(s) \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s; \\theta) \\frac{\\nabla \\pi(a \\vert s; \\theta)}{\\pi(a \\vert s; \\theta)} Q_\\pi(s, a) \\\\ \u0026 = \\sum_{s \\in \\mathcal{S}} d(s) \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s; \\theta) \\nabla \\ln \\pi(a \\vert s; \\theta) Q_\\pi(s, a) \\\\ \u0026 = \\mathbb{E}_{\\pi_\\theta} [\\nabla \\ln \\pi(a \\vert s; \\theta) Q_\\pi(s, a)] \\end{aligned} $$ This result is named \u0026ldquo;Policy Gradient Theorem\u0026rdquo; which lays the theoretical foundation for various policy gradient algorithms:\n $$ \\nabla \\mathcal{J}(\\theta) = \\mathbb{E}_{\\pi_\\theta} [\\nabla \\ln \\pi(a \\vert s, \\theta) Q_\\pi(s, a)] $$ REINFORCE REINFORCE, also known as Monte-Carlo policy gradient, relies on $Q_\\pi(s, a)$, an estimated return by MC methods using episode samples, to update the policy parameter $\\theta$.\nA commonly used variation of REINFORCE is to subtract a baseline value from the return $G_t$ to reduce the variance of gradient estimation while keeping the bias unchanged. For example, a common baseline is state-value, and if applied, we would use $A(s, a) = Q(s, a) - V(s)$ in the gradient ascent update.\n Initialize θ at random Generate one episode $S_1, A_1, R_2, S_2, A_2, \\dots, S_T$ For t=1, 2, \u0026hellip; , T: Estimate the the return G_t since the time step t. $\\theta \\leftarrow \\theta + \\alpha \\gamma^t G_t \\nabla \\ln \\pi(A_t \\vert S_t, \\theta)$. Actor-Critic If the value function is learned in addition to the policy, we would get Actor-Critic algorithm.\n Critic: updates value function parameters w and depending on the algorithm it could be action-value $Q(a \\vert s; w)$ or state-value $V(s; w)$. Actor: updates policy parameters θ, in the direction suggested by the critic, $\\pi(a \\vert s; \\theta)$. Let\u0026rsquo;s see how it works in an action-value actor-critic algorithm.\n Initialize s, θ, w at random; sample $a \\sim \\pi(a \\vert s; \\theta)$. For t = 1… T: Sample reward $r_t \\sim R(s, a)$ and next state $s' \\sim P(s' \\vert s, a)$. Then sample the next action $a' \\sim \\pi(s', a'; \\theta)$. Update policy parameters: $\\theta \\leftarrow \\theta + \\alpha_\\theta Q(s, a; w) \\nabla_\\theta \\ln \\pi(a \\vert s; \\theta)$. Compute the correction for action-value at time t: $G_{t:t+1} = r_t + \\gamma Q(s', a'; w) - Q(s, a; w)$ and use it to update value function parameters: $w \\leftarrow w + \\alpha_w G_{t:t+1} \\nabla_w Q(s, a; w) $. Update $a \\leftarrow a'$ and $s \\leftarrow s'$. $\\alpha_\\theta$ and $\\alpha_w$ are two learning rates for policy and value function parameter updates, respectively.\nA3C Asynchronous Advantage Actor-Critic (Mnih et al., 2016), short for A3C, is a classic policy gradient method with the special focus on parallel training.\nIn A3C, the critics learn the state-value function, $V(s; w)$, while multiple actors are trained in parallel and get synced with global parameters from time to time. Hence, A3C is good for parallel training by default, i.e. on one machine with multi-core CPU.\nThe loss function for state-value is to minimize the mean squared error, $\\mathcal{J}_v (w) = (G_t - V(s; w))^2$ and we use gradient descent to find the optimal w. This state-value function is used as the baseline in the policy gradient update.\nHere is the algorithm outline:\n We have global parameters, θ and w; similar thread-specific parameters, θ' and w'. Initialize the time step t = 1 While T \u0026lt;= T_MAX: Reset gradient: dθ = 0 and dw = 0. Synchronize thread-specific parameters with global ones: θ' = θ and w' = w. $t_\\text{start}$ = t and get $s_t$. While ($s_t \\neq \\text{TERMINAL}$) and ($t - t_\\text{start} \u0026lt;= t_\\text{max}$): Pick the action $a_t \\sim \\pi(a_t \\vert s_t; \\theta')$ and receive a new reward $r_t$ and a new state $s_{t+1}$. Update t = t + 1 and T = T + 1. Initialize the variable that holds the return estimation $$R = \\begin{cases} 0 \u0026amp; \\text{if } s_t \\text{ is TERMINAL} \\ V(s_t; w') \u0026amp; \\text{otherwise} \\end{cases}$$. For $i = t-1, \\dots, t_\\text{start}$: $R \\leftarrow r_i + \\gamma R$; here R is a MC measure of $G_i$. Accumulate gradients w.r.t. θ': $d\\theta \\leftarrow d\\theta + \\nabla_{\\theta'} \\log \\pi(a_i \\vert s_i; \\theta')(R - V(s_i; w'))$; Accumulate gradients w.r.t. w': $dw \\leftarrow dw + \\nabla_{w'} (R - V(s_i; w'))^2$. Update synchronously θ using dθ, and w using dw. A3C enables the parallelism in multiple agent training. The gradient accumulation step (6.2) can be considered as a reformation of minibatch-based stochastic gradient update: the values of w or θ get corrected by a little bit in the direction of each training thread independently.\nEvolution Strategies Evolution Strategies (ES) is a type of model-agnostic optimization approach. It learns the optimal solution by imitating Darwin\u0026rsquo;s theory of the evolution of species by natural selection. Two prerequisites for applying ES: (1) our solutions can freely interact with the environment and see whether they can solve the problem; (2) we are able to compute a fitness score of how good each solution is. We don\u0026rsquo;t have to know the environment configuration to solve the problem.\nSay, we start with a population of random solutions. All of them are capable of interacting with the environment and only candidates with high fitness scores can survive (only the fittest can survive in a competition for limited resources). A new generation is then created by recombining the settings (gene mutation) of high-fitness survivors. This process is repeated until the new solutions are good enough.\nVery different from the popular MDP-based approaches as what we have introduced above, ES aims to learn the policy parameter $\\theta$ without value approximation. Let\u0026rsquo;s assume the distribution over the parameter $\\theta$ is an isotropic multivariate Gaussian with mean $\\mu$ and fixed covariance $\\sigma^2I$. The gradient of $F(\\theta)$ is calculated:\n $$ \\begin{aligned} \u0026 \\nabla_\\theta \\mathbb{E}_{\\theta \\sim N(\\mu, \\sigma^2)} F(\\theta) \\\\ =\u0026 \\nabla_\\theta \\int_\\theta F(\\theta) \\Pr(\\theta) \u0026\u0026 \\text{Pr(.) is the Gaussian density function.} \\\\ =\u0026 \\int_\\theta F(\\theta) \\Pr(\\theta) \\frac{\\nabla_\\theta \\Pr(\\theta)}{\\Pr(\\theta)} \\\\ =\u0026 \\int_\\theta F(\\theta) \\Pr(\\theta) \\nabla_\\theta \\log \\Pr(\\theta) \\\\ =\u0026 \\mathbb{E}_{\\theta \\sim N(\\mu, \\sigma^2)} [F(\\theta) \\nabla_\\theta \\log \\Pr(\\theta)] \u0026\u0026 \\text{Similar to how we do policy gradient update.} \\\\ =\u0026 \\mathbb{E}_{\\theta \\sim N(\\mu, \\sigma^2)} \\Big[ F(\\theta) \\nabla_\\theta \\log \\Big( \\frac{1}{\\sqrt{2\\pi\\sigma^2}} e^{-\\frac{(\\theta - \\mu)^2}{2 \\sigma^2 }} \\Big) \\Big] \\\\ =\u0026 \\mathbb{E}_{\\theta \\sim N(\\mu, \\sigma^2)} \\Big[ F(\\theta) \\nabla_\\theta \\Big( -\\log \\sqrt{2\\pi\\sigma^2} - \\frac{(\\theta - \\mu)^2}{2 \\sigma^2} \\Big) \\Big] \\\\ =\u0026 \\mathbb{E}_{\\theta \\sim N(\\mu, \\sigma^2)} \\Big[ F(\\theta) \\frac{\\theta - \\mu}{\\sigma^2} \\Big] \\end{aligned} $$ We can rewrite this formula in terms of a \u0026ldquo;mean\u0026rdquo; parameter $\\theta$ (different from the $\\theta$ above; this $\\theta$ is the base gene for further mutation), $\\epsilon \\sim N(0, I)$ and therefore $\\theta + \\epsilon \\sigma \\sim N(\\theta, \\sigma^2)$. $\\epsilon$ controls how much Gaussian noises should be added to create mutation:\n $$ \\nabla_\\theta \\mathbb{E}_{\\epsilon \\sim N(0, I)} F(\\theta + \\sigma \\epsilon) = \\frac{1}{\\sigma} \\mathbb{E}_{\\epsilon \\sim N(0, I)} [F(\\theta + \\sigma \\epsilon) \\epsilon] $$ Fig. 9. A simple parallel evolution-strategies-based RL algorithm. Parallel workers share the random seeds so that they can reconstruct the Gaussian noises with tiny communication bandwidth. (Image source: Salimans et al. 2017.) ES, as a black-box optimization algorithm, is another approach to RL problems (In my original writing, I used the phrase \u0026ldquo;a nice alternative\u0026rdquo;; Seita pointed me to this discussion and thus I updated my wording.). It has a couple of good characteristics (Salimans et al., 2017) keeping it fast and easy to train:\n ES does not need value function approximation; ES does not perform gradient back-propagation; ES is invariant to delayed or long-term rewards; ES is highly parallelizable with very little data communication. Known Problems Exploration-Exploitation Dilemma The problem of exploration vs exploitation dilemma has been discussed in my previous post. When the RL problem faces an unknown environment, this issue is especially a key to finding a good solution: without enough exploration, we cannot learn the environment well enough; without enough exploitation, we cannot complete our reward optimization task.\nDifferent RL algorithms balance between exploration and exploitation in different ways. In MC methods, Q-learning or many on-policy algorithms, the exploration is commonly implemented by ε-greedy; In ES, the exploration is captured by the policy parameter perturbation. Please keep this into consideration when develop a new RL algorithm.\nDeadly Triad Issue We do seek the efficiency and flexibility of TD methods that involve bootstrapping. However, when off-policy, nonlinear function approximation, and bootstrapping are combined in one RL algorithm, the training could be unstable and hard to converge. This issue is known as the deadly triad (Sutton \u0026amp; Barto, 2017). Many architectures using deep learning models were proposed to resolve the problem, including DQN to stabilize the training with experience replay and occasionally frozen target network.\nCase Study: AlphaGo Zero The game of Go has been an extremely hard problem in the field of Artificial Intelligence for decades until recent years. AlphaGo and AlphaGo Zero are two programs developed by a team at DeepMind. Both involve deep Convolutional Neural Networks (CNN) and Monte Carlo Tree Search (MCTS) and both have been approved to achieve the level of professional human Go players. Different from AlphaGo that relied on supervised learning from expert human moves, AlphaGo Zero used only reinforcement learning and self-play without human knowledge beyond the basic rules.\nFig. 10. The board of Go. Two players play black and white stones alternatively on the vacant intersections of a board with 19 x 19 lines. A group of stones must have at least one open point (an intersection, called a \"liberty\") to remain on the board and must have at least two or more enclosed liberties (called \"eyes\") to stay \"alive\". No stone shall repeat a previous position. With all the knowledge of RL above, let\u0026rsquo;s take a look at how AlphaGo Zero works. The main component is a deep CNN over the game board configuration (precisely, a ResNet with batch normalization and ReLU). This network outputs two values:\n $$ (p, v) = f_\\theta(s) $$ $s$: the game board configuration, 19 x 19 x 17 stacked feature planes; 17 features for each position, 8 past configurations (including current) for the current player + 8 past configurations for the opponent + 1 feature indicating the color (1=black, 0=white). We need to code the color specifically because the network is playing with itself and the colors of current player and opponents are switching between steps. $p$: the probability of selecting a move over 19^2 + 1 candidates (19^2 positions on the board, in addition to passing). $v$: the winning probability given the current setting. During self-play, MCTS further improves the action probability distribution $\\pi \\sim p(.)$ and then the action $a_t$ is sampled from this improved policy. The reward $z_t$ is a binary value indicating whether the current player eventually wins the game. Each move generates an episode tuple $(s_t, \\pi_t, z_t)$ and it is saved into the replay memory. The details on MCTS are skipped for the sake of space in this post; please read the original paper if you are interested.\nFig. 11. AlphaGo Zero is trained by self-play while MCTS improves the output policy further in every step. (Image source: Figure 1a in Silver et al., 2017). The network is trained with the samples in the replay memory to minimize the loss:\n $$ \\mathcal{L} = (z - v)^2 - \\pi^\\top \\log p + c \\| \\theta \\|^2 $$ where $c$ is a hyperparameter controlling the intensity of L2 penalty to avoid overfitting.\nAlphaGo Zero simplified AlphaGo by removing supervised learning and merging separated policy and value networks into one. It turns out that AlphaGo Zero achieved largely improved performance with a much shorter training time! I strongly recommend reading these two papers side by side and compare the difference, super fun.\nI know this is a long read, but hopefully worth it. If you notice mistakes and errors in this post, don\u0026rsquo;t hesitate to contact me at [lilian dot wengweng at gmail dot com]. See you in the next post! :)\n Cited as:\n@article{weng2018bandit, title = \u0026quot;A (Long) Peek into Reinforcement Learning\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2018\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2018-02-19-rl-overview/\u0026quot; } References [1] Yuxi Li. Deep reinforcement learning: An overview. arXiv preprint arXiv:1701.07274. 2017.\n[2] Richard S. Sutton and Andrew G. Barto. Reinforcement Learning: An Introduction; 2nd Edition. 2017.\n[3] Volodymyr Mnih, et al. Asynchronous methods for deep reinforcement learning. ICML. 2016.\n[4] Tim Salimans, et al. Evolution strategies as a scalable alternative to reinforcement learning. arXiv preprint arXiv:1703.03864 (2017).\n[5] David Silver, et al. Mastering the game of go without human knowledge. Nature 550.7676 (2017): 354.\n[6] David Silver, et al. Mastering the game of Go with deep neural networks and tree search. Nature 529.7587 (2016): 484-489.\n[7] Volodymyr Mnih, et al. Human-level control through deep reinforcement learning. Nature 518.7540 (2015): 529.\n[8] Ziyu Wang, et al. Dueling network architectures for deep reinforcement learning. ICML. 2016.\n[9] Reinforcement Learning lectures by David Silver on YouTube.\n[10] OpenAI Blog: Evolution Strategies as a Scalable Alternative to Reinforcement Learning\n[11] Frank Sehnke, et al. Parameter-exploring policy gradients. Neural Networks 23.4 (2010): 551-559.\n[12] Csaba Szepesvári. Algorithms for reinforcement learning. 1st Edition. Synthesis lectures on artificial intelligence and machine learning 4.1 (2010): 1-103.\n If you notice mistakes and errors in this post, please don\u0026rsquo;t hesitate to contact me at [lilian dot wengweng at gmail dot com] and I would be super happy to correct them right away!\n","permalink":"https://lilianweng.github.io/posts/2018-02-19-rl-overview/","summary":"[Updated on 2020-09-03: Updated the algorithm of SARSA and Q-learning so that the difference is more pronounced. [Updated on 2021-09-19: Thanks to 爱吃猫的鱼, we have this post in Chinese].\nA couple of exciting news in Artificial Intelligence (AI) has just happened in recent years. AlphaGo defeated the best professional human player in the game of Go. Very soon the extended algorithm AlphaGo Zero beat AlphaGo by 100-0 without supervised learning on human knowledge.","title":"A (Long) Peek into Reinforcement Learning"},{"content":"The algorithms are implemented for Bernoulli bandit in lilianweng/multi-armed-bandit.\nExploitation vs Exploration The exploration vs exploitation dilemma exists in many aspects of our life. Say, your favorite restaurant is right around the corner. If you go there every day, you would be confident of what you will get, but miss the chances of discovering an even better option. If you try new places all the time, very likely you are gonna have to eat unpleasant food from time to time. Similarly, online advisors try to balance between the known most attractive ads and the new ads that might be even more successful.\nFig. 1. A real-life example of the exploration vs exploitation dilemma: where to eat? (Image source: UC Berkeley AI course slide, lecture 11.) If we have learned all the information about the environment, we are able to find the best strategy by even just simulating brute-force, let alone many other smart approaches. The dilemma comes from the incomplete information: we need to gather enough information to make best overall decisions while keeping the risk under control. With exploitation, we take advantage of the best option we know. With exploration, we take some risk to collect information about unknown options. The best long-term strategy may involve short-term sacrifices. For example, one exploration trial could be a total failure, but it warns us of not taking that action too often in the future.\nWhat is Multi-Armed Bandit? The multi-armed bandit problem is a classic problem that well demonstrates the exploration vs exploitation dilemma. Imagine you are in a casino facing multiple slot machines and each is configured with an unknown probability of how likely you can get a reward at one play. The question is: What is the best strategy to achieve highest long-term rewards?\nIn this post, we will only discuss the setting of having an infinite number of trials. The restriction on a finite number of trials introduces a new type of exploration problem. For instance, if the number of trials is smaller than the number of slot machines, we cannot even try every machine to estimate the reward probability (!) and hence we have to behave smartly w.r.t. a limited set of knowledge and resources (i.e. time).\nFig. 2. An illustration of how a Bernoulli multi-armed bandit works. The reward probabilities are **unknown** to the player. A naive approach can be that you continue to playing with one machine for many many rounds so as to eventually estimate the \u0026ldquo;true\u0026rdquo; reward probability according to the law of large numbers. However, this is quite wasteful and surely does not guarantee the best long-term reward.\nDefinition Now let\u0026rsquo;s give it a scientific definition.\nA Bernoulli multi-armed bandit can be described as a tuple of $\\langle \\mathcal{A}, \\mathcal{R} \\rangle$, where:\n We have $K$ machines with reward probabilities, $\\{ \\theta_1, \\dots, \\theta_K \\}$. At each time step t, we take an action a on one slot machine and receive a reward r. $\\mathcal{A}$ is a set of actions, each referring to the interaction with one slot machine. The value of action a is the expected reward, $Q(a) = \\mathbb{E} [r \\vert a] = \\theta$. If action $a_t$ at the time step t is on the i-th machine, then $Q(a_t) = \\theta_i$. $\\mathcal{R}$ is a reward function. In the case of Bernoulli bandit, we observe a reward r in a stochastic fashion. At the time step t, $r_t = \\mathcal{R}(a_t)$ may return reward 1 with a probability $Q(a_t)$ or 0 otherwise. It is a simplified version of Markov decision process, as there is no state $\\mathcal{S}$.\nThe goal is to maximize the cumulative reward $\\sum_{t=1}^T r_t$. If we know the optimal action with the best reward, then the goal is same as to minimize the potential regret or loss by not picking the optimal action.\nThe optimal reward probability $\\theta^{*}$ of the optimal action $a^{*}$ is:\n $$ \\theta^{*}=Q(a^{*})=\\max_{a \\in \\mathcal{A}} Q(a) = \\max_{1 \\leq i \\leq K} \\theta_i $$ Our loss function is the total regret we might have by not selecting the optimal action up to the time step T:\n $$ \\mathcal{L}_T = \\mathbb{E} \\Big[ \\sum_{t=1}^T \\big( \\theta^{*} - Q(a_t) \\big) \\Big] $$ Bandit Strategies Based on how we do exploration, there several ways to solve the multi-armed bandit.\n No exploration: the most naive approach and a bad one. Exploration at random Exploration smartly with preference to uncertainty ε-Greedy Algorithm The ε-greedy algorithm takes the best action most of the time, but does random exploration occasionally. The action value is estimated according to the past experience by averaging the rewards associated with the target action a that we have observed so far (up to the current time step t):\n $$ \\hat{Q}_t(a) = \\frac{1}{N_t(a)} \\sum_{\\tau=1}^t r_\\tau \\mathbb{1}[a_\\tau = a] $$ where $\\mathbb{1}$ is a binary indicator function and $N_t(a)$ is how many times the action a has been selected so far, $N_t(a) = \\sum_{\\tau=1}^t \\mathbb{1}[a_\\tau = a]$.\nAccording to the ε-greedy algorithm, with a small probability $\\epsilon$ we take a random action, but otherwise (which should be the most of the time, probability 1-$\\epsilon$) we pick the best action that we have learnt so far: $\\hat{a}^{*}_t = \\arg\\max_{a \\in \\mathcal{A}} \\hat{Q}_t(a)$.\nCheck my toy implementation here.\nUpper Confidence Bounds Random exploration gives us an opportunity to try out options that we have not known much about. However, due to the randomness, it is possible we end up exploring a bad action which we have confirmed in the past (bad luck!). To avoid such inefficient exploration, one approach is to decrease the parameter ε in time and the other is to be optimistic about options with high uncertainty and thus to prefer actions for which we haven\u0026rsquo;t had a confident value estimation yet. Or in other words, we favor exploration of actions with a strong potential to have a optimal value.\nThe Upper Confidence Bounds (UCB) algorithm measures this potential by an upper confidence bound of the reward value, $\\hat{U}_t(a)$, so that the true value is below with bound $Q(a) \\leq \\hat{Q}_t(a) + \\hat{U}_t(a)$ with high probability. The upper bound $\\hat{U}_t(a)$ is a function of $N_t(a)$; a larger number of trials $N_t(a)$ should give us a smaller bound $\\hat{U}_t(a)$.\nIn UCB algorithm, we always select the greediest action to maximize the upper confidence bound:\n $$ a^{UCB}_t = argmax_{a \\in \\mathcal{A}} \\hat{Q}_t(a) + \\hat{U}_t(a) $$ Now, the question is how to estimate the upper confidence bound.\nHoeffding\u0026rsquo;s Inequality If we do not want to assign any prior knowledge on how the distribution looks like, we can get help from \u0026ldquo;Hoeffding\u0026rsquo;s Inequality\u0026rdquo; \u0026mdash; a theorem applicable to any bounded distribution.\nLet $X_1, \\dots, X_t$ be i.i.d. (independent and identically distributed) random variables and they are all bounded by the interval [0, 1]. The sample mean is $\\overline{X}_t = \\frac{1}{t}\\sum_{\\tau=1}^t X_\\tau$. Then for u \u0026gt; 0, we have:\n $$ \\mathbb{P} [ \\mathbb{E}[X] \\overline{X}_t + u] \\leq e^{-2tu^2} $$ Given one target action a, let us consider:\n $r_t(a)$ as the random variables, $Q(a)$ as the true mean, $\\hat{Q}_t(a)$ as the sample mean, And $u$ as the upper confidence bound, $u = U_t(a)$ Then we have,\n $$ \\mathbb{P} [ Q(a) \\hat{Q}_t(a) + U_t(a)] \\leq e^{-2t{U_t(a)}^2} $$ We want to pick a bound so that with high chances the true mean is blow the sample mean + the upper confidence bound. Thus $e^{-2t U_t(a)^2}$ should be a small probability. Let\u0026rsquo;s say we are ok with a tiny threshold p:\n $$ e^{-2t U_t(a)^2} = p \\text{ Thus, } U_t(a) = \\sqrt{\\frac{-\\log p}{2 N_t(a)}} $$ UCB1 One heuristic is to reduce the threshold p in time, as we want to make more confident bound estimation with more rewards observed. Set $p=t^{-4}$ we get UCB1 algorithm:\n $$ U_t(a) = \\sqrt{\\frac{2 \\log t}{N_t(a)}} \\text{ and } a^{UCB1}_t = \\arg\\max_{a \\in \\mathcal{A}} Q(a) + \\sqrt{\\frac{2 \\log t}{N_t(a)}} $$ Bayesian UCB In UCB or UCB1 algorithm, we do not assume any prior on the reward distribution and therefore we have to rely on the Hoeffding\u0026rsquo;s Inequality for a very generalize estimation. If we are able to know the distribution upfront, we would be able to make better bound estimation.\nFor example, if we expect the mean reward of every slot machine to be Gaussian as in Fig 2, we can set the upper bound as 95% confidence interval by setting $\\hat{U}_t(a)$ to be twice the standard deviation.\nFig. 3. When the expected reward has a Gaussian distribution. $\\sigma(a\\_i)$ is the standard deviation and $c\\sigma(a\\_i)$ is the upper confidence bound. The constant $c$ is a adjustable hyperparameter. (Image source: UCL RL course lecture 9's slides) Check my toy implementation of UCB1 and Bayesian UCB with Beta prior on θ.\nThompson Sampling Thompson sampling has a simple idea but it works great for solving the multi-armed bandit problem.\nFig. 4. Oops, I guess not this Thompson? (Credit goes to Ben Taborsky; he has a full theorem of how Thompson invented while pondering over who to pass the ball. Yes I stole his joke.) At each time step, we want to select action a according to the probability that a is optimal:\n $$ \\begin{aligned} \\pi(a \\; \\vert \\; h_t) \u0026= \\mathbb{P} [ Q(a) Q(a'), \\forall a' \\neq a \\; \\vert \\; h_t] \\\\ \u0026= \\mathbb{E}_{\\mathcal{R} \\vert h_t} [ \\mathbb{1}(a = \\arg\\max_{a \\in \\mathcal{A}} Q(a)) ] \\end{aligned} $$ where $\\pi(a ; \\vert ; h_t)$ is the probability of taking action a given the history $h_t$.\nFor the Bernoulli bandit, it is natural to assume that $Q(a)$ follows a Beta distribution, as $Q(a)$ is essentially the success probability θ in Bernoulli distribution. The value of $\\text{Beta}(\\alpha, \\beta)$ is within the interval [0, 1]; α and β correspond to the counts when we succeeded or failed to get a reward respectively.\nFirst, let us initialize the Beta parameters α and β based on some prior knowledge or belief for every action. For example,\n α = 1 and β = 1; we expect the reward probability to be 50% but we are not very confident. α = 1000 and β = 9000; we strongly believe that the reward probability is 10%. At each time t, we sample an expected reward, $\\tilde{Q}(a)$, from the prior distribution $\\text{Beta}(\\alpha_i, \\beta_i)$ for every action. The best action is selected among samples: $a^{TS}_t = \\arg\\max_{a \\in \\mathcal{A}} \\tilde{Q}(a)$. After the true reward is observed, we can update the Beta distribution accordingly, which is essentially doing Bayesian inference to compute the posterior with the known prior and the likelihood of getting the sampled data.\n $$ \\begin{aligned} \\alpha_i \u0026 \\leftarrow \\alpha_i + r_t \\mathbb{1}[a^{TS}_t = a_i] \\\\ \\beta_i \u0026 \\leftarrow \\beta_i + (1-r_t) \\mathbb{1}[a^{TS}_t = a_i] \\end{aligned} $$ Thompson sampling implements the idea of probability matching. Because its reward estimations $\\tilde{Q}$ are sampled from posterior distributions, each of these probabilities is equivalent to the probability that the corresponding action is optimal, conditioned on observed history.\nHowever, for many practical and complex problems, it can be computationally intractable to estimate the posterior distributions with observed true rewards using Bayesian inference. Thompson sampling still can work out if we are able to approximate the posterior distributions using methods like Gibbs sampling, Laplace approximate, and the bootstraps. This tutorial presents a comprehensive review; strongly recommend it if you want to learn more about Thompson sampling.\nCase Study I implemented the above algorithms in lilianweng/multi-armed-bandit. A BernoulliBandit object can be constructed with a list of random or predefined reward probabilities. The bandit algorithms are implemented as subclasses of Solver, taking a Bandit object as the target problem. The cumulative regrets are tracked in time.\nFig. 4. The result of a small experiment on solving a Bernoulli bandit with K = 10 slot machines with reward probabilities, {0.0, 0.1, 0.2, ..., 0.9}. Each solver runs 10000 steps. (Left) The plot of time step vs the cumulative regrets. (Middle) The plot of true reward probability vs estimated probability. (Right) The fraction of each action is picked during the 10000-step run.* Summary We need exploration because information is valuable. In terms of the exploration strategies, we can do no exploration at all, focusing on the short-term returns. Or we occasionally explore at random. Or even further, we explore and we are picky about which options to explore \u0026mdash; actions with higher uncertainty are favored because they can provide higher information gain.\n Cited as:\n@article{weng2018bandit, title = \u0026quot;The Multi-Armed Bandit Problem and Its Solutions\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2018\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2018-01-23-multi-armed-bandit/\u0026quot; } References [1] CS229 Supplemental Lecture notes: Hoeffding\u0026rsquo;s inequality.\n[2] RL Course by David Silver - Lecture 9: Exploration and Exploitation\n[3] Olivier Chapelle and Lihong Li. \u0026ldquo;An empirical evaluation of thompson sampling.\u0026quot; NIPS. 2011.\n[4] Russo, Daniel, et al. \u0026ldquo;A Tutorial on Thompson Sampling.\u0026quot; arXiv:1707.02038 (2017).\n","permalink":"https://lilianweng.github.io/posts/2018-01-23-multi-armed-bandit/","summary":"The algorithms are implemented for Bernoulli bandit in lilianweng/multi-armed-bandit.\nExploitation vs Exploration The exploration vs exploitation dilemma exists in many aspects of our life. Say, your favorite restaurant is right around the corner. If you go there every day, you would be confident of what you will get, but miss the chances of discovering an even better option. If you try new places all the time, very likely you are gonna have to eat unpleasant food from time to time.","title":"The Multi-Armed Bandit Problem and Its Solutions"},{"content":"[Updated on 2018-12-20: Remove YOLO here. Part 4 will cover multiple fast object detection algorithms, including YOLO.] [Updated on 2018-12-27: Add bbox regression and tricks sections for R-CNN.]\nIn the series of \u0026ldquo;Object Detection for Dummies\u0026rdquo;, we started with basic concepts in image processing, such as gradient vectors and HOG, in Part 1. Then we introduced classic convolutional neural network architecture designs for classification and pioneer models for object recognition, Overfeat and DPM, in Part 2. In the third post of this series, we are about to review a set of models in the R-CNN (\u0026ldquo;Region-based CNN\u0026rdquo;) family.\nLinks to all the posts in the series: [Part 1] [Part 2] [Part 3] [Part 4].\nHere is a list of papers covered in this post ;)\n Model Goal Resources R-CNN Object recognition [paper][code] Fast R-CNN Object recognition [paper][code] Faster R-CNN Object recognition [paper][code] Mask R-CNN Image segmentation [paper][code] R-CNN R-CNN (Girshick et al., 2014) is short for \u0026ldquo;Region-based Convolutional Neural Networks\u0026rdquo;. The main idea is composed of two steps. First, using selective search, it identifies a manageable number of bounding-box object region candidates (\u0026ldquo;region of interest\u0026rdquo; or \u0026ldquo;RoI\u0026rdquo;). And then it extracts CNN features from each region independently for classification.\nFig. 1. The architecture of R-CNN. (Image source: Girshick et al., 2014) Model Workflow How R-CNN works can be summarized as follows:\n Pre-train a CNN network on image classification tasks; for example, VGG or ResNet trained on ImageNet dataset. The classification task involves N classes. NOTE: You can find a pre-trained AlexNet in Caffe Model Zoo. I don’t think you can find it in Tensorflow, but Tensorflow-slim model library provides pre-trained ResNet, VGG, and others.\n Propose category-independent regions of interest by selective search (~2k candidates per image). Those regions may contain target objects and they are of different sizes. Region candidates are warped to have a fixed size as required by CNN. Continue fine-tuning the CNN on warped proposal regions for K + 1 classes; The additional one class refers to the background (no object of interest). In the fine-tuning stage, we should use a much smaller learning rate and the mini-batch oversamples the positive cases because most proposed regions are just background. Given every image region, one forward propagation through the CNN generates a feature vector. This feature vector is then consumed by a binary SVM trained for each class independently. The positive samples are proposed regions with IoU (intersection over union) overlap threshold \u0026gt;= 0.3, and negative samples are irrelevant others. To reduce the localization errors, a regression model is trained to correct the predicted detection window on bounding box correction offset using CNN features. Bounding Box Regression Given a predicted bounding box coordinate $\\mathbf{p} = (p_x, p_y, p_w, p_h)$ (center coordinate, width, height) and its corresponding ground truth box coordinates $\\mathbf{g} = (g_x, g_y, g_w, g_h)$ , the regressor is configured to learn scale-invariant transformation between two centers and log-scale transformation between widths and heights. All the transformation functions take $\\mathbf{p}$ as input.\n $$ \\begin{aligned} \\hat{g}_x \u0026= p_w d_x(\\mathbf{p}) + p_x \\\\ \\hat{g}_y \u0026= p_h d_y(\\mathbf{p}) + p_y \\\\ \\hat{g}_w \u0026= p_w \\exp({d_w(\\mathbf{p})}) \\\\ \\hat{g}_h \u0026= p_h \\exp({d_h(\\mathbf{p})}) \\end{aligned} $$ Fig. 2. Illustration of transformation between predicted and ground truth bounding boxes. An obvious benefit of applying such transformation is that all the bounding box correction functions, $d_i(\\mathbf{p})$ where $i \\in \\{ x, y, w, h \\}$, can take any value between [-∞, +∞]. The targets for them to learn are:\n $$ \\begin{aligned} t_x \u0026= (g_x - p_x) / p_w \\\\ t_y \u0026= (g_y - p_y) / p_h \\\\ t_w \u0026= \\log(g_w/p_w) \\\\ t_h \u0026= \\log(g_h/p_h) \\end{aligned} $$ A standard regression model can solve the problem by minimizing the SSE loss with regularization:\n $$ \\mathcal{L}_\\text{reg} = \\sum_{i \\in \\{x, y, w, h\\}} (t_i - d_i(\\mathbf{p}))^2 + \\lambda \\|\\mathbf{w}\\|^2 $$ The regularization term is critical here and RCNN paper picked the best λ by cross validation. It is also noteworthy that not all the predicted bounding boxes have corresponding ground truth boxes. For example, if there is no overlap, it does not make sense to run bbox regression. Here, only a predicted box with a nearby ground truth box with at least 0.6 IoU is kept for training the bbox regression model.\nCommon Tricks Several tricks are commonly used in RCNN and other detection models.\nNon-Maximum Suppression\nLikely the model is able to find multiple bounding boxes for the same object. Non-max suppression helps avoid repeated detection of the same instance. After we get a set of matched bounding boxes for the same object category: Sort all the bounding boxes by confidence score. Discard boxes with low confidence scores. While there is any remaining bounding box, repeat the following: Greedily select the one with the highest score. Skip the remaining boxes with high IoU (i.e. \u0026gt; 0.5) with previously selected one.\nFig. 3. Multiple bounding boxes detect the car in the image. After non-maximum suppression, only the best remains and the rest are ignored as they have large overlaps with the selected one. (Image source: DPM paper) Hard Negative Mining\nWe consider bounding boxes without objects as negative examples. Not all the negative examples are equally hard to be identified. For example, if it holds pure empty background, it is likely an “easy negative”; but if the box contains weird noisy texture or partial object, it could be hard to be recognized and these are “hard negative”.\nThe hard negative examples are easily misclassified. We can explicitly find those false positive samples during the training loops and include them in the training data so as to improve the classifier.\nSpeed Bottleneck Looking through the R-CNN learning steps, you could easily find out that training an R-CNN model is expensive and slow, as the following steps involve a lot of work:\n Running selective search to propose 2000 region candidates for every image; Generating the CNN feature vector for every image region (N images * 2000). The whole process involves three models separately without much shared computation: the convolutional neural network for image classification and feature extraction; the top SVM classifier for identifying target objects; and the regression model for tightening region bounding boxes. Fast R-CNN To make R-CNN faster, Girshick (2015) improved the training procedure by unifying three independent models into one jointly trained framework and increasing shared computation results, named Fast R-CNN. Instead of extracting CNN feature vectors independently for each region proposal, this model aggregates them into one CNN forward pass over the entire image and the region proposals share this feature matrix. Then the same feature matrix is branched out to be used for learning the object classifier and the bounding-box regressor. In conclusion, computation sharing speeds up R-CNN.\nFig. 4. The architecture of Fast R-CNN. (Image source: Girshick, 2015) RoI Pooling It is a type of max pooling to convert features in the projected region of the image of any size, h x w, into a small fixed window, H x W. The input region is divided into H x W grids, approximately every subwindow of size h/H x w/W. Then apply max-pooling in each grid.\nFig. 5. RoI pooling (Image source: Stanford CS231n slides.) Model Workflow How Fast R-CNN works is summarized as follows; many steps are same as in R-CNN:\n First, pre-train a convolutional neural network on image classification tasks. Propose regions by selective search (~2k candidates per image). Alter the pre-trained CNN: Replace the last max pooling layer of the pre-trained CNN with a RoI pooling layer. The RoI pooling layer outputs fixed-length feature vectors of region proposals. Sharing the CNN computation makes a lot of sense, as many region proposals of the same images are highly overlapped. Replace the last fully connected layer and the last softmax layer (K classes) with a fully connected layer and softmax over K + 1 classes. Finally the model branches into two output layers: A softmax estimator of K + 1 classes (same as in R-CNN, +1 is the \u0026ldquo;background\u0026rdquo; class), outputting a discrete probability distribution per RoI. A bounding-box regression model which predicts offsets relative to the original RoI for each of K classes. Loss Function The model is optimized for a loss combining two tasks (classification + localization):\n| Symbol | Explanation | | $u$ | True class label, $ u \\in 0, 1, \\dots, K$; by convention, the catch-all background class has $u = 0$. | | $p$ | Discrete probability distribution (per RoI) over K + 1 classes: $p = (p_0, \\dots, p_K)$, computed by a softmax over the K + 1 outputs of a fully connected layer. | | $v$ | True bounding box $ v = (v_x, v_y, v_w, v_h) $. | | $t^u$ | Predicted bounding box correction, $t^u = (t^u_x, t^u_y, t^u_w, t^u_h)$. See above. | {:.info}\nThe loss function sums up the cost of classification and bounding box prediction: $\\mathcal{L} = \\mathcal{L}_\\text{cls} + \\mathcal{L}_\\text{box}$. For \u0026ldquo;background\u0026rdquo; RoI, $\\mathcal{L}_\\text{box}$ is ignored by the indicator function $\\mathbb{1} [u \\geq 1]$, defined as:\n $$ \\mathbb{1} [u = 1] = \\begin{cases} 1 \u0026 \\text{if } u \\geq 1\\\\ 0 \u0026 \\text{otherwise} \\end{cases} $$ The overall loss function is:\n $$ \\begin{align*} \\mathcal{L}(p, u, t^u, v) \u0026= \\mathcal{L}_\\text{cls} (p, u) + \\mathbb{1} [u \\geq 1] \\mathcal{L}_\\text{box}(t^u, v) \\\\ \\mathcal{L}_\\text{cls}(p, u) \u0026= -\\log p_u \\\\ \\mathcal{L}_\\text{box}(t^u, v) \u0026= \\sum_{i \\in \\{x, y, w, h\\}} L_1^\\text{smooth} (t^u_i - v_i) \\end{align*} $$ The bounding box loss $\\mathcal{L}_{box}$ should measure the difference between $t^u_i$ and $v_i$ using a robust loss function. The smooth L1 loss is adopted here and it is claimed to be less sensitive to outliers.\n $$ L_1^\\text{smooth}(x) = \\begin{cases} 0.5 x^2 \u0026 \\text{if } \\vert x \\vert Fig. 6. The plot of smooth L1 loss, $y = L\\_1^\\text{smooth}(x)$. (Image source: link) Speed Bottleneck Fast R-CNN is much faster in both training and testing time. However, the improvement is not dramatic because the region proposals are generated separately by another model and that is very expensive.\nFaster R-CNN An intuitive speedup solution is to integrate the region proposal algorithm into the CNN model. Faster R-CNN (Ren et al., 2016) is doing exactly this: construct a single, unified model composed of RPN (region proposal network) and fast R-CNN with shared convolutional feature layers.\nFig. 7. An illustration of Faster R-CNN model. (Image source: Ren et al., 2016) Model Workflow Pre-train a CNN network on image classification tasks. Fine-tune the RPN (region proposal network) end-to-end for the region proposal task, which is initialized by the pre-train image classifier. Positive samples have IoU (intersection-over-union) \u0026gt; 0.7, while negative samples have IoU \u0026lt; 0.3. Slide a small n x n spatial window over the conv feature map of the entire image. At the center of each sliding window, we predict multiple regions of various scales and ratios simultaneously. An anchor is a combination of (sliding window center, scale, ratio). For example, 3 scales + 3 ratios =\u0026gt; k=9 anchors at each sliding position. Train a Fast R-CNN object detection model using the proposals generated by the current RPN Then use the Fast R-CNN network to initialize RPN training. While keeping the shared convolutional layers, only fine-tune the RPN-specific layers. At this stage, RPN and the detection network have shared convolutional layers! Finally fine-tune the unique layers of Fast R-CNN Step 4-5 can be repeated to train RPN and Fast R-CNN alternatively if needed. Loss Function Faster R-CNN is optimized for a multi-task loss function, similar to fast R-CNN.\n| Symbol | Explanation | | $p_i$ | Predicted probability of anchor i being an object. | | $p^*_i$ | Ground truth label (binary) of whether anchor i is an object. | | $t_i$ | Predicted four parameterized coordinates. | | $t^*_i$ | Ground truth coordinates. | | $N_\\text{cls}$ | Normalization term, set to be mini-batch size (~256) in the paper. | | $N_\\text{box}$ | Normalization term, set to the number of anchor locations (~2400) in the paper. | | $\\lambda$ | A balancing parameter, set to be ~10 in the paper (so that both $\\mathcal{L}_\\text{cls}$ and $\\mathcal{L}_\\text{box}$ terms are roughly equally weighted). | {:.info}\nThe multi-task loss function combines the losses of classification and bounding box regression:\n $$ \\begin{align*} \\mathcal{L} \u0026= \\mathcal{L}_\\text{cls} + \\mathcal{L}_\\text{box} \\\\ \\mathcal{L}(\\{p_i\\}, \\{t_i\\}) \u0026= \\frac{1}{N_\\text{cls}} \\sum_i \\mathcal{L}_\\text{cls} (p_i, p^*_i) + \\frac{\\lambda}{N_\\text{box}} \\sum_i p^*_i \\cdot L_1^\\text{smooth}(t_i - t^*_i) \\\\ \\end{align*} $$ where $\\mathcal{L}_\\text{cls}$ is the log loss function over two classes, as we can easily translate a multi-class classification into a binary classification by predicting a sample being a target object versus not. $L_1^\\text{smooth}$ is the smooth L1 loss.\n $$ \\mathcal{L}_\\text{cls} (p_i, p^*_i) = - p^*_i \\log p_i - (1 - p^*_i) \\log (1 - p_i) $$ Mask R-CNN Mask R-CNN (He et al., 2017) extends Faster R-CNN to pixel-level image segmentation. The key point is to decouple the classification and the pixel-level mask prediction tasks. Based on the framework of Faster R-CNN, it added a third branch for predicting an object mask in parallel with the existing branches for classification and localization. The mask branch is a small fully-connected network applied to each RoI, predicting a segmentation mask in a pixel-to-pixel manner.\nFig. 8. Mask R-CNN is Faster R-CNN model with image segmentation. (Image source: He et al., 2017) Because pixel-level segmentation requires much more fine-grained alignment than bounding boxes, mask R-CNN improves the RoI pooling layer (named \u0026ldquo;RoIAlign layer\u0026rdquo;) so that RoI can be better and more precisely mapped to the regions of the original image.\nFig. 9. Predictions by Mask R-CNN on COCO test set. (Image source: He et al., 2017) RoIAlign The RoIAlign layer is designed to fix the location misalignment caused by quantization in the RoI pooling. RoIAlign removes the hash quantization, for example, by using x/16 instead of [x/16], so that the extracted features can be properly aligned with the input pixels. Bilinear interpolation is used for computing the floating-point location values in the input.\nFig. 10. A region of interest is mapped **accurately** from the original image onto the feature map without rounding up to integers. (Image source: link) Loss Function The multi-task loss function of Mask R-CNN combines the loss of classification, localization and segmentation mask: $ \\mathcal{L} = \\mathcal{L}_\\text{cls} + \\mathcal{L}_\\text{box} + \\mathcal{L}_\\text{mask}$, where $\\mathcal{L}_\\text{cls}$ and $\\mathcal{L}_\\text{box}$ are same as in Faster R-CNN.\nThe mask branch generates a mask of dimension m x m for each RoI and each class; K classes in total. Thus, the total output is of size $K \\cdot m^2$. Because the model is trying to learn a mask for each class, there is no competition among classes for generating masks.\n$\\mathcal{L}_\\text{mask}$ is defined as the average binary cross-entropy loss, only including k-th mask if the region is associated with the ground truth class k.\n $$ \\mathcal{L}_\\text{mask} = - \\frac{1}{m^2} \\sum_{1 \\leq i, j \\leq m} \\big[ y_{ij} \\log \\hat{y}^k_{ij} + (1-y_{ij}) \\log (1- \\hat{y}^k_{ij}) \\big] $$ where $y_{ij}$ is the label of a cell (i, j) in the true mask for the region of size m x m; $\\hat{y}_{ij}^k$ is the predicted value of the same cell in the mask learned for the ground-truth class k.\nSummary of Models in the R-CNN family Here I illustrate model designs of R-CNN, Fast R-CNN, Faster R-CNN and Mask R-CNN. You can track how one model evolves to the next version by comparing the small differences.\n Cited as:\n@article{weng2017detection3, title = \u0026quot;Object Detection for Dummies Part 3: R-CNN Family\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2017\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2017-12-31-object-recognition-part-3/\u0026quot; } Reference [1] Ross Girshick, Jeff Donahue, Trevor Darrell, and Jitendra Malik. \u0026ldquo;Rich feature hierarchies for accurate object detection and semantic segmentation.\u0026quot; In Proc. IEEE Conf. on computer vision and pattern recognition (CVPR), pp. 580-587. 2014.\n[2] Ross Girshick. \u0026ldquo;Fast R-CNN.\u0026quot; In Proc. IEEE Intl. Conf. on computer vision, pp. 1440-1448. 2015.\n[3] Shaoqing Ren, Kaiming He, Ross Girshick, and Jian Sun. \u0026ldquo;Faster R-CNN: Towards real-time object detection with region proposal networks.\u0026quot; In Advances in neural information processing systems (NIPS), pp. 91-99. 2015.\n[4] Kaiming He, Georgia Gkioxari, Piotr Dollár, and Ross Girshick. \u0026ldquo;Mask R-CNN.\u0026quot; arXiv preprint arXiv:1703.06870, 2017.\n[5] Joseph Redmon, Santosh Divvala, Ross Girshick, and Ali Farhadi. \u0026ldquo;You only look once: Unified, real-time object detection.\u0026quot; In Proc. IEEE Conf. on computer vision and pattern recognition (CVPR), pp. 779-788. 2016.\n[6] \u0026ldquo;A Brief History of CNNs in Image Segmentation: From R-CNN to Mask R-CNN\u0026rdquo; by Athelas.\n[7] Smooth L1 Loss: https://github.com/rbgirshick/py-faster-rcnn/files/764206/SmoothL1Loss.1.pdf\n","permalink":"https://lilianweng.github.io/posts/2017-12-31-object-recognition-part-3/","summary":"[Updated on 2018-12-20: Remove YOLO here. Part 4 will cover multiple fast object detection algorithms, including YOLO.] [Updated on 2018-12-27: Add bbox regression and tricks sections for R-CNN.]\nIn the series of \u0026ldquo;Object Detection for Dummies\u0026rdquo;, we started with basic concepts in image processing, such as gradient vectors and HOG, in Part 1. Then we introduced classic convolutional neural network architecture designs for classification and pioneer models for object recognition, Overfeat and DPM, in Part 2.","title":"Object Detection for Dummies Part 3: R-CNN Family"},{"content":"Part 1 of the \u0026ldquo;Object Detection for Dummies\u0026rdquo; series introduced: (1) the concept of image gradient vector and how HOG algorithm summarizes the information across all the gradient vectors in one image; (2) how the image segmentation algorithm works to detect regions that potentially contain objects; (3) how the Selective Search algorithm refines the outcomes of image segmentation for better region proposal.\nIn Part 2, we are about to find out more on the classic convolution neural network architectures for image classification. They lay the foundation for further progress on the deep learning models for object detection. Go check Part 3 if you want to learn more on R-CNN and related models.\nLinks to all the posts in the series: [Part 1] [Part 2] [Part 3] [Part 4].\nCNN for Image Classification CNN, short for \u0026ldquo;Convolutional Neural Network\u0026rdquo;, is the go-to solution for computer vision problems in the deep learning world. It was, to some extent, inspired by how human visual cortex system works.\nConvolution Operation I strongly recommend this guide to convolution arithmetic, which provides a clean and solid explanation with tons of visualizations and examples. Here let\u0026rsquo;s focus on two-dimensional convolution as we are working with images in this post.\nIn short, convolution operation slides a predefined kernel (also called \u0026ldquo;filter\u0026rdquo;) on top of the input feature map (matrix of image pixels), multiplying and adding the values of the kernel and partial input features to generate the output. The values form an output matrix, as usually, the kernel is much smaller than the input image.\nFig. 1. An illustration of applying a kernel on the input feature map to generate the output. (Image source: River Trail documentation) Figure 2 showcases two real examples of how to convolve a 3x3 kernel over a 5x5 2D matrix of numeric values to generate a 3x3 matrix. By controlling the padding size and the stride length, we can generate an output matrix of a certain size.\nFig. 2. Two examples of 2D convolution operation: (top) no padding and 1x1 strides; (bottom) 1x1 border zeros padding and 2x2 strides. (Image source: deeplearning.net) AlexNet (Krizhevsky et al, 2012) 5 convolution [+ optional max pooling] layers + 2 MLP layers + 1 LR layer Use data augmentation techniques to expand the training dataset, such as image translations, horizontal reflections, and patch extractions. Fig. 3. The architecture of AlexNet. (Image source: link) VGG (Simonyan and Zisserman, 2014) The network is considered as \u0026ldquo;very deep\u0026rdquo; at its time; 19 layers The architecture is extremely simplified with only 3x3 convolutional layers and 2x2 pooling layers. The stacking of small filters simulates a larger filter with fewer parameters. ResNet (He et al., 2015) The network is indeed very deep; 152 layers of simple architecture. Residual Block: Some input of a certain layer can be passed to the component two layers later. Residual blocks are essential for keeping a deep network trainable and eventually work. Without residual blocks, the training loss of a plain network does not monotonically decrease as the number of layers increases due to vanishing and exploding gradients. Fig. 4. An illustration of the residual block of ResNet. In some way, we can say the design of residual blocks is inspired by V4 getting input directly from V1 in the human visual cortex system. (left image source: Wang et al., 2017) Evaluation Metrics: mAP A common evaluation metric used in many object recognition and detection tasks is \u0026ldquo;mAP\u0026rdquo;, short for \u0026ldquo;mean average precision\u0026rdquo;. It is a number from 0 to 100; higher value is better.\n Combine all detections from all test images to draw a precision-recall curve (PR curve) for each class; The \u0026ldquo;average precision\u0026rdquo; (AP) is the area under the PR curve. Given that target objects are in different classes, we first compute AP separately for each class, and then average over classes. A detection is a true positive if it has \u0026ldquo;intersection over union\u0026rdquo; (IoU) with a ground-truth box greater than some threshold (usually 0.5; if so, the metric is \u0026ldquo;mAP@0.5\u0026rdquo;) Deformable Parts Model The Deformable Parts Model (DPM) (Felzenszwalb et al., 2010) recognizes objects with a mixture graphical model (Markov random fields) of deformable parts. The model consists of three major components:\n A coarse root filter defines a detection window that approximately covers an entire object. A filter specifies weights for a region feature vector. Multiple part filters that cover smaller parts of the object. Parts filters are learned at twice resolution of the root filter. A spatial model for scoring the locations of part filters relative to the root. Fig. 5. The DPM model contains (a) a root filter, (b) multiple part filters at twice the resolution, and (c) a model for scoring the location and deformation of parts. The quality of detecting an object is measured by the score of filters minus the deformation costs. The matching score $f$, in laymen\u0026rsquo;s terms, is:\n $$ f(\\text{model}, x) = f(\\beta_\\text{root}, x) + \\sum_{\\beta_\\text{part} \\in \\text{part filters}} \\max_y [f(\\beta_\\text{part}, y) - \\text{cost}(\\beta_\\text{part}, x, y)] $$ in which,\n $x$ is an image with a specified position and scale; $y$ is a sub region of $x$. $\\beta_\\text{root}$ is the root filter. $\\beta_\\text{part}$ is one part filter. cost() measures the penalty of the part deviating from its ideal location relative to the root. The basic score model is the dot product between the filter $\\beta$ and the region feature vector $\\Phi(x)$: $f(\\beta, x) = \\beta \\cdot \\Phi(x)$. The feature set $\\Phi(x)$ can be defined by HOG or other similar algorithms.\nA root location with high score detects a region with high chances to contain an object, while the locations of the parts with high scores confirm a recognized object hypothesis. The paper adopted latent SVM to model the classifier.\nFig. 6. The matching process by DPM. (Image source: Felzenszwalb et al., 2010) The author later claimed that DPM and CNN models are not two distinct approaches to object recognition. Instead, a DPM model can be formulated as a CNN by unrolling the DPM inference algorithm and mapping each step to an equivalent CNN layer. (Check the details in Girshick et al., 2015!)\nOverfeat Overfeat [paper][code] is a pioneer model of integrating the object detection, localization and classification tasks all into one convolutional neural network. The main idea is to (i) do image classification at different locations on regions of multiple scales of the image in a sliding window fashion, and (ii) predict the bounding box locations with a regressor trained on top of the same convolution layers.\nThe Overfeat model architecture is very similar to AlexNet. It is trained as follows:\nFig. 7. The training stages of the Overfeat model. (Image source: link) Train a CNN model (similar to AlexNet) on the image classification task. Then, we replace the top classifier layers by a regression network and train it to predict object bounding boxes at each spatial location and scale. The regressor is class-specific, each generated for one image class. Input: Images with classification and bounding box. Output: $(x_\\text{left}, x_\\text{right}, y_\\text{top}, y_\\text{bottom})$, 4 values in total, representing the coordinates of the bounding box edges. Loss: The regressor is trained to minimize $l2$ norm between generated bounding box and the ground truth for each training example. At the detection time,\n Perform classification at each location using the pretrained CNN model. Predict object bounding boxes on all classified regions generated by the classifier. Merge bounding boxes with sufficient overlap from localization and sufficient confidence of being the same object from the classifier. Cited as:\n@article{weng2017detection2, title = \u0026quot;Object Detection for Dummies Part 2: CNN, DPM and Overfeat\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2017\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2017-12-15-object-recognition-part-2/\u0026quot; } Reference [1] Vincent Dumoulin and Francesco Visin. \u0026ldquo;A guide to convolution arithmetic for deep learning.\u0026quot; arXiv preprint arXiv:1603.07285 (2016).\n[2] Haohan Wang, Bhiksha Raj, and Eric P. Xing. \u0026ldquo;On the Origin of Deep Learning.\u0026quot; arXiv preprint arXiv:1702.07800 (2017).\n[3] Pedro F. Felzenszwalb, Ross B. Girshick, David McAllester, and Deva Ramanan. \u0026ldquo;Object detection with discriminatively trained part-based models.\u0026quot; IEEE transactions on pattern analysis and machine intelligence 32, no. 9 (2010): 1627-1645.\n[4] Ross B. Girshick, Forrest Iandola, Trevor Darrell, and Jitendra Malik. \u0026ldquo;Deformable part models are convolutional neural networks.\u0026quot; In Proc. IEEE Conf. on Computer Vision and Pattern Recognition (CVPR), pp. 437-446. 2015.\n[5] Sermanet, Pierre, David Eigen, Xiang Zhang, Michaël Mathieu, Rob Fergus, and Yann LeCun. \u0026ldquo;OverFeat: Integrated Recognition, Localization and Detection using Convolutional Networks\u0026rdquo; arXiv preprint arXiv:1312.6229 (2013).\n","permalink":"https://lilianweng.github.io/posts/2017-12-15-object-recognition-part-2/","summary":"Part 1 of the \u0026ldquo;Object Detection for Dummies\u0026rdquo; series introduced: (1) the concept of image gradient vector and how HOG algorithm summarizes the information across all the gradient vectors in one image; (2) how the image segmentation algorithm works to detect regions that potentially contain objects; (3) how the Selective Search algorithm refines the outcomes of image segmentation for better region proposal.\nIn Part 2, we are about to find out more on the classic convolution neural network architectures for image classification.","title":"Object Detection for Dummies Part 2: CNN, DPM and Overfeat"},{"content":"I\u0026rsquo;ve never worked in the field of computer vision and has no idea how the magic could work when an autonomous car is configured to tell apart a stop sign from a pedestrian in a red hat. To motivate myself to look into the maths behind object recognition and detection algorithms, I\u0026rsquo;m writing a few posts on this topic \u0026ldquo;Object Detection for Dummies\u0026rdquo;. This post, part 1, starts with super rudimentary concepts in image processing and a few methods for image segmentation. Nothing related to deep neural networks yet. Deep learning models for object detection and recognition will be discussed in Part 2 and Part 3.\n Disclaimer: When I started, I was using \u0026ldquo;object recognition\u0026rdquo; and \u0026ldquo;object detection\u0026rdquo; interchangeably. I don\u0026rsquo;t think they are the same: the former is more about telling whether an object exists in an image while the latter needs to spot where the object is. However, they are highly related and many object recognition algorithms lay the foundation for detection.\n Links to all the posts in the series: [Part 1] [Part 2] [Part 3] [Part 4].\nImage Gradient Vector First of all, I would like to make sure we can distinguish the following terms. They are very similar, closely related, but not exactly the same.\n Derivative Directional Derivative Gradient Value type Scalar Scalar Vector Definition The rate of change of a function $f(x,y,z,\u0026hellip;)$ at a point $(x_0,y_0,z_0,\u0026hellip;)$, which is the slope of the tangent line at the point. The instantaneous rate of change of $f(x,y,z, \u0026hellip;)$ in the direction of an unit vector $\\vec{u}$. It points in the direction of the greatest rate of increase of the function, containing all the partial derivative information of a multivariable function. In the image processing, we want to know the direction of colors changing from one extreme to the other (i.e. black to white on a grayscale image). Therefore, we want to measure \u0026ldquo;gradient\u0026rdquo; on pixels of colors. The gradient on an image is discrete because each pixel is independent and cannot be further split.\nThe image gradient vector is defined as a metric for every individual pixel, containing the pixel color changes in both x-axis and y-axis. The definition is aligned with the gradient of a continuous multi-variable function, which is a vector of partial derivatives of all the variables. Suppose f(x, y) records the color of the pixel at location (x, y), the gradient vector of the pixel (x, y) is defined as follows:\n $$ \\begin{align*} \\nabla f(x, y) = \\begin{bmatrix} g_x \\\\ g_y \\end{bmatrix} = \\begin{bmatrix} \\frac{\\partial f}{\\partial x} \\\\[6pt] \\frac{\\partial f}{\\partial y} \\end{bmatrix} = \\begin{bmatrix} f(x+1, y) - f(x-1, y)\\\\ f(x, y+1) - f(x, y-1) \\end{bmatrix} \\end{align*} $$ The $\\frac{\\partial f}{\\partial x}$ term is the partial derivative on the x-direction, which is computed as the color difference between the adjacent pixels on the left and right of the target, f(x+1, y) - f(x-1, y). Similarly, the $\\frac{\\partial f}{\\partial y}$ term is the partial derivative on the y-direction, measured as f(x, y+1) - f(x, y-1), the color difference between the adjacent pixels above and below the target.\nThere are two important attributes of an image gradient:\n Magnitude is the L2-norm of the vector, $g = \\sqrt{ g_x^2 + g_y^2 }$. Direction is the arctangent of the ratio between the partial derivatives on two directions, $\\theta = \\arctan{(g_y / g_x)}$. Fig. 1. To compute the gradient vector of a target pixel at location (x, y), we need to know the colors of its four neighbors (or eight surrounding pixels depending on the kernel). The gradient vector of the example in Fig. 1. is:\n $$ \\begin{align*} \\nabla f = \\begin{bmatrix} f(x+1, y) - f(x-1, y)\\\\ f(x, y+1) - f(x, y-1) \\end{bmatrix} = \\begin{bmatrix} 55-105\\\\ 90-40 \\end{bmatrix} = \\begin{bmatrix} -50\\\\ 50 \\end{bmatrix} \\end{align*} $$ Thus,\n the magnitude is $\\sqrt{50^2 + (-50)^2} = 70.7107$, and the direction is $\\arctan{(-50/50)} = -45^{\\circ}$. Repeating the gradient computation process for every pixel iteratively is too slow. Instead, it can be well translated into applying a convolution operator on the entire image matrix, labeled as $\\mathbf{A}$ using one of the specially designed convolutional kernels.\nLet\u0026rsquo;s start with the x-direction of the example in Fig 1. using the kernel $[-1,0,1]$ sliding over the x-axis; $\\ast$ is the convolution operator:\n $$ \\begin{align*} \\mathbf{G}_x \u0026= [-1, 0, 1] \\ast [105, 255, 55] = -105 + 0 + 55 = -50 \\end{align*} $$ Similarly, on the y-direction, we adopt the kernel $[+1, 0, -1]^\\top$:\n $$ \\begin{align*} \\mathbf{G}_y \u0026= [+1, 0, -1]^\\top \\ast \\begin{bmatrix} 90\\\\ 255\\\\ 40 \\end{bmatrix} = 90 + 0 - 40 = 50 \\end{align*} $$ Try this in python:\nimport numpy as np import scipy.signal as sig data = np.array([[0, 105, 0], [40, 255, 90], [0, 55, 0]]) G_x = sig.convolve2d(data, np.array([[-1, 0, 1]]), mode=\u0026#39;valid\u0026#39;) G_y = sig.convolve2d(data, np.array([[-1], [0], [1]]), mode=\u0026#39;valid\u0026#39;) These two functions return array([[0], [-50], [0]]) and array([[0, 50, 0]]) respectively. (Note that in the numpy array representation, 40 is shown in front of 90, so -1 is listed before 1 in the kernel correspondingly.)\nCommon Image Processing Kernels Prewitt operator: Rather than only relying on four directly adjacent neighbors, the Prewitt operator utilizes eight surrounding pixels for smoother results.\n $$ \\mathbf{G}_x = \\begin{bmatrix} -1 \u0026 0 \u0026 +1 \\\\ -1 \u0026 0 \u0026 +1 \\\\ -1 \u0026 0 \u0026 +1 \\end{bmatrix} \\ast \\mathbf{A} \\text{ and } \\mathbf{G}_y = \\begin{bmatrix} +1 \u0026 +1 \u0026 +1 \\\\ 0 \u0026 0 \u0026 0 \\\\ -1 \u0026 -1 \u0026 -1 \\end{bmatrix} \\ast \\mathbf{A} $$ Sobel operator: To emphasize the impact of directly adjacent pixels more, they get assigned with higher weights.\n $$ \\mathbf{G}_x = \\begin{bmatrix} -1 \u0026 0 \u0026 +1 \\\\ -2 \u0026 0 \u0026 +2 \\\\ -1 \u0026 0 \u0026 +1 \\end{bmatrix} \\ast \\mathbf{A} \\text{ and } \\mathbf{G}_y = \\begin{bmatrix} +1 \u0026 +2 \u0026 +1 \\\\ 0 \u0026 0 \u0026 0 \\\\ -1 \u0026 -2 \u0026 -1 \\end{bmatrix} \\ast \\mathbf{A} $$ Different kernels are created for different goals, such as edge detection, blurring, sharpening and many more. Check this wiki page for more examples and references.\nExample: Manu in 2004 Let\u0026rsquo;s run a simple experiment on the photo of Manu Ginobili in 2004 [[Download Image]({{ \u0026lsquo;/assets/data/manu-2004.jpg\u0026rsquo; | relative_url }}){:target=\u0026quot;_blank\u0026quot;}] when he still had a lot of hair. For simplicity, the photo is converted to grayscale first. For colored images, we just need to repeat the same process in each color channel respectively.\nFig. 2. Manu Ginobili in 2004 with hair. (Image source: Manu Ginobili's bald spot through the years) import numpy as np import scipy import scipy.signal as sig # With mode=\u0026#34;L\u0026#34;, we force the image to be parsed in the grayscale, so it is # actually unnecessary to convert the photo color beforehand. img = scipy.misc.imread(\u0026#34;manu-2004.jpg\u0026#34;, mode=\u0026#34;L\u0026#34;) # Define the Sobel operator kernels. kernel_x = np.array([[-1, 0, 1],[-2, 0, 2],[-1, 0, 1]]) kernel_y = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]) G_x = sig.convolve2d(img, kernel_x, mode=\u0026#39;same\u0026#39;) G_y = sig.convolve2d(img, kernel_y, mode=\u0026#39;same\u0026#39;) # Plot them! fig = plt.figure() ax1 = fig.add_subplot(121) ax2 = fig.add_subplot(122) # Actually plt.imshow() can handle the value scale well even if I don\u0026#39;t do # the transformation (G_x + 255) / 2. ax1.imshow((G_x + 255) / 2, cmap=\u0026#39;gray\u0026#39;); ax1.set_xlabel(\u0026#34;Gx\u0026#34;) ax2.imshow((G_y + 255) / 2, cmap=\u0026#39;gray\u0026#39;); ax2.set_xlabel(\u0026#34;Gy\u0026#34;) plt.show() Fig. 3. Apply Sobel operator kernel on the example image. You might notice that most area is in gray. Because the difference between two pixel is between -255 and 255 and we need to convert them back to [0, 255] for the display purpose. A simple linear transformation ($\\mathbf{G}$ + 255)/2 would interpret all the zeros (i.e., constant colored background shows no change in gradient) as 125 (shown as gray).\nHistogram of Oriented Gradients (HOG) The Histogram of Oriented Gradients (HOG) is an efficient way to extract features out of the pixel colors for building an object recognition classifier. With the knowledge of image gradient vectors, it is not hard to understand how HOG works. Let\u0026rsquo;s start!\nHow HOG works Preprocess the image, including resizing and color normalization.\n Compute the gradient vector of every pixel, as well as its magnitude and direction.\n Divide the image into many 8x8 pixel cells. In each cell, the magnitude values of these 64 cells are binned and cumulatively added into 9 buckets of unsigned direction (no sign, so 0-180 degree rather than 0-360 degree; this is a practical choice based on empirical experiments). For better robustness, if the direction of the gradient vector of a pixel lays between two buckets, its magnitude does not all go into the closer one but proportionally split between two. For example, if a pixel\u0026rsquo;s gradient vector has magnitude 8 and degree 15, it is between two buckets for degree 0 and 20 and we would assign 2 to bucket 0 and 6 to bucket 20. This interesting configuration makes the histogram much more stable when small distortion is applied to the image.\n Fig. 4. How to split one gradient vector's magnitude if its degress is between two degree bins. (Image source: https://www.learnopencv.com/histogram-of-oriented-gradients/) Then we slide a 2x2 cells (thus 16x16 pixels) block across the image. In each block region, 4 histograms of 4 cells are concatenated into one-dimensional vector of 36 values and then normalized to have an unit weight. The final HOG feature vector is the concatenation of all the block vectors. It can be fed into a classifier like SVM for learning object recognition tasks. Example: Manu in 2004 Let\u0026rsquo;s reuse the same example image in the previous section. Remember that we have computed $\\mathbf{G}_x$ and $\\mathbf{G}_y$ for the whole image.\nN_BUCKETS = 9 CELL_SIZE = 8 # Each cell is 8x8 pixels BLOCK_SIZE = 2 # Each block is 2x2 cells def assign_bucket_vals(m, d, bucket_vals): left_bin = int(d / 20.) # Handle the case when the direction is between [160, 180) right_bin = (int(d / 20.) + 1) % N_BUCKETS assert 0 \u0026lt;= left_bin \u0026lt; right_bin \u0026lt; N_BUCKETS left_val= m * (right_bin * 20 - d) / 20 right_val = m * (d - left_bin * 20) / 20 bucket_vals[left_bin] += left_val bucket_vals[right_bin] += right_val def get_magnitude_hist_cell(loc_x, loc_y): # (loc_x, loc_y) defines the top left corner of the target cell. cell_x = G_x[loc_x:loc_x + CELL_SIZE, loc_y:loc_y + CELL_SIZE] cell_y = G_y[loc_x:loc_x + CELL_SIZE, loc_y:loc_y + CELL_SIZE] magnitudes = np.sqrt(cell_x * cell_x + cell_y * cell_y) directions = np.abs(np.arctan(cell_y / cell_x) * 180 / np.pi) buckets = np.linspace(0, 180, N_BUCKETS + 1) bucket_vals = np.zeros(N_BUCKETS) map( lambda (m, d): assign_bucket_vals(m, d, bucket_vals), zip(magnitudes.flatten(), directions.flatten()) ) return bucket_vals def get_magnitude_hist_block(loc_x, loc_y): # (loc_x, loc_y) defines the top left corner of the target block. return reduce( lambda arr1, arr2: np.concatenate((arr1, arr2)), [get_magnitude_hist_cell(x, y) for x, y in zip( [loc_x, loc_x + CELL_SIZE, loc_x, loc_x + CELL_SIZE], [loc_y, loc_y, loc_y + CELL_SIZE, loc_y + CELL_SIZE], )] ) The following code simply calls the functions to construct a histogram and plot it.\n# Random location [200, 200] as an example. loc_x = loc_y = 200 ydata = get_magnitude_hist_block(loc_x, loc_y) ydata = ydata / np.linalg.norm(ydata) xdata = range(len(ydata)) bucket_names = np.tile(np.arange(N_BUCKETS), BLOCK_SIZE * BLOCK_SIZE) assert len(ydata) == N_BUCKETS * (BLOCK_SIZE * BLOCK_SIZE) assert len(bucket_names) == len(ydata) plt.figure(figsize=(10, 3)) plt.bar(xdata, ydata, align=\u0026#39;center\u0026#39;, alpha=0.8, width=0.9) plt.xticks(xdata, bucket_names * 20, rotation=90) plt.xlabel(\u0026#39;Direction buckets\u0026#39;) plt.ylabel(\u0026#39;Magnitude\u0026#39;) plt.grid(ls=\u0026#39;--\u0026#39;, color=\u0026#39;k\u0026#39;, alpha=0.1) plt.title(\u0026#34;HOG of block at [%d, %d]\u0026#34; % (loc_x, loc_y)) plt.tight_layout() In the code above, I use the block with top left corner located at [200, 200] as an example and here is the final normalized histogram of this block. You can play with the code to change the block location to be identified by a sliding window.\nFig. 5. Demonstration of a HOG histogram for one block. The code is mostly for demonstrating the computation process. There are many off-the-shelf libraries with HOG algorithm implemented, such as OpenCV, SimpleCV and scikit-image.\nImage Segmentation (Felzenszwalb\u0026rsquo;s Algorithm) When there exist multiple objects in one image (true for almost every real-world photos), we need to identify a region that potentially contains a target object so that the classification can be executed more efficiently.\nFelzenszwalb and Huttenlocher (2004) proposed an algorithm for segmenting an image into similar regions using a graph-based approach. It is also the initialization method for Selective Search (a popular region proposal algorithm) that we are gonna discuss later.\nSay, we use a undirected graph $G=(V, E)$ to represent an input image. One vertex $v_i \\in V$ represents one pixel. One edge $e = (v_i, v_j) \\in E$ connects two vertices $v_i$ and $v_j$. Its associated weight $w(v_i, v_j)$ measures the dissimilarity between $v_i$ and $v_j$. The dissimilarity can be quantified in dimensions like color, location, intensity, etc. The higher the weight, the less similar two pixels are. A segmentation solution $S$ is a partition of $V$ into multiple connected components, $\\{C\\}$. Intuitively similar pixels should belong to the same components while dissimilar ones are assigned to different components.\nGraph Construction There are two approaches to constructing a graph out of an image.\n Grid Graph: Each pixel is only connected with surrounding neighbours (8 other cells in total). The edge weight is the absolute difference between the intensity values of the pixels. Nearest Neighbor Graph: Each pixel is a point in the feature space (x, y, r, g, b), in which (x, y) is the pixel location and (r, g, b) is the color values in RGB. The weight is the Euclidean distance between two pixels' feature vectors. Key Concepts Before we lay down the criteria for a good graph partition (aka image segmentation), let us define a couple of key concepts:\n Internal difference: $Int(C) = \\max_{e\\in MST(C, E)} w(e)$, where $MST$ is the minimum spanning tree of the components. A component $C$ can still remain connected even when we have removed all the edges with weights \u0026lt; $Int(C)$. Difference between two components: $Dif(C_1, C_2) = \\min_{v_i \\in C_1, v_j \\in C_2, (v_i, v_j) \\in E} w(v_i, v_j)$. $Dif(C_1, C_2) = \\infty$ if there is no edge in-between. Minimum internal difference: $MInt(C_1, C_2) = min(Int(C_1) + \\tau(C_1), Int(C_2) + \\tau(C_2))$, where $\\tau(C) = k / \\vert C \\vert$ helps make sure we have a meaningful threshold for the difference between components. With a higher $k$, it is more likely to result in larger components. The quality of a segmentation is assessed by a pairwise region comparison predicate defined for given two regions $C_1$ and $C_2$:\n $$ D(C_1, C_2) = \\begin{cases} \\text{True} \u0026 \\text{ if } Dif(C_1, C_2) MInt(C_1, C_2) \\\\ \\text{False} \u0026 \\text{ otherwise} \\end{cases} $$ Only when the predicate holds True, we consider them as two independent components; otherwise the segmentation is too fine and they probably should be merged.\nHow Image Segmentation Works The algorithm follows a bottom-up procedure. Given $G=(V, E)$ and $|V|=n, |E|=m$:\n Edges are sorted by weight in ascending order, labeled as $e_1, e_2, \\dots, e_m$. Initially, each pixel stays in its own component, so we start with $n$ components. Repeat for $k=1, \\dots, m$: The segmentation snapshot at the step $k$ is denoted as $S^k$. We take the k-th edge in the order, $e_k = (v_i, v_j)$. If $v_i$ and $v_j$ belong to the same component, do nothing and thus $S^k = S^{k-1}$. If $v_i$ and $v_j$ belong to two different components $C_i^{k-1}$ and $C_j^{k-1}$ as in the segmentation $S^{k-1}$, we want to merge them into one if $w(v_i, v_j) \\leq MInt(C_i^{k-1}, C_j^{k-1})$; otherwise do nothing. If you are interested in the proof of the segmentation properties and why it always exists, please refer to the paper.\nFig. 6. An indoor scene with segmentation detected by the grid graph construction in Felzenszwalb's graph-based segmentation algorithm (k=300). Example: Manu in 2013 This time I would use the photo of old Manu Ginobili in 2013 [[Image]({{ \u0026lsquo;/assets/data/manu-2013.jpg\u0026rsquo; | relative_url }})] as the example image when his bald spot has grown up strong. Still for simplicity, we use the picture in grayscale.\nFig. 7. Manu Ginobili in 2013 with bald spot. (Image source: Manu Ginobili's bald spot through the years) Rather than coding from scratch, let us apply skimage.segmentation.felzenszwalb to the image.\nimport skimage.segmentation from matplotlib import pyplot as plt img2 = scipy.misc.imread(\u0026#34;manu-2013.jpg\u0026#34;, mode=\u0026#34;L\u0026#34;) segment_mask1 = skimage.segmentation.felzenszwalb(img2, scale=100) segment_mask2 = skimage.segmentation.felzenszwalb(img2, scale=1000) fig = plt.figure(figsize=(12, 5)) ax1 = fig.add_subplot(121) ax2 = fig.add_subplot(122) ax1.imshow(segment_mask1); ax1.set_xlabel(\u0026#34;k=100\u0026#34;) ax2.imshow(segment_mask2); ax2.set_xlabel(\u0026#34;k=1000\u0026#34;) fig.suptitle(\u0026#34;Felsenszwalb\u0026#39;s efficient graph based image segmentation\u0026#34;) plt.tight_layout() plt.show() The code ran two versions of Felzenszwalb\u0026rsquo;s algorithms as shown in Fig. 8. The left k=100 generates a finer-grained segmentation with small regions where Manu\u0026rsquo;s bald spot is identified. The right one k=1000 outputs a coarser-grained segmentation where regions tend to be larger.\nFig. 8. Felsenszwalb's efficient graph-based image segmentation is applied on the photo of Manu in 2013. Selective Search Selective search is a common algorithm to provide region proposals that potentially contain objects. It is built on top of the image segmentation output and use region-based characteristics (NOTE: not just attributes of a single pixel) to do a bottom-up hierarchical grouping.\nHow Selective Search Works At the initialization stage, apply Felzenszwalb and Huttenlocher\u0026rsquo;s graph-based image segmentation algorithm to create regions to start with. Use a greedy algorithm to iteratively group regions together: First the similarities between all neighbouring regions are calculated. The two most similar regions are grouped together, and new similarities are calculated between the resulting region and its neighbours. The process of grouping the most similar regions (Step 2) is repeated until the whole image becomes a single region. Fig. 9. The detailed algorithm of Selective Search. Configuration Variations Given two regions $(r_i, r_j)$, selective search proposed four complementary similarity measures:\n Color similarity Texture: Use algorithm that works well for material recognition such as SIFT. Size: Small regions are encouraged to merge early. Shape: Ideally one region can fill the gap of the other. By (i) tuning the threshold $k$ in Felzenszwalb and Huttenlocher\u0026rsquo;s algorithm, (ii) changing the color space and (iii) picking different combinations of similarity metrics, we can produce a diverse set of Selective Search strategies. The version that produces the region proposals with best quality is configured with (i) a mixture of various initial segmentation proposals, (ii) a blend of multiple color spaces and (iii) a combination of all similarity measures. Unsurprisingly we need to balance between the quality (the model complexity) and the speed.\n Cited as:\n@article{weng2017detection1, title = \u0026quot;Object Detection for Dummies Part 1: Gradient Vector, HOG, and SS\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2017\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2017-10-29-object-recognition-part-1/\u0026quot; } References [1] Dalal, Navneet, and Bill Triggs. \u0026ldquo;Histograms of oriented gradients for human detection.\u0026quot; Computer Vision and Pattern Recognition (CVPR), 2005.\n[2] Pedro F. Felzenszwalb, and Daniel P. Huttenlocher. \u0026ldquo;Efficient graph-based image segmentation.\u0026quot; Intl. journal of computer vision 59.2 (2004): 167-181.\n[3] Histogram of Oriented Gradients by Satya Mallick\n[4] Gradient Vectors by Chris McCormick\n[5] HOG Person Detector Tutorial by Chris McCormick\n","permalink":"https://lilianweng.github.io/posts/2017-10-29-object-recognition-part-1/","summary":"I\u0026rsquo;ve never worked in the field of computer vision and has no idea how the magic could work when an autonomous car is configured to tell apart a stop sign from a pedestrian in a red hat. To motivate myself to look into the maths behind object recognition and detection algorithms, I\u0026rsquo;m writing a few posts on this topic \u0026ldquo;Object Detection for Dummies\u0026rdquo;. This post, part 1, starts with super rudimentary concepts in image processing and a few methods for image segmentation.","title":"Object Detection for Dummies Part 1: Gradient Vector, HOG, and SS"},{"content":"Human vocabulary comes in free text. In order to make a machine learning model understand and process the natural language, we need to transform the free-text words into numeric values. One of the simplest transformation approaches is to do a one-hot encoding in which each distinct word stands for one dimension of the resulting vector and a binary value indicates whether the word presents (1) or not (0).\nHowever, one-hot encoding is impractical computationally when dealing with the entire vocabulary, as the representation demands hundreds of thousands of dimensions. Word embedding represents words and phrases in vectors of (non-binary) numeric values with much lower and thus denser dimensions. An intuitive assumption for good word embedding is that they can approximate the similarity between words (i.e., \u0026ldquo;cat\u0026rdquo; and \u0026ldquo;kitten\u0026rdquo; are similar words, and thus they are expected to be close in the reduced vector space) or disclose hidden semantic relationships (i.e., the relationship between \u0026ldquo;cat\u0026rdquo; and \u0026ldquo;kitten\u0026rdquo; is an analogy to the one between \u0026ldquo;dog\u0026rdquo; and \u0026ldquo;puppy\u0026rdquo;). Contextual information is super useful for learning word meaning and relationship, as similar words may appear in the similar context often.\nThere are two main approaches for learning word embedding, both relying on the contextual knowledge.\n Count-based: The first one is unsupervised, based on matrix factorization of a global word co-occurrence matrix. Raw co-occurrence counts do not work well, so we want to do smart things on top. Context-based: The second approach is supervised. Given a local context, we want to design a model to predict the target words and in the meantime, this model learns the efficient word embedding representation. Count-Based Vector Space Model Count-based vector space models heavily rely on the word frequency and co-occurrence matrix with the assumption that words in the same contexts share similar or related semantic meanings. The models map count-based statistics like co-occurrences between neighboring words down to a small and dense word vectors. PCA, topic models, and neural probabilistic language models are all good examples of this category.\n Different from the count-based approaches, context-based methods build predictive models that directly target at predicting a word given its neighbors. The dense word vectors are part of the model parameters. The best vector representation of each word is learned during the model training process.\nContext-Based: Skip-Gram Model Suppose that you have a sliding window of a fixed size moving along a sentence: the word in the middle is the \u0026ldquo;target\u0026rdquo; and those on its left and right within the sliding window are the context words. The skip-gram model (Mikolov et al., 2013) is trained to predict the probabilities of a word being a context word for the given target.\nThe following example demonstrates multiple pairs of target and context words as training samples, generated by a 5-word window sliding along the sentence.\n \u0026ldquo;The man who passes the sentence should swing the sword.\u0026rdquo; \u0026ndash; Ned Stark\n Sliding window (size = 5) Target word Context [The man who] the man, who [The man who passes] man the, who, passes [The man who passes the] who the, man, passes, the [man who passes the sentence] passes man, who, the, sentence \u0026hellip; \u0026hellip; \u0026hellip; [sentence should swing the sword] swing sentence, should, the, sword [should swing the sword] the should, swing, sword [swing the sword] sword swing, the {:.info} Each context-target pair is treated as a new observation in the data. For example, the target word \u0026ldquo;swing\u0026rdquo; in the above case produces four training samples: (\u0026ldquo;swing\u0026rdquo;, \u0026ldquo;sentence\u0026rdquo;), (\u0026ldquo;swing\u0026rdquo;, \u0026ldquo;should\u0026rdquo;), (\u0026ldquo;swing\u0026rdquo;, \u0026ldquo;the\u0026rdquo;), and (\u0026ldquo;swing\u0026rdquo;, \u0026ldquo;sword\u0026rdquo;).\nFig. 1. The skip-gram model. Both the input vector $\\mathbf{x}$ and the output $\\mathbf{y}$ are one-hot encoded word representations. The hidden layer is the word embedding of size $N$. Given the vocabulary size $V$, we are about to learn word embedding vectors of size $N$. The model learns to predict one context word (output) using one target word (input) at a time.\nAccording to Fig. 1,\n Both input word $w_i$ and the output word $w_j$ are one-hot encoded into binary vectors $\\mathbf{x}$ and $\\mathbf{y}$ of size $V$. First, the multiplication of the binary vector $\\mathbf{x}$ and the word embedding matrix $W$ of size $V \\times N$ gives us the embedding vector of the input word $w_i$: the i-th row of the matrix $W$. This newly discovered embedding vector of dimension $N$ forms the hidden layer. The multiplication of the hidden layer and the word context matrix $W’$ of size $N \\times V$ produces the output one-hot encoded vector $\\mathbf{y}$. The output context matrix $W’$ encodes the meanings of words as context, different from the embedding matrix $W$. NOTE: Despite the name, $W’$ is independent of $W$, not a transpose or inverse or whatsoever. Context-Based: Continuous Bag-of-Words (CBOW) The Continuous Bag-of-Words (CBOW) is another similar model for learning word vectors. It predicts the target word (i.e. \u0026ldquo;swing\u0026rdquo;) from source context words (i.e., \u0026ldquo;sentence should the sword\u0026rdquo;).\nFig. 2. The CBOW model. Word vectors of multiple context words are averaged to get a fixed-length vector as in the hidden layer. Other symbols have the same meanings as in Fig 1. Because there are multiple contextual words, we average their corresponding word vectors, constructed by the multiplication of the input vector and the matrix $W$. Because the averaging stage smoothes over a lot of the distributional information, some people believe the CBOW model is better for small dataset.\nLoss Functions Both the skip-gram model and the CBOW model should be trained to minimize a well-designed loss/objective function. There are several loss functions we can incorporate to train these language models. In the following discussion, we will use the skip-gram model as an example to describe how the loss is computed.\nFull Softmax The skip-gram model defines the embedding vector of every word by the matrix $W$ and the context vector by the output matrix $W'$. Given an input word $w_I$, let us label the corresponding row of $W$ as vector $v_{w_I}$ (embedding vector) and its corresponding column of $W'$ as $v'_{w_I}$ (context vector). The final output layer applies softmax to compute the probability of predicting the output word $w_O$ given $w_I$, and therefore:\n $$ p(w_O \\vert w_I) = \\frac{\\exp({v'_{w_O}}^{\\top} v_{w_I})}{\\sum_{i=1}^V \\exp({v'_{w_i}}^{\\top} v_{w_I})} $$ This is accurate as presented in Fig. 1. However, when $V$ is extremely large, calculating the denominator by going through all the words for every single sample is computationally impractical. The demand for more efficient conditional probability estimation leads to the new methods like hierarchical softmax.\nHierarchical Softmax Morin and Bengio (2005) proposed hierarchical softmax to make the sum calculation faster with the help of a binary tree structure. The hierarchical softmax encodes the language model\u0026rsquo;s output softmax layer into a tree hierarchy, where each leaf is one word and each internal node stands for relative probabilities of the children nodes.\nFig. 3. An illustration of the hierarchical softmax binary tree. The leaf nodes in white are words in the vocabulary. The gray inner nodes carry information on the probabilities of reaching its child nodes. One path starting from the root to the leaf $w\\_i$. $n(w\\_i, j)$ denotes the j-th node on this path. (Image source: word2vec Parameter Learning Explained) Each word $w_i$ has a unique path from the root down to its corresponding leaf. The probability of picking this word is equivalent to the probability of taking this path from the root down through the tree branches. Since we know the embedding vector $v_n$ of the internal node $n$, the probability of getting the word can be computed by the product of taking left or right turn at every internal node stop.\nAccording to Fig. 3, the probability of one node is ($\\sigma$ is the sigmoid function):\n $$ \\begin{align} p(\\text{turn right} \\to \\dots w_I \\vert n) \u0026= \\sigma({v'_n}^{\\top} v_{w_I})\\\\ p(\\text{turn left } \\to \\dots w_I \\vert n) \u0026= 1 - p(\\text{turn right} \\vert n) = \\sigma(-{v'_n}^{\\top} v_{w_I}) \\end{align} $$ The final probability of getting a context word $w_O$ given an input word $w_I$ is:\n $$ p(w_O \\vert w_I) = \\prod_{k=1}^{L(w_O)} \\sigma(\\mathbb{I}_{\\text{turn}}(n(w_O, k), n(w_O, k+1)) \\cdot {v'_{n(w_O, k)}}^{\\top} v_{w_I}) $$ where $L(w_O)$ is the depth of the path leading to the word $w_O$ and $\\mathbb{I}_{\\text{turn}}$ is a specially indicator function which returns 1 if $n(w_O, k+1)$ is the left child of $n(w_O, k)$ otherwise -1. The internal nodes' embeddings are learned during the model training. The tree structure helps greatly reduce the complexity of the denominator estimation from O(V) (vocabulary size) to O(log V) (the depth of the tree) at the training time. However, at the prediction time, we still to compute the probability of every word and pick the best, as we don\u0026rsquo;t know which leaf to reach for in advance.\nA good tree structure is crucial to the model performance. Several handy principles are: group words by frequency like what is implemented by Huffman tree for simple speedup; group similar words into same or close branches (i.e. use predefined word clusters, WordNet).\nCross Entropy Another approach completely steers away from the softmax framework. Instead, the loss function measures the cross entropy between the predicted probabilities $p$ and the true binary labels $\\mathbf{y}$.\nFirst, let\u0026rsquo;s recall that the cross entropy between two distributions $p$ and $q$ is measured as $ H(p, q) = -\\sum_x p(x) \\log q(x) $. In our case, the true label $y_i$ is 1 only when $w_i$ is the output word; $y_j$ is 0 otherwise. The loss function $\\mathcal{L}_\\theta$ of the model with parameter config $\\theta$ aims to minimize the cross entropy between the prediction and the ground truth, as lower cross entropy indicates high similarity between two distributions.\n $$ \\mathcal{L}_\\theta = - \\sum_{i=1}^V y_i \\log p(w_i | w_I) = - \\log p(w_O \\vert w_I) $$ Recall that,\n $$ p(w_O \\vert w_I) = \\frac{\\exp({v'_{w_O}}^{\\top} v_{w_I})}{\\sum_{i=1}^V \\exp({v'_{w_i}}^{\\top} v_{w_I})} $$ Therefore,\n $$ \\mathcal{L}_{\\theta} = - \\log \\frac{\\exp({v'_{w_O}}^{\\top}{v_{w_I}})}{\\sum_{i=1}^V \\exp({v'_{w_i}}^{\\top}{v_{w_I} })} = - {v'_{w_O}}^{\\top}{v_{w_I} } + \\log \\sum_{i=1}^V \\exp({v'_{w_i} }^{\\top}{v_{w_I}}) $$ To start training the model using back-propagation with SGD, we need to compute the gradient of the loss function. For simplicity, let\u0026rsquo;s label $z_{IO} = {v'_{w_O}}^{\\top}{v_{w_I}}$.\n $$ \\begin{align} \\nabla_\\theta \\mathcal{L}_{\\theta} \u0026= \\nabla_\\theta\\big( - z_{IO} + \\log \\sum_{i=1}^V e^{z_{Ii}} \\big) \\\\ \u0026= - \\nabla_\\theta z_{IO} + \\nabla_\\theta \\big( \\log \\sum_{i=1}^V e^{z_{Ii}} \\big) \\\\ \u0026= - \\nabla_\\theta z_{IO} + \\frac{1}{\\sum_{i=1}^V e^{z_{Ii}}} \\sum_{i=1}^V e^{z_{Ii}} \\nabla_\\theta z_{Ii} \\\\ \u0026= - \\nabla_\\theta z_{IO} + \\sum_{i=1}^V \\frac{e^{z_{Ii}}}{\\sum_{i=1}^V e^{z_{Ii}}} \\nabla_\\theta z_{Ii} \\\\ \u0026= - \\nabla_\\theta z_{IO} + \\sum_{i=1}^V p(w_i \\vert w_I) \\nabla_\\theta z_{Ii} \\\\ \u0026= - \\nabla_\\theta z_{IO} + \\mathbb{E}_{w_i \\sim Q(\\tilde{w})} \\nabla_\\theta z_{Ii} \\end{align} $$ where $Q(\\tilde{w})$ is the distribution of noise samples.\nAccording to the formula above, the correct output word has a positive reinforcement according to the first term (the larger $\\nabla_\\theta z_{IO}$ the better loss we have), while other words have a negative impact as captured by the second term.\nHow to estimate $\\mathbb{E}_{w_i \\sim Q(\\tilde{w})} \\nabla_\\theta {v'_{w_i}}^{\\top}{v_{w_I}}$ with a sample set of noise words rather than scanning through the entire vocabulary is the key of using cross-entropy-based sampling approach.\nNoise Contrastive Estimation (NCE) The Noise Contrastive Estimation (NCE) metric intends to differentiate the target word from noise samples using a logistic regression classifier (Gutmann and Hyvärinen, 2010).\nGiven an input word $w_I$, the correct output word is known as $w$. In the meantime, we sample $N$ other words from the noise sample distribution $Q$, denoted as $\\tilde{w}_1, \\tilde{w}_2, \\dots, \\tilde{w}_N \\sim Q$. Let\u0026rsquo;s label the decision of the binary classifier as $d$ and $d$$ can only take a binary value.\n $$ \\mathcal{L}_\\theta = - [ \\log p(d=1 \\vert w, w_I) + \\sum_{i=1, \\tilde{w}_i \\sim Q}^N \\log p(d=0|\\tilde{w}_i, w_I) ] $$ When $N$ is big enough, according to the Law of large numbers,\n $$ \\mathcal{L}_\\theta = - [ \\log p(d=1 \\vert w, w_I) + N\\mathbb{E}_{\\tilde{w}_i \\sim Q} \\log p(d=0|\\tilde{w}_i, w_I)] $$ To compute the probability $p(d=1 \\vert w, w_I)$, we can start with the joint probability $p(d, w \\vert w_I)$. Among $w, \\tilde{w}_1, \\tilde{w}_2, \\dots, \\tilde{w}_N$, we have 1 out of (N+1) chance to pick the true word $w$, which is sampled from the conditional probability $p(w \\vert w_I)$; meanwhile, we have N out of (N+1) chances to pick a noise word, each sampled from $q(\\tilde{w}) \\sim Q$. Thus,\n $$ p(d, w | w_I) = \\begin{cases} \\frac{1}{N+1} p(w \\vert w_I) \u0026 \\text{if } d=1 \\\\ \\frac{N}{N+1} q(\\tilde{w}) \u0026 \\text{if } d=0 \\end{cases} $$ Then we can figure out $p(d=1 \\vert w, w_I)$ and $p(d=0 \\vert w, w_I)$:\n $$ \\begin{align} p(d=1 \\vert w, w_I) \u0026= \\frac{p(d=1, w \\vert w_I)}{p(d=1, w \\vert w_I) + p(d=0, w \\vert w_I)} \u0026= \\frac{p(w \\vert w_I)}{p(w \\vert w_I) + Nq(\\tilde{w})} \\end{align} $$ $$ \\begin{align} p(d=0 \\vert w, w_I) \u0026= \\frac{p(d=0, w \\vert w_I)}{p(d=1, w \\vert w_I) + p(d=0, w \\vert w_I)} \u0026= \\frac{Nq(\\tilde{w})}{p(w \\vert w_I) + Nq(\\tilde{w})} \\end{align} $$ Finally the loss function of NCE\u0026rsquo;s binary classifier becomes:\n $$ \\begin{align} \\mathcal{L}_\\theta \u0026 = - [ \\log p(d=1 \\vert w, w_I) + \\sum_{\\substack{i=1 \\\\ \\tilde{w}_i \\sim Q}}^N \\log p(d=0|\\tilde{w}_i, w_I)] \\\\ \u0026 = - [ \\log \\frac{p(w \\vert w_I)}{p(w \\vert w_I) + Nq(\\tilde{w})} + \\sum_{\\substack{i=1 \\\\ \\tilde{w}_i \\sim Q}}^N \\log \\frac{Nq(\\tilde{w}_i)}{p(w \\vert w_I) + Nq(\\tilde{w}_i)}] \\end{align} $$ However, $p(w \\vert w_I)$ still involves summing up the entire vocabulary in the denominator. Let’s label the denominator as a partition function of the input word, $Z(w_I)$. A common assumption is $Z(w) \\approx 1$ given that we expect the softmax output layer to be normalized (Minh and Teh, 2012). Then the loss function is simplified to:\n $$ \\mathcal{L}_\\theta = - [ \\log \\frac{\\exp({v'_w}^{\\top}{v_{w_I}})}{\\exp({v'_w}^{\\top}{v_{w_I}}) + Nq(\\tilde{w})} + \\sum_{\\substack{i=1 \\\\ \\tilde{w}_i \\sim Q}}^N \\log \\frac{Nq(\\tilde{w}_i)}{\\exp({v'_w}^{\\top}{v_{w_I}}) + Nq(\\tilde{w}_i)}] $$ The noise distribution $Q$ is a tunable parameter and we would like to design it in a way so that:\n intuitively it should be very similar to the real data distribution; and it should be easy to sample from. For example, the sampling implementation (log_uniform_candidate_sampler) of NCE loss in tensorflow assumes that such noise samples follow a log-uniform distribution, also known as Zipfian’s law. The probability of a given word in logarithm is expected to be reversely proportional to its rank, while high-frequency words are assigned with lower ranks. In this case, $q(\\tilde{w}) = \\frac{1}{ \\log V}(\\log (r_{\\tilde{w}} + 1) - \\log r_{\\tilde{w}})$, where $r_{\\tilde{w}} \\in [1, V]$ is the rank of a word by frequency in descending order.\nNegative Sampling (NEG) The Negative Sampling (NEG) proposed by Mikolov et al. (2013) is a simplified variation of NCE loss. It is especially famous for training Google\u0026rsquo;s word2vec project. Different from NCE Loss which attempts to approximately maximize the log probability of the softmax output, negative sampling did further simplification because it focuses on learning high-quality word embedding rather than modeling the word distribution in natural language.\nNEG approximates the binary classifier\u0026rsquo;s output with sigmoid functions as follows:\n $$ \\begin{align} p(d=1 \\vert w_, w_I) \u0026= \\sigma({v'_{w}}^\\top v_{w_I}) \\\\ p(d=0 \\vert w, w_I) \u0026= 1 - \\sigma({v'_{w}}^\\top v_{w_I}) = \\sigma(-{v'_{w}}^\\top v_{w_I}) \\end{align} $$ The final NCE loss function looks like:\n $$ \\mathcal{L}_\\theta = - [ \\log \\sigma({v'_{w}}^\\top v_{w_I}) + \\sum_{\\substack{i=1 \\\\ \\tilde{w}_i \\sim Q}}^N \\log \\sigma(-{v'_{\\tilde{w}_i}}^\\top v_{w_I})] $$ Other Tips for Learning Word Embedding Mikolov et al. (2013) suggested several helpful practices that could result in good word embedding learning outcomes.\n Soft sliding window. When pairing the words within the sliding window, we could assign less weight to more distant words. One heuristic is \u0026mdash; given a maximum window size parameter defined, $s_{\\text{max}}$, the actual window size is randomly sampled between 1 and $s_{\\text{max}}$ for every training sample. Thus, each context word has the probability of 1/(its distance to the target word) being observed, while the adjacent words are always observed.\n Subsampling frequent words. Extremely frequent words might be too general to differentiate the context (i.e. think about stopwords). While on the other hand, rare words are more likely to carry distinct information. To balance the frequent and rare words, Mikolov et al. proposed to discard words $w$ with probability $1-\\sqrt{t/f(w)}$ during sampling. Here $f(w)$ is the word frequency and $t$ is an adjustable threshold.\n Learning phrases first. A phrase often stands as a conceptual unit, rather than a simple composition of individual words. For example, we cannot really tell \u0026ldquo;New York\u0026rdquo; is a city name even we know the meanings of \u0026ldquo;new\u0026rdquo; and \u0026ldquo;york\u0026rdquo;. Learning such phrases first and treating them as word units before training the word embedding model improves the outcome quality. A simple data-driven approach is based on unigram and bigram counts: $s_{\\text{phrase}} = \\frac{C(w_i w_j) - \\delta}{ C(w_i)C(w_j)}$, where $C(.)$ is simple count of an unigram $w_i$ or bigram $w_i w_j$ and $\\delta$ is a discounting threshold to prevent super infrequent words and phrases. Higher scores indicate higher chances of being phrases. To form phrases longer than two words, we can scan the vocabulary multiple times with decreasing score cutoff values.\n GloVe: Global Vectors The Global Vector (GloVe) model proposed by Pennington et al. (2014) aims to combine the count-based matrix factorization and the context-based skip-gram model together.\nWe all know the counts and co-occurrences can reveal the meanings of words. To distinguish from $p(w_O \\vert w_I)$ in the context of a word embedding word, we would like to define the co-ocurrence probability as:\n $$ p_{\\text{co}}(w_k \\vert w_i) = \\frac{C(w_i, w_k)}{C(w_i)} $$ $C(w_i, w_k)$ counts the co-occurrence between words $w_i$ and $w_k$.\nSay, we have two words, $w_i$=\u0026ldquo;ice\u0026rdquo; and $w_j$=\u0026ldquo;steam\u0026rdquo;. The third word $\\tilde{w}_k$=\u0026ldquo;solid\u0026rdquo; is related to \u0026ldquo;ice\u0026rdquo; but not \u0026ldquo;steam\u0026rdquo;, and thus we expect $p_{\\text{co}}(\\tilde{w}_k \\vert w_i)$ to be much larger than $p_{\\text{co}}(\\tilde{w}_k \\vert w_j)$ and therefore $\\frac{p_{\\text{co}}(\\tilde{w}_k \\vert w_i)}{p_{\\text{co}}(\\tilde{w}_k \\vert w_j)}$ to be very large. If the third word $\\tilde{w}_k$ = \u0026ldquo;water\u0026rdquo; is related to both or $\\tilde{w}_k$ = \u0026ldquo;fashion\u0026rdquo; is unrelated to either of them, $\\frac{p_{\\text{co}}(\\tilde{w}_k \\vert w_i)}{p_{\\text{co}}(\\tilde{w}_k \\vert w_j)}$ is expected to be close to one.\nThe intuition here is that the word meanings are captured by the ratios of co-occurrence probabilities rather than the probabilities themselves. The global vector models the relationship between two words regarding to the third context word as:\n $$ F(w_i, w_j, \\tilde{w}_k) = \\frac{p_{\\text{co}}(\\tilde{w}_k \\vert w_i)}{p_{\\text{co}}(\\tilde{w}_k \\vert w_j)} $$ Further, since the goal is to learn meaningful word vectors, $F$ is designed to be a function of the linear difference between two words $w_i - w_j$:\n $$ F((w_i - w_j)^\\top \\tilde{w}_k) = \\frac{p_{\\text{co}}(\\tilde{w}_k \\vert w_i)}{p_{\\text{co}}(\\tilde{w}_k \\vert w_j)} $$ With the consideration of $F$ being symmetric between target words and context words, the final solution is to model $F$ as an exponential function. Please read the original paper (Pennington et al., 2014) for more details of the equations.\n $$ \\begin{align} F({w_i}^\\top \\tilde{w}_k) \u0026= \\exp({w_i}^\\top \\tilde{w}_k) = p_{\\text{co}}(\\tilde{w}_k \\vert w_i) \\\\ F((w_i - w_j)^\\top \\tilde{w}_k) \u0026= \\exp((w_i - w_j)^\\top \\tilde{w}_k) = \\frac{\\exp(w_i^\\top \\tilde{w}_k)}{\\exp(w_j^\\top \\tilde{w}_k)} = \\frac{p_{\\text{co}}(\\tilde{w}_k \\vert w_i)}{p_{\\text{co}}(\\tilde{w}_k \\vert w_j)} \\end{align} $$ Finally,\n $$ {w_i}^\\top \\tilde{w}_k = \\log p_{\\text{co}}(\\tilde{w}_k \\vert w_i) = \\log \\frac{C(w_i, \\tilde{w}_k)}{C(w_i)} = \\log C(w_i, \\tilde{w}_k) - \\log C(w_i) $$ Since the second term $-\\log C(w_i)$ is independent of $k$, we can add bias term $b_i$ for $w_i$ to capture $-\\log C(w_i)$. To keep the symmetric form, we also add in a bias $\\tilde{b}_k$ for $\\tilde{w}_k$.\n $$ \\log C(w_i, \\tilde{w}_k) = {w_i}^\\top \\tilde{w}_k + b_i + \\tilde{b}_k $$ The loss function for the GloVe model is designed to preserve the above formula by minimizing the sum of the squared errors:\n $$ \\mathcal{L}_\\theta = \\sum_{i=1, j=1}^V f(C(w_i,w_j)) ({w_i}^\\top \\tilde{w}_j + b_i + \\tilde{b}_j - \\log C(w_i, \\tilde{w}_j))^2 $$ The weighting schema $f(c)$ is a function of the co-occurrence of $w_i$ and $w_j$ and it is an adjustable model configuration. It should be close to zero as $c \\to 0$; should be non-decreasing as higher co-occurrence should have more impact; should saturate when $c$ become extremely large. The paper proposed the following weighting function.\n $$ f(c) = \\begin{cases} (\\frac{c}{c_{\\max}})^\\alpha \u0026 \\text{if } c Examples: word2vec on \u0026ldquo;Game of Thrones\u0026rdquo; After reviewing all the theoretical knowledge above, let\u0026rsquo;s try a little experiment in word embedding extracted from \u0026ldquo;the Games of Thrones corpus\u0026rdquo;. The process is super straightforward using gensim.\nStep 1: Extract words\nimport sys from nltk.corpus import stopwords from nltk.tokenize import sent_tokenize STOP_WORDS = set(stopwords.words(\u0026#39;english\u0026#39;)) def get_words(txt): return filter( lambda x: x not in STOP_WORDS, re.findall(r\u0026#39;\\b(\\w+)\\b\u0026#39;, txt) ) def parse_sentence_words(input_file_names): \u0026#34;\u0026#34;\u0026#34;Returns a list of a list of words. Each sublist is a sentence.\u0026#34;\u0026#34;\u0026#34; sentence_words = [] for file_name in input_file_names: for line in open(file_name): line = line.strip().lower() line = line.decode(\u0026#39;unicode_escape\u0026#39;).encode(\u0026#39;ascii\u0026#39;,\u0026#39;ignore\u0026#39;) sent_words = map(get_words, sent_tokenize(line)) sent_words = filter(lambda sw: len(sw) \u0026gt; 1, sent_words) if len(sent_words) \u0026gt; 1: sentence_words += sent_words return sentence_words # You would see five .txt files after unzip \u0026#39;a_song_of_ice_and_fire.zip\u0026#39; input_file_names = [\u0026#34;001ssb.txt\u0026#34;, \u0026#34;002ssb.txt\u0026#34;, \u0026#34;003ssb.txt\u0026#34;, \u0026#34;004ssb.txt\u0026#34;, \u0026#34;005ssb.txt\u0026#34;] GOT_SENTENCE_WORDS= parse_sentence_words(input_file_names) Step 2: Feed a word2vec model\nfrom gensim.models import Word2Vec # size: the dimensionality of the embedding vectors. # window: the maximum distance between the current and predicted word within a sentence. model = Word2Vec(GOT_SENTENCE_WORDS, size=128, window=3, min_count=5, workers=4) model.wv.save_word2vec_format(\u0026#34;got_word2vec.txt\u0026#34;, binary=False) Step 3: Check the results\nIn the GoT word embedding space, the top similar words to \u0026ldquo;king\u0026rdquo; and \u0026ldquo;queen\u0026rdquo; are:\n model.most_similar('king', topn=10) (word, similarity with \u0026lsquo;king\u0026rsquo;) model.most_similar('queen', topn=10) (word, similarity with \u0026lsquo;queen\u0026rsquo;) (\u0026lsquo;kings\u0026rsquo;, 0.897245) (\u0026lsquo;cersei\u0026rsquo;, 0.942618) (\u0026lsquo;baratheon\u0026rsquo;, 0.809675) (\u0026lsquo;joffrey\u0026rsquo;, 0.933756) (\u0026lsquo;son\u0026rsquo;, 0.763614) (\u0026lsquo;margaery\u0026rsquo;, 0.931099) (\u0026lsquo;robert\u0026rsquo;, 0.708522) (\u0026lsquo;sister\u0026rsquo;, 0.928902) (\u0026lsquo;lords\u0026rsquo;, 0.698684) (\u0026lsquo;prince\u0026rsquo;, 0.927364) (\u0026lsquo;joffrey\u0026rsquo;, 0.696455) (\u0026lsquo;uncle\u0026rsquo;, 0.922507) (\u0026lsquo;prince\u0026rsquo;, 0.695699) (\u0026lsquo;varys\u0026rsquo;, 0.918421) (\u0026lsquo;brother\u0026rsquo;, 0.685239) (\u0026lsquo;ned\u0026rsquo;, 0.917492) (\u0026lsquo;aerys\u0026rsquo;, 0.684527) (\u0026lsquo;melisandre\u0026rsquo;, 0.915403) (\u0026lsquo;stannis\u0026rsquo;, 0.682932) (\u0026lsquo;robb\u0026rsquo;, 0.915272) Cited as:\n@article{weng2017wordembedding, title = \u0026quot;Learning word embedding\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2017\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2017-10-15-word-embedding/\u0026quot; } References [1] Tensorflow Tutorial Vector Representations of Words.\n[2] \u0026ldquo;Word2Vec Tutorial - The Skip-Gram Model\u0026rdquo; by Chris McCormick.\n[3] \u0026ldquo;On word embeddings - Part 2: Approximating the Softmax\u0026rdquo; by Sebastian Ruder.\n[4] Xin Rong. word2vec Parameter Learning Explained\n[5] Mikolov, Tomas, Kai Chen, Greg Corrado, and Jeffrey Dean. \u0026ldquo;Efficient estimation of word representations in vector space.\u0026quot; arXiv preprint arXiv:1301.3781 (2013).\n[6] Frederic Morin and Yoshua Bengio. \u0026ldquo;Hierarchical Probabilistic Neural Network Language Model.\u0026quot; Aistats. Vol. 5. 2005.\n[7] Michael Gutmann and Aapo Hyvärinen. \u0026ldquo;Noise-contrastive estimation: A new estimation principle for unnormalized statistical models.\u0026quot; Proc. Intl. Conf. on Artificial Intelligence and Statistics. 2010.\n[8] Tomas Mikolov, Ilya Sutskever, Kai Chen, Greg Corrado, and Jeffrey Dean. \u0026ldquo;Distributed representations of words and phrases and their compositionality.\u0026quot; Advances in neural information processing systems. 2013.\n[9] Tomas Mikolov, Kai Chen, Greg Corrado, and Jeffrey Dean. \u0026ldquo;Efficient estimation of word representations in vector space.\u0026quot; arXiv preprint arXiv:1301.3781 (2013).\n[10] Marco Baroni, Georgiana Dinu, and Germán Kruszewski. \u0026ldquo;Don\u0026rsquo;t count, predict! A systematic comparison of context-counting vs. context-predicting semantic vectors.\u0026quot; ACL (1). 2014.\n[11] Jeffrey Pennington, Richard Socher, and Christopher Manning. \u0026ldquo;Glove: Global vectors for word representation.\u0026quot; Proc. Conf. on empirical methods in natural language processing (EMNLP). 2014.\n","permalink":"https://lilianweng.github.io/posts/2017-10-15-word-embedding/","summary":"Human vocabulary comes in free text. In order to make a machine learning model understand and process the natural language, we need to transform the free-text words into numeric values. One of the simplest transformation approaches is to do a one-hot encoding in which each distinct word stands for one dimension of the resulting vector and a binary value indicates whether the word presents (1) or not (0).\nHowever, one-hot encoding is impractical computationally when dealing with the entire vocabulary, as the representation demands hundreds of thousands of dimensions.","title":"Learning Word Embedding"},{"content":"Professor Naftali Tishby passed away in 2021. Hope the post can introduce his cool idea of information bottleneck to more people.\nRecently I watched the talk \u0026ldquo;Information Theory in Deep Learning\u0026rdquo; by Prof Naftali Tishby and found it very interesting. He presented how to apply the information theory to study the growth and transformation of deep neural networks during training. Using the Information Bottleneck (IB) method, he proposed a new learning bound for deep neural networks (DNN), as the traditional learning theory fails due to the exponentially large number of parameters. Another keen observation is that DNN training involves two distinct phases: First, the network is trained to fully represent the input data and minimize the generalization error; then, it learns to forget the irrelevant details by compressing the representation of the input.\nMost of the materials in this post are from Prof Tishby’s talk and related papers.\nBasic Concepts Markov Chain\nA Markov process is a \u0026ldquo;memoryless\u0026rdquo; (also called \u0026ldquo;Markov Property\u0026rdquo;) stochastic process. A Markov chain is a type of Markov process containing multiple discrete states. That is being said, the conditional probability of future states of the process is only determined by the current state and does not depend on the past states.\nKullback–Leibler (KL) Divergence\nKL divergence measures how one probability distribution $p$ diverges from a second expected probability distribution $q$. It is asymmetric.\n $$ \\begin{aligned} D_{KL}(p \\| q) \u0026= \\sum_x p(x) \\log \\frac{p(x)}{q(x)} \\\\ \u0026= - \\sum_x p(x)\\log q(x) + \\sum_x p(x)\\log p(x) \\\\ \u0026= H(P, Q) - H(P) \\end{aligned} $$ $D_{KL}$ achieves the minimum zero when $p(x)$ == $q(x)$ everywhere.\nMutual Information\nMutual information measures the mutual dependence between two variables. It quantifies the \u0026ldquo;amount of information\u0026rdquo; obtained about one random variable through the other random variable. Mutual information is symmetric.\n $$ \\begin{aligned} I(X;Y) \u0026= D_{KL}[p(x,y) \\| p(x)p(y)] \\\\ \u0026= \\sum_{x \\in X, y \\in Y} p(x, y) \\log(\\frac{p(x, y)}{p(x)p(y)}) \\\\ \u0026= \\sum_{x \\in X, y \\in Y} p(x, y) \\log(\\frac{p(x|y)}{p(x)}) \\\\ \u0026= H(X) - H(X|Y) \\\\ \\end{aligned} $$ Data Processing Inequality (DPI)\nFor any markov chain: $X \\to Y \\to Z$, we would have $I(X; Y) \\geq I(X; Z)$.\nA deep neural network can be viewed as a Markov chain, and thus when we are moving down the layers of a DNN, the mutual information between the layer and the input can only decrease.\nReparametrization invariance\nFor two invertible functions $\\phi$, $\\psi$, the mutual information still holds: $I(X; Y) = I(\\phi(X); \\psi(Y))$.\nFor example, if we shuffle the weights in one layer of DNN, it would not affect the mutual information between this layer and another.\nDeep Neural Networks as Markov Chains The training data contains sampled observations from the joint distribution of $X$ and $Y$. The input variable $X$ and weights of hidden layers are all high-dimensional random variable. The ground truth target $Y$ and the predicted value $\\hat{Y}$ are random variables of smaller dimensions in the classification settings.\nFig. 1. The structure of a deep neural network, which consists of the target label $Y$, input layer $X$, hidden layers $h\\_1, \\dots, h\\_m$ and the final prediction $\\hat{Y}$. (Image source: Tishby and Zaslavsky, 2015) If we label the hidden layers of a DNN as $h_1, h_2, \\dots, h_m$ as in Fig. 1, we can view each layer as one state of a Markov Chain: $ h_i \\to h_{i+1}$. According to DPI, we would have:\n $$ \\begin{aligned} H(X) \\geq I(X; h_1) \\geq I(X; h_2) \\geq \\dots \\geq I(X; h_m) \\geq I(X; \\hat{Y}) \\\\ I(X; Y) \\geq I(h_1; Y) \\geq I(h_2; Y) \\geq \\dots \\geq I(h_m; Y) \\geq I(\\hat{Y}; Y) \\end{aligned} $$ A DNN is designed to learn how to describe $X$ to predict $Y$ and eventually, to compress $X$ to only hold the information related to $Y$. Tishby describes this processing as \u0026ldquo;successive refinement of relevant information\u0026rdquo;.\nInformation Plane Theorem A DNN has successive internal representations of $X$, a set of hidden layers $\\{T_i\\}$. The information plane theorem characterizes each layer by its encoder and decoder information. The encoder is a representation of the input data $X$, while the decoder translates the information in the current layer to the target ouput $Y$.\nPrecisely, in an information plane plot:\n X-axis: The sample complexity of $T_i$ is determined by the encoder mutual information $I(X; T_i)$. Sample complexity refers to how many samples you need to achieve certain accuracy and generalization. Y-axis: The accuracy (generalization error) is determined by the decoder mutual information $I(T_i; Y)$. Fig. 2. The encoder vs decoder mutual information of DNN hidden layers of 50 experiments. Different layers are color-coders, with green being the layer right next to the input and the orange being the furthest. There are three snapshots, at the initial epoch, 400 epochs and 9000 epochs respectively. (Image source: Shwartz-Ziv and Tishby, 2017) Each dot in Fig. 2. marks the encoder/ decoder mutual information of one hidden layer of one network simulation (no regularization is applied; no weights decay, no dropout, etc.). They move up as expected because the knowledge about the true labels is increasing (accuracy increases). At the early stage, the hidden layers learn a lot about the input $X$, but later they start to compress to forget some information about the input. Tishby believes that \u0026ldquo;the most important part of learning is actually forgetting\u0026rdquo;. Check out this nice video that demonstrates how the mutual information measures of layers are changing in epoch time.\nFig. 3. Here is an aggregated view of Fig 2. The compression happens after the generalization error becomes very small. (Image source: Tishby’ talk 15:15) Two Optimization Phases Tracking the normalized mean and standard deviation of each layer\u0026rsquo;s weights in time also reveals two optimization phases of the training process.\nFig. 4. The norm of mean and standard deviation of each layer's weight gradients for each layer as a function of training epochs. Different layers are color-coded. (Image source: Shwartz-Ziv and Tishby, 2017) Among early epochs, the mean values are three magnitudes larger than the standard deviations. After a sufficient number of epochs, the error saturates and the standard deviations become much noisier afterward. The further a layer is away from the output, the noisier it gets, because the noises can get amplified and accumulated through the back-prop process (not due to the width of the layer).\nLearning Theory \u0026ldquo;Old\u0026rdquo; Generalization Bounds The generalization bounds defined by the classic learning theory is:\n $$ \\epsilon^2 $\\epsilon$: The difference between the training error and the generalization error. The generalization error measures how accurate the prediction of an algorithm is for previously unseen data. $H_\\epsilon$: $\\epsilon$-cover of the hypothesis class. Typically we assume the size $\\vert H_\\epsilon \\vert \\sim (1/\\epsilon)^d$. $\\delta$: Confidence. $m$: The number of training samples. $d$: The VC dimension of the hypothesis. This definition states that the difference between the training error and the generalization error is bounded by a function of the hypothesis space size and the dataset size. The bigger the hypothesis space gets, the bigger the generalization error becomes. I recommend this tutorial on ML theory, part1 and part2, if you are interested in reading more on generalization bounds.\nHowever, it does not work for deep learning. The larger a network is, the more parameters it needs to learn. With this generalization bounds, larger networks (larger $d$) would have worse bounds. This is contrary to the intuition that larger networks are able to achieve better performance with higher expressivity.\n\u0026ldquo;New\u0026rdquo; Input compression bound To solve this counterintuitive observation, Tishby et al. proposed a new input compression bound for DNN.\nFirst let us have $T_\\epsilon$ as an $\\epsilon$-partition of the input variable $X$. This partition compresses the input with respect to the homogeneity to the labels into small cells. The cells in total can cover the whole input space. If the prediction outputs binary values, we can replace the cardinality of the hypothesis, $\\vert H_\\epsilon \\vert$, with $2^{\\vert T_\\epsilon \\vert}$.\n $$ |H_\\epsilon| \\sim 2^{|X|} \\to 2^{|T_\\epsilon|} $$ When $X$ is large, the size of $X$ is approximately $2^{H(X)}$. Each cell in the $\\epsilon$-partition is of size $2^{H(X \\vert T_\\epsilon)}$. Therefore we have $\\vert T_\\epsilon \\vert \\sim \\frac{2^{H(X)}}{2^{H(X \\vert T_\\epsilon)}} = 2^{I(T_\\epsilon; X)}$. Then the input compression bound becomes:\n $$ \\epsilon^2 Fig. 5. The black line is the optimal achievable information bottleneck (IB) limit. The red line corresponds to the upper bound on the out-of-sample IB distortion, when trained on a finite sample set. $\\Delta C$ is the complexity gap and $\\Delta G$ is the generalization gap. (Recreated based on Tishby’ talk 24:50) Network Size and Training Data Size The Benefit of More Hidden Layers Having more layers give us computational benefits and speed up the training process for good generalization.\nFig. 6. The optimization time is much shorter (fewer epochs) with more hidden layers. (Image source: Shwartz-Ziv and Tishby, 2017) Compression through stochastic relaxation: According to the diffusion equation, the relaxation time of layer $k$ is proportional to the exponential of this layer\u0026rsquo;s compression amount $\\Delta S_k$: $\\Delta t_k \\sim \\exp(\\Delta S_k)$. We can compute the layer compression as $\\Delta S_k = I(X; T_k) - I(X; T_{k-1})$. Because $\\exp(\\sum_k \\Delta S_k) \\geq \\sum_k \\exp(\\Delta S_k)$, we would expect an exponential decrease in training epochs with more hidden layers (larger $k$).\nThe Benefit of More Training Samples Fitting more training data requires more information captured by the hidden layers. With increased training data size, the decoder mutual information (recall that this is directly related to the generalization error), $I(T; Y)$, is pushed up and gets closer to the theoretical information bottleneck bound. Tishby emphasized that It is the mutual information, not the layer size or the VC dimension, that determines generalization, different from standard theories.\nFig. 7. The training data of different sizes is color-coded. The information plane of multiple converged networks are plotted. More training data leads to better generalization. (Image source: Shwartz-Ziv and Tishby, 2017) Cited as:\n@article{weng2017infotheory, title = \u0026quot;Anatomize Deep Learning with Information Theory\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2017\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2017-09-28-information-bottleneck/\u0026quot; } References [1] Naftali Tishby. Information Theory of Deep Learning\n[2] Machine Learning Theory - Part 1: Introduction\n[3] Machine Learning Theory - Part 2: Generalization Bounds\n[4] New Theory Cracks Open the Black Box of Deep Learning by Quanta Magazine.\n[5] Naftali Tishby and Noga Zaslavsky. \u0026ldquo;Deep learning and the information bottleneck principle.\u0026quot; IEEE Information Theory Workshop (ITW), 2015.\n[6] Ravid Shwartz-Ziv and Naftali Tishby. \u0026ldquo;Opening the Black Box of Deep Neural Networks via Information.\u0026quot; arXiv preprint arXiv:1703.00810, 2017.\n","permalink":"https://lilianweng.github.io/posts/2017-09-28-information-bottleneck/","summary":"Professor Naftali Tishby passed away in 2021. Hope the post can introduce his cool idea of information bottleneck to more people.\nRecently I watched the talk \u0026ldquo;Information Theory in Deep Learning\u0026rdquo; by Prof Naftali Tishby and found it very interesting. He presented how to apply the information theory to study the growth and transformation of deep neural networks during training. Using the Information Bottleneck (IB) method, he proposed a new learning bound for deep neural networks (DNN), as the traditional learning theory fails due to the exponentially large number of parameters.","title":"Anatomize Deep Learning with Information Theory"},{"content":"[Updated on 2018-09-30: thanks to Yoonju, we have this post translated in Korean!] [Updated on 2019-04-18: this post is also available on arXiv.]\nGenerative adversarial network (GAN) has shown great results in many generative tasks to replicate the real-world rich content such as images, human language, and music. It is inspired by game theory: two models, a generator and a critic, are competing with each other while making each other stronger at the same time. However, it is rather challenging to train a GAN model, as people are facing issues like training instability or failure to converge.\nHere I would like to explain the maths behind the generative adversarial network framework, why it is hard to be trained, and finally introduce a modified version of GAN intended to solve the training difficulties.\nKullback–Leibler and Jensen–Shannon Divergence Before we start examining GANs closely, let us first review two metrics for quantifying the similarity between two probability distributions.\n(1) KL (Kullback–Leibler) divergence measures how one probability distribution $p$ diverges from a second expected probability distribution $q$.\n $$ D_{KL}(p \\| q) = \\int_x p(x) \\log \\frac{p(x)}{q(x)} dx $$ $D_{KL}$ achieves the minimum zero when $p(x)$ == $q(x)$ everywhere.\nIt is noticeable according to the formula that KL divergence is asymmetric. In cases where $p(x)$ is close to zero, but $q(x)$ is significantly non-zero, the $q$\u0026rsquo;s effect is disregarded. It could cause buggy results when we just want to measure the similarity between two equally important distributions.\n(2) Jensen–Shannon Divergence is another measure of similarity between two probability distributions, bounded by $[0, 1]$. JS divergence is symmetric (yay!) and more smooth. Check this Quora post if you are interested in reading more about the comparison between KL divergence and JS divergence.\n $$ D_{JS}(p \\| q) = \\frac{1}{2} D_{KL}(p \\| \\frac{p + q}{2}) + \\frac{1}{2} D_{KL}(q \\| \\frac{p + q}{2}) $$ Fig. 1. Given two Gaussian distribution, $p$ with mean=0 and std=1 and $q$ with mean=1 and std=1. The average of two distributions is labelled as $m=(p+q)/2$. KL divergence $D_{KL}$ is asymmetric but JS divergence $D_{JS}$ is symmetric. Some believe (Huszar, 2015) that one reason behind GANs' big success is switching the loss function from asymmetric KL divergence in traditional maximum-likelihood approach to symmetric JS divergence. We will discuss more on this point in the next section.\nGenerative Adversarial Network (GAN) GAN consists of two models:\n A discriminator $D$ estimates the probability of a given sample coming from the real dataset. It works as a critic and is optimized to tell the fake samples from the real ones. A generator $G$ outputs synthetic samples given a noise variable input $z$ ($z$ brings in potential output diversity). It is trained to capture the real data distribution so that its generative samples can be as real as possible, or in other words, can trick the discriminator to offer a high probability. Fig. 2. Architecture of a generative adversarial network. (Image source: www.kdnuggets.com/2017/01/generative-...-learning.html) These two models compete against each other during the training process: the generator $G$ is trying hard to trick the discriminator, while the critic model $D$ is trying hard not to be cheated. This interesting zero-sum game between two models motivates both to improve their functionalities.\nGiven,\n Symbol Meaning Notes $p_{z}$ Data distribution over noise input $z$ Usually, just uniform. $p_{g}$ The generator\u0026rsquo;s distribution over data $x$ $p_{r}$ Data distribution over real sample $x$ On one hand, we want to make sure the discriminator $D$\u0026rsquo;s decisions over real data are accurate by maximizing $\\mathbb{E}_{x \\sim p_{r}(x)} [\\log D(x)]$. Meanwhile, given a fake sample $G(z), z \\sim p_z(z)$, the discriminator is expected to output a probability, $D(G(z))$, close to zero by maximizing $\\mathbb{E}_{z \\sim p_{z}(z)} [\\log (1 - D(G(z)))]$.\nOn the other hand, the generator is trained to increase the chances of $D$ producing a high probability for a fake example, thus to minimize $\\mathbb{E}_{z \\sim p_{z}(z)} [\\log (1 - D(G(z)))]$.\nWhen combining both aspects together, $D$ and $G$ are playing a minimax game in which we should optimize the following loss function:\n $$ \\begin{aligned} \\min_G \\max_D L(D, G) \u0026 = \\mathbb{E}_{x \\sim p_{r}(x)} [\\log D(x)] + \\mathbb{E}_{z \\sim p_z(z)} [\\log(1 - D(G(z)))] \\\\ \u0026 = \\mathbb{E}_{x \\sim p_{r}(x)} [\\log D(x)] + \\mathbb{E}_{x \\sim p_g(x)} [\\log(1 - D(x)] \\end{aligned} $$ ($\\mathbb{E}_{x \\sim p_{r}(x)} [\\log D(x)]$ has no impact on $G$ during gradient descent updates.)\nWhat is the optimal value for D? Now we have a well-defined loss function. Let\u0026rsquo;s first examine what is the best value for $D$.\n $$ L(G, D) = \\int_x \\bigg( p_{r}(x) \\log(D(x)) + p_g (x) \\log(1 - D(x)) \\bigg) dx $$ Since we are interested in what is the best value of $D(x)$ to maximize $L(G, D)$, let us label\n $$ \\tilde{x} = D(x), A=p_{r}(x), B=p_g(x) $$ And then what is inside the integral (we can safely ignore the integral because $x$ is sampled over all the possible values) is:\n $$ \\begin{aligned} f(\\tilde{x}) \u0026 = A log\\tilde{x} + B log(1-\\tilde{x}) \\\\ \\frac{d f(\\tilde{x})}{d \\tilde{x}} \u0026 = A \\frac{1}{ln10} \\frac{1}{\\tilde{x}} - B \\frac{1}{ln10} \\frac{1}{1 - \\tilde{x}} \\\\ \u0026 = \\frac{1}{ln10} (\\frac{A}{\\tilde{x}} - \\frac{B}{1-\\tilde{x}}) \\\\ \u0026 = \\frac{1}{ln10} \\frac{A - (A + B)\\tilde{x}}{\\tilde{x} (1 - \\tilde{x})} \\\\ \\end{aligned} $$ Thus, set $\\frac{d f(\\tilde{x})}{d \\tilde{x}} = 0$, we get the best value of the discriminator: $D^*(x) = \\tilde{x}^* = \\frac{A}{A + B} = \\frac{p_{r}(x)}{p_{r}(x) + p_g(x)} \\in [0, 1]$.\nOnce the generator is trained to its optimal, $p_g$ gets very close to $p_{r}$. When $p_g = p_{r}$, $D^*(x)$ becomes $1/2$.\nWhat is the global optimal? When both $G$ and $D$ are at their optimal values, we have $p_g = p_{r}$ and $D^*(x) = 1/2$ and the loss function becomes:\n $$ \\begin{aligned} L(G, D^*) \u0026= \\int_x \\bigg( p_{r}(x) \\log(D^*(x)) + p_g (x) \\log(1 - D^*(x)) \\bigg) dx \\\\ \u0026= \\log \\frac{1}{2} \\int_x p_{r}(x) dx + \\log \\frac{1}{2} \\int_x p_g(x) dx \\\\ \u0026= -2\\log2 \\end{aligned} $$ What does the loss function represent? According to the formula listed in the previous section, JS divergence between $p_{r}$ and $p_g$ can be computed as:\n $$ \\begin{aligned} D_{JS}(p_{r} \\| p_g) =\u0026 \\frac{1}{2} D_{KL}(p_{r} || \\frac{p_{r} + p_g}{2}) + \\frac{1}{2} D_{KL}(p_{g} || \\frac{p_{r} + p_g}{2}) \\\\ =\u0026 \\frac{1}{2} \\bigg( \\log2 + \\int_x p_{r}(x) \\log \\frac{p_{r}(x)}{p_{r} + p_g(x)} dx \\bigg) + \\\\\u0026 \\frac{1}{2} \\bigg( \\log2 + \\int_x p_g(x) \\log \\frac{p_g(x)}{p_{r} + p_g(x)} dx \\bigg) \\\\ =\u0026 \\frac{1}{2} \\bigg( \\log4 + L(G, D^*) \\bigg) \\end{aligned} $$ Thus,\n $$ L(G, D^*) = 2D_{JS}(p_{r} \\| p_g) - 2\\log2 $$ Essentially the loss function of GAN quantifies the similarity between the generative data distribution $p_g$ and the real sample distribution $p_{r}$ by JS divergence when the discriminator is optimal. The best $G^*$ that replicates the real data distribution leads to the minimum $L(G^*, D^*) = -2\\log2$ which is aligned with equations above.\n Other Variations of GAN: There are many variations of GANs in different contexts or designed for different tasks. For example, for semi-supervised learning, one idea is to update the discriminator to output real class labels, $1, \\dots, K-1$, as well as one fake class label $K$. The generator model aims to trick the discriminator to output a classification label smaller than $K$.\n Tensorflow Implementation: carpedm20/DCGAN-tensorflow\nProblems in GANs Although GAN has shown great success in the realistic image generation, the training is not easy; The process is known to be slow and unstable.\nHard to achieve Nash equilibrium Salimans et al. (2016) discussed the problem with GAN\u0026rsquo;s gradient-descent-based training procedure. Two models are trained simultaneously to find a Nash equilibrium to a two-player non-cooperative game. However, each model updates its cost independently with no respect to another player in the game. Updating the gradient of both models concurrently cannot guarantee a convergence.\nLet\u0026rsquo;s check out a simple example to better understand why it is difficult to find a Nash equilibrium in an non-cooperative game. Suppose one player takes control of $x$ to minimize $f_1(x) = xy$, while at the same time the other player constantly updates $y$ to minimize $f_2(y) = -xy$.\nBecause $\\frac{\\partial f_1}{\\partial x} = y$ and $\\frac{\\partial f_2}{\\partial y} = -x$, we update $x$ with $x-\\eta \\cdot y$ and $y$ with $y+ \\eta \\cdot x$ simulitanously in one iteration, where $\\eta$ is the learning rate. Once $x$ and $y$ have different signs, every following gradient update causes huge oscillation and the instability gets worse in time, as shown in Fig. 3.\nFig. 3. A simulation of our example for updating $x$ to minimize $xy$ and updating $y$ to minimize $-xy$. The learning rate $\\eta = 0.1$. With more iterations, the oscillation grows more and more unstable. Low dimensional supports Term Explanation Manifold A topological space that locally resembles Euclidean space near each point. Precisely, when this Euclidean space is of dimension $n$, the manifold is referred as $n$-manifold. Support A real-valued function $f$ is the subset of the domain containing those elements which are not mapped to zero. Arjovsky and Bottou (2017) discussed the problem of the supports of $p_r$ and $p_g$ lying on low dimensional manifolds and how it contributes to the instability of GAN training thoroughly in a very theoretical paper \u0026ldquo;Towards principled methods for training generative adversarial networks\u0026rdquo;.\nThe dimensions of many real-world datasets, as represented by $p_r$, only appear to be artificially high. They have been found to concentrate in a lower dimensional manifold. This is actually the fundamental assumption for Manifold Learning. Thinking of the real world images, once the theme or the contained object is fixed, the images have a lot of restrictions to follow, i.e., a dog should have two ears and a tail, and a skyscraper should have a straight and tall body, etc. These restrictions keep images aways from the possibility of having a high-dimensional free form.\n$p_g$ lies in a low dimensional manifolds, too. Whenever the generator is asked to a much larger image like 64x64 given a small dimension, such as 100, noise variable input $z$, the distribution of colors over these 4096 pixels has been defined by the small 100-dimension random number vector and can hardly fill up the whole high dimensional space.\nBecause both $p_g$ and $p_r$ rest in low dimensional manifolds, they are almost certainly gonna be disjoint (See Fig. 4). When they have disjoint supports, we are always capable of finding a perfect discriminator that separates real and fake samples 100% correctly. Check the paper if you are curious about the proof.\nFig. 4. Low dimensional manifolds in high dimension space can hardly have overlaps. (Left) Two lines in a three-dimension space. (Right) Two surfaces in a three-dimension space. Vanishing gradient When the discriminator is perfect, we are guaranteed with $D(x) = 1, \\forall x \\in p_r$ and $D(x) = 0, \\forall x \\in p_g$. Therefore the loss function $L$ falls to zero and we end up with no gradient to update the loss during learning iterations. Fig. 5 demonstrates an experiment when the discriminator gets better, the gradient vanishes fast.\nFig. 5. First, a DCGAN is trained for 1, 10 and 25 epochs. Then, with the **generator fixed**, a discriminator is trained from scratch and measure the gradients with the original cost function. We see the gradient norms **decay quickly** (in log scale), in the best case 5 orders of magnitude after 4000 discriminator iterations. (Image source: Arjovsky and Bottou, 2017) As a result, training a GAN faces a dilemma:\n If the discriminator behaves badly, the generator does not have accurate feedback and the loss function cannot represent the reality. If the discriminator does a great job, the gradient of the loss function drops down to close to zero and the learning becomes super slow or even jammed. This dilemma clearly is capable to make the GAN training very tough.\nMode collapse During the training, the generator may collapse to a setting where it always produces same outputs. This is a common failure case for GANs, commonly referred to as Mode Collapse. Even though the generator might be able to trick the corresponding discriminator, it fails to learn to represent the complex real-world data distribution and gets stuck in a small space with extremely low variety.\nFig. 6. A DCGAN model is trained with an MLP network with 4 layers, 512 units and ReLU activation function, configured to lack a strong inductive bias for image generation. The results shows a significant degree of mode collapse. (Image source: Arjovsky, Chintala, \u0026 Bottou, 2017.) Lack of a proper evaluation metric Generative adversarial networks are not born with a good objection function that can inform us the training progress. Without a good evaluation metric, it is like working in the dark. No good sign to tell when to stop; No good indicator to compare the performance of multiple models.\nImproved GAN Training The following suggestions are proposed to help stabilize and improve the training of GANs.\nFirst five methods are practical techniques to achieve faster convergence of GAN training, proposed in \u0026ldquo;Improve Techniques for Training GANs\u0026rdquo;. The last two are proposed in \u0026ldquo;Towards principled methods for training generative adversarial networks\u0026rdquo; to solve the problem of disjoint distributions.\n(1) Feature Matching\nFeature matching suggests to optimize the discriminator to inspect whether the generator\u0026rsquo;s output matches expected statistics of the real samples. In such a scenario, the new loss function is defined as $| \\mathbb{E}_{x \\sim p_r} f(x) - \\mathbb{E}_{z \\sim p_z(z)}f(G(z)) |_2^2 $, where $f(x)$ can be any computation of statistics of features, such as mean or median.\n(2) Minibatch Discrimination\nWith minibatch discrimination, the discriminator is able to digest the relationship between training data points in one batch, instead of processing each point independently.\nIn one minibatch, we approximate the closeness between every pair of samples, $c(x_i, x_j)$, and get the overall summary of one data point by summing up how close it is to other samples in the same batch, $o(x_i) = \\sum_{j} c(x_i, x_j)$. Then $o(x_i)$ is explicitly added to the input of the model.\n(3) Historical Averaging\nFor both models, add $ | \\Theta - \\frac{1}{t} \\sum_{i=1}^t \\Theta_i |^2 $ into the loss function, where $\\Theta$ is the model parameter and $\\Theta_i$ is how the parameter is configured at the past training time $i$. This addition piece penalizes the training speed when $\\Theta$ is changing too dramatically in time.\n(4) One-sided Label Smoothing\nWhen feeding the discriminator, instead of providing 1 and 0 labels, use soften values such as 0.9 and 0.1. It is shown to reduce the networks' vulnerability.\n(5) Virtual Batch Normalization (VBN)\nEach data sample is normalized based on a fixed batch (\u0026ldquo;reference batch\u0026rdquo;) of data rather than within its minibatch. The reference batch is chosen once at the beginning and stays the same through the training.\nTheano Implementation: openai/improved-gan\n(6) Adding Noises.\nBased on the discussion in the previous section, we now know $p_r$ and $p_g$ are disjoint in a high dimensional space and it causes the problem of vanishing gradient. To artificially \u0026ldquo;spread out\u0026rdquo; the distribution and to create higher chances for two probability distributions to have overlaps, one solution is to add continuous noises onto the inputs of the discriminator $D$.\n(7) Use Better Metric of Distribution Similarity\nThe loss function of the vanilla GAN measures the JS divergence between the distributions of $p_r$ and $p_g$. This metric fails to provide a meaningful value when two distributions are disjoint.\nWasserstein metric is proposed to replace JS divergence because it has a much smoother value space. See more in the next section.\nWasserstein GAN (WGAN) What is Wasserstein distance? Wasserstein Distance is a measure of the distance between two probability distributions. It is also called Earth Mover\u0026rsquo;s distance, short for EM distance, because informally it can be interpreted as the minimum energy cost of moving and transforming a pile of dirt in the shape of one probability distribution to the shape of the other distribution. The cost is quantified by: the amount of dirt moved x the moving distance.\nLet us first look at a simple case where the probability domain is discrete. For example, suppose we have two distributions $P$ and $Q$, each has four piles of dirt and both have ten shovelfuls of dirt in total. The numbers of shovelfuls in each dirt pile are assigned as follows:\n $$ P_1 = 3, P_2 = 2, P_3 = 1, P_4 = 4\\\\ Q_1 = 1, Q_2 = 2, Q_3 = 4, Q_4 = 3 $$ In order to change $P$ to look like $Q$, as illustrated in Fig. 7, we:\n First move 2 shovelfuls from $P_1$ to $P_2$ =\u0026gt; $(P_1, Q_1)$ match up. Then move 2 shovelfuls from $P_2$ to $P_3$ =\u0026gt; $(P_2, Q_2)$ match up. Finally move 1 shovelfuls from $Q_3$ to $Q_4$ =\u0026gt; $(P_3, Q_3)$ and $(P_4, Q_4)$ match up. If we label the cost to pay to make $P_i$ and $Q_i$ match as $\\delta_i$, we would have $\\delta_{i+1} = \\delta_i + P_i - Q_i$ and in the example:\n $$ \\begin{aligned} \\delta_0 \u0026= 0\\\\ \\delta_1 \u0026= 0 + 3 - 1 = 2\\\\ \\delta_2 \u0026= 2 + 2 - 2 = 2\\\\ \\delta_3 \u0026= 2 + 1 - 4 = -1\\\\ \\delta_4 \u0026= -1 + 4 - 3 = 0 \\end{aligned} $$ Finally the Earth Mover\u0026rsquo;s distance is $W = \\sum \\vert \\delta_i \\vert = 5$.\nFig. 7. Step-by-step plan of moving dirt between piles in $P$ and $Q$ to make them match. When dealing with the continuous probability domain, the distance formula becomes:\n $$ W(p_r, p_g) = \\inf_{\\gamma \\sim \\Pi(p_r, p_g)} \\mathbb{E}_{(x, y) \\sim \\gamma}[\\| x-y \\|] $$ In the formula above, $\\Pi(p_r, p_g)$ is the set of all possible joint probability distributions between $p_r$ and $p_g$. One joint distribution $\\gamma \\in \\Pi(p_r, p_g)$ describes one dirt transport plan, same as the discrete example above, but in the continuous probability space. Precisely $\\gamma(x, y)$ states the percentage of dirt should be transported from point $x$ to $y$ so as to make $x$ follows the same probability distribution of $y$. That\u0026rsquo;s why the marginal distribution over $x$ adds up to $p_g$, $\\sum_{x} \\gamma(x, y) = p_g(y)$ (Once we finish moving the planned amount of dirt from every possible $x$ to the target $y$, we end up with exactly what $y$ has according to $p_g$.) and vice versa $\\sum_{y} \\gamma(x, y) = p_r(x)$.\nWhen treating $x$ as the starting point and $y$ as the destination, the total amount of dirt moved is $\\gamma(x, y)$ and the travelling distance is $| x-y |$ and thus the cost is $\\gamma(x, y) \\cdot | x-y |$. The expected cost averaged across all the $(x,y)$ pairs can be easily computed as:\n $$ \\sum_{x, y} \\gamma(x, y) \\| x-y \\| = \\mathbb{E}_{x, y \\sim \\gamma} \\| x-y \\| $$ Finally, we take the minimum one among the costs of all dirt moving solutions as the EM distance. In the definition of Wasserstein distance, the $\\inf$ (infimum, also known as greatest lower bound) indicates that we are only interested in the smallest cost.\nWhy Wasserstein is better than JS or KL divergence? Even when two distributions are located in lower dimensional manifolds without overlaps, Wasserstein distance can still provide a meaningful and smooth representation of the distance in-between.\nThe WGAN paper exemplified the idea with a simple example.\nSuppose we have two probability distributions, $P$ and $Q$:\n $$ \\forall (x, y) \\in P, x = 0 \\text{ and } y \\sim U(0, 1)\\\\ \\forall (x, y) \\in Q, x = \\theta, 0 \\leq \\theta \\leq 1 \\text{ and } y \\sim U(0, 1)\\\\ $$ Fig. 8. There is no overlap between $P$ and $Q$ when $\\theta \\neq 0$. When $\\theta \\neq 0$:\n $$ \\begin{aligned} D_{KL}(P \\| Q) \u0026= \\sum_{x=0, y \\sim U(0, 1)} 1 \\cdot \\log\\frac{1}{0} = +\\infty \\\\ D_{KL}(Q \\| P) \u0026= \\sum_{x=\\theta, y \\sim U(0, 1)} 1 \\cdot \\log\\frac{1}{0} = +\\infty \\\\ D_{JS}(P, Q) \u0026= \\frac{1}{2}(\\sum_{x=0, y \\sim U(0, 1)} 1 \\cdot \\log\\frac{1}{1/2} + \\sum_{x=0, y \\sim U(0, 1)} 1 \\cdot \\log\\frac{1}{1/2}) = \\log 2\\\\ W(P, Q) \u0026= |\\theta| \\end{aligned} $$ But when $\\theta = 0$, two distributions are fully overlapped:\n $$ \\begin{aligned} D_{KL}(P \\| Q) \u0026= D_{KL}(Q \\| P) = D_{JS}(P, Q) = 0\\\\ W(P, Q) \u0026= 0 = \\lvert \\theta \\rvert \\end{aligned} $$ $D_{KL}$ gives us inifity when two distributions are disjoint. The value of $D_{JS}$ has sudden jump, not differentiable at $\\theta = 0$. Only Wasserstein metric provides a smooth measure, which is super helpful for a stable learning process using gradient descents.\nUse Wasserstein distance as GAN loss function It is intractable to exhaust all the possible joint distributions in $\\Pi(p_r, p_g)$ to compute $\\inf_{\\gamma \\sim \\Pi(p_r, p_g)}$. Thus the authors proposed a smart transformation of the formula based on the Kantorovich-Rubinstein duality to:\n $$ W(p_r, p_g) = \\frac{1}{K} \\sup_{\\| f \\|_L \\leq K} \\mathbb{E}_{x \\sim p_r}[f(x)] - \\mathbb{E}_{x \\sim p_g}[f(x)] $$ where $\\sup$ (supremum) is the opposite of $inf$ (infimum); we want to measure the least upper bound or, in even simpler words, the maximum value.\nLipschitz continuity?\nThe function $f$ in the new form of Wasserstein metric is demanded to satisfy $| f |_L \\leq K$, meaning it should be K-Lipschitz continuous.\nA real-valued function $f: \\mathbb{R} \\rightarrow \\mathbb{R}$ is called $K$-Lipschitz continuous if there exists a real constant $K \\geq 0$ such that, for all $x_1, x_2 \\in \\mathbb{R}$,\n$$ \\lvert f(x_1) - f(x_2) \\rvert \\leq K \\lvert x_1 - x_2 \\rvert $$\nHere $K$ is known as a Lipschitz constant for function $f(.)$. Functions that are everywhere continuously differentiable is Lipschitz continuous, because the derivative, estimated as $\\frac{\\lvert f(x_1) - f(x_2) \\rvert}{\\lvert x_1 - x_2 \\rvert}$, has bounds. However, a Lipschitz continuous function may not be everywhere differentiable, such as $f(x) = \\lvert x \\rvert$.\nExplaining how the transformation happens on the Wasserstein distance formula is worthy of a long post by itself, so I skip the details here. If you are interested in how to compute Wasserstein metric using linear programming, or how to transfer Wasserstein metric into its dual form according to the Kantorovich-Rubinstein Duality, read this awesome post.\nSuppose this function $f$ comes from a family of K-Lipschitz continuous functions, $\\{ f_w \\}_{w \\in W}$, parameterized by $w$. In the modified Wasserstein-GAN, the \u0026ldquo;discriminator\u0026rdquo; model is used to learn $w$ to find a good $f_w$ and the loss function is configured as measuring the Wasserstein distance between $p_r$ and $p_g$.\n $$ L(p_r, p_g) = W(p_r, p_g) = \\max_{w \\in W} \\mathbb{E}_{x \\sim p_r}[f_w(x)] - \\mathbb{E}_{z \\sim p_r(z)}[f_w(g_\\theta(z))] $$ Thus the \u0026ldquo;discriminator\u0026rdquo; is not a direct critic of telling the fake samples apart from the real ones anymore. Instead, it is trained to learn a $K$-Lipschitz continuous function to help compute Wasserstein distance. As the loss function decreases in the training, the Wasserstein distance gets smaller and the generator model\u0026rsquo;s output grows closer to the real data distribution.\nOne big problem is to maintain the $K$-Lipschitz continuity of $f_w$ during the training in order to make everything work out. The paper presents a simple but very practical trick: After every gradient update, clamp the weights $w$ to a small window, such as $[-0.01, 0.01]$, resulting in a compact parameter space $W$ and thus $f_w$ obtains its lower and upper bounds to preserve the Lipschitz continuity.\nFig. 9. Algorithm of Wasserstein generative adversarial network. (Image source: Arjovsky, Chintala, \u0026 Bottou, 2017.) Compared to the original GAN algorithm, the WGAN undertakes the following changes:\n After every gradient update on the critic function, clamp the weights to a small fixed range, $[-c, c]$. Use a new loss function derived from the Wasserstein distance, no logarithm anymore. The \u0026ldquo;discriminator\u0026rdquo; model does not play as a direct critic but a helper for estimating the Wasserstein metric between real and generated data distribution. Empirically the authors recommended RMSProp optimizer on the critic, rather than a momentum based optimizer such as Adam which could cause instability in the model training. I haven\u0026rsquo;t seen clear theoretical explanation on this point through. Sadly, Wasserstein GAN is not perfect. Even the authors of the original WGAN paper mentioned that \u0026ldquo;Weight clipping is a clearly terrible way to enforce a Lipschitz constraint\u0026rdquo; (Oops!). WGAN still suffers from unstable training, slow convergence after weight clipping (when clipping window is too large), and vanishing gradients (when clipping window is too small).\nSome improvement, precisely replacing weight clipping with gradient penalty, has been discussed in Gulrajani et al. 2017. I will leave this to a future post.\nExample: Create New Pokemons! Just for fun, I tried out carpedm20/DCGAN-tensorflow on a tiny dataset, Pokemon sprites. The dataset only has 900-ish pokemon images, including different levels of same pokemon species.\nLet\u0026rsquo;s check out what types of new pokemons the model is able to create. Unfortunately due to the tiny training data, the new pokemons only have rough shapes without details. The shapes and colors do look better with more training epoches! Hooray!\nFig. 10. Train carpedm20/DCGAN-tensorflow on a set of Pokemon sprite images. The sample outputs are listed after training epoches = 7, 21, 49. If you are interested in a commented version of carpedm20/DCGAN-tensorflow and how to modify it to train WGAN and WGAN with gradient penalty, check lilianweng/unified-gan-tensorflow.\n Cited as:\n@article{weng2017gan, title = \u0026quot;From GAN to WGAN\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2017\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2017-08-20-gan/\u0026quot; } OR\n@misc{weng2019gan, title={From GAN to WGAN}, author={Lilian Weng}, year={2019}, eprint={1904.08994}, archivePrefix={arXiv}, primaryClass={cs.LG} } References [1] Goodfellow, Ian, et al. \u0026ldquo;Generative adversarial nets.\u0026quot; NIPS, 2014.\n[2] Tim Salimans, et al. \u0026ldquo;Improved techniques for training gans.\u0026quot; NIPS 2016.\n[3] Martin Arjovsky and Léon Bottou. \u0026ldquo;Towards principled methods for training generative adversarial networks.\u0026quot; arXiv preprint arXiv:1701.04862 (2017).\n[4] Martin Arjovsky, Soumith Chintala, and Léon Bottou. \u0026ldquo;Wasserstein GAN.\u0026quot; arXiv preprint arXiv:1701.07875 (2017).\n[5] Ishaan Gulrajani, Faruk Ahmed, Martin Arjovsky, Vincent Dumoulin, Aaron Courville. Improved training of wasserstein gans. arXiv preprint arXiv:1704.00028 (2017).\n[6] Computing the Earth Mover\u0026rsquo;s Distance under Transformations\n[7] Wasserstein GAN and the Kantorovich-Rubinstein Duality\n[8] zhuanlan.zhihu.com/p/25071913\n[9] Ferenc Huszár. \u0026ldquo;How (not) to Train your Generative Model: Scheduled Sampling, Likelihood, Adversary?.\u0026quot; arXiv preprint arXiv:1511.05101 (2015).\n","permalink":"https://lilianweng.github.io/posts/2017-08-20-gan/","summary":"[Updated on 2018-09-30: thanks to Yoonju, we have this post translated in Korean!] [Updated on 2019-04-18: this post is also available on arXiv.]\nGenerative adversarial network (GAN) has shown great results in many generative tasks to replicate the real-world rich content such as images, human language, and music. It is inspired by game theory: two models, a generator and a critic, are competing with each other while making each other stronger at the same time.","title":"From GAN to WGAN"},{"content":"The machine learning models have started penetrating into critical areas like health care, justice systems, and financial industry. Thus to figure out how the models make the decisions and make sure the decisioning process is aligned with the ethnic requirements or legal regulations becomes a necessity.\nMeanwhile, the rapid growth of deep learning models pushes the requirement of interpreting complicated models further. People are eager to apply the power of AI fully on key aspects of everyday life. However, it is hard to do so without enough trust in the models or an efficient procedure to explain unintended behavior, especially considering that the deep neural networks are born as black-boxes.\nThink of the following cases:\n The financial industry is highly regulated and loan issuers are required by law to make fair decisions and explain their credit models to provide reasons whenever they decide to decline loan application. Medical diagnosis model is responsible for human life. How can we be confident enough to treat a patient as instructed by a black-box model? When using a criminal decision model to predict the risk of recidivism at the court, we have to make sure the model behaves in an equitable, honest and nondiscriminatory manner. If a self-driving car suddenly acts abnormally and we cannot explain why, are we gonna be comfortable enough to use the technique in real traffic in large scale? At Affirm, we are issuing tens of thousands of installment loans every day and our underwriting model has to provide declination reasons when the model rejects one\u0026rsquo;s loan application. That\u0026rsquo;s one of the many motivations for me to dig deeper and write this post. Model interpretability is a big field in machine learning. This review is never met to exhaust every study, but to serve as a starting point.\n Interpretable Models Lipton (2017) summarized the properties of an interpretable model in a theoretical review paper, \u0026ldquo;The mythos of model interpretability\u0026rdquo;: A human can repeat (\u0026ldquo;simulatability\u0026rdquo;) the computation process with a full understanding of the algorithm (\u0026ldquo;algorithmic transparency\u0026rdquo;) and every individual part of the model owns an intuitive explanation (\u0026ldquo;decomposability\u0026rdquo;).\nMany classic models have relatively simpler formation and naturally, come with a model-specific interpretation method. Meanwhile, new tools are being developed to help create better interpretable models (Been, Khanna, \u0026amp; Koyejo, 2016; Lakkaraju, Bach \u0026amp; Leskovec, 2016).\nRegression A general form of a linear regression model is:\n$$ y = w_0 + w_1 x_1 + w_2 x_2 + … + w_n x_n $$\nThe coefficients describe the change of the response triggered by one unit increase of the independent variables. The coefficients are not comparable directly unless the features have been standardized (check sklearn.preprocessing.StandardScalar and RobustScaler), since one unit of different features can refer to very different things. Without standardization, the product $w_i \\dot x_i$ can be used to quantify one feature\u0026rsquo;s contribution to the response.\nNaive Bayes Naive Bayes is named as \u0026ldquo;Naive\u0026rdquo; because it works on a very simplified assumption that features are independent of each other and each contributes to the output independently.\nGiven a feature vector $\\mathbf{x} = [x_1, x_2, \\dots, x_n]$ and a class label $c \\in \\{1, 2, \\dots, C\\}$, the probability of this data point belonging to this class is:\n $$ \\begin{aligned} p(c | x_1, x_2, \\dots, x_n) \u0026\\propto p(c, x_1, x_2, \\dots, x_n)\\\\ \u0026\\propto p(c) p(x_1 | c) p(x_2 | c) \\dots p(x_n | c)\\\\ \u0026\\propto p(c) \\prod_{i=1}^n p(x_i | c). \\end{aligned} $$ The Naive Bayes classifier is then defined as:\n$$ \\hat{y} = \\arg\\max_{c \\in 1, \\dots, C} p(c) \\prod_{i=1}^n p(x_i | c) $$\nBecause the model has learned the prior $p(x_i \\vert c)$ during the training, the contribution of an individual feature value can be easily measured by the posterior, $p(c \\vert x_i) = p(c)p(x_i \\vert c) / p(x_i)$.\nDecision Tree/Decision Lists Decision lists are a set of boolean functions, usually constructed by the syntax like if... then... else.... The if-condition contains a function involving one or multiple features and a boolean output. Decision lists are born with good interpretability and can be visualized in a tree structure. Many research on decision lists is driven by medical applications, where the interpretability is almost as crucial as the model itself.\nA few types of decision lists are briefly described below:\n Falling Rule Lists (FRL) (Wang and Rudin, 2015) has fully enforced monotonicity on feature values. One key point, for example in the binary classification context, is that the probability of prediction $Y=1$ associated with each rule decreases as one moves down the decision lists. Bayesian Rule List (BRL) (Letham et al., 2015) is a generative model that yields a posterior distribution over possible decision lists. Interpretable Decision Sets (IDS) (Lakkaraju, Bach \u0026amp; Leskovec, 2016) is a prediction framework to create a set of classification rules. The learning is optimized for both accuracy and interpretability simultaneously. IDS is closely related to the BETA method I\u0026rsquo;m gonna describe later for interpreting black-box models. Random Forests Weirdly enough, many people believe that the Random Forests model is a black box, which is not true. Considering that the output of random forests is the majority vote by a large number of independent decision trees and each tree is naturally interpretable.\nIt is not very hard to gauge the influence of individual features if we look into a single tree at a time. The global feature importance of random forests can be quantified by the total decrease in node impurity averaged over all trees of the ensemble (\u0026ldquo;mean decrease impurity\u0026rdquo;).\nFor one instance, because the decision paths in all the trees are well tracked, we can use the difference between the mean value of data points in a parent node between that of a child node to approximate the contribution of this split. Read more in this series of blog posts: Interpreting Random Forests.\nInterpreting Black-Box Models A lot of models are not designed to be interpretable. Approaches to explaining a black-box model aim to extract information from the trained model to justify its prediction outcome, without knowing how the model works in details. To keep the interpretation process independent from the model implementation is good for real-world applications: Even when the base model is being constantly upgraded and refined, the interpretation engine built on top would not worry about the changes.\nWithout the concern of keeping the model transparent and interpretable, we can endow the model with greater power of expressivity by adding more parameters and nonlinearity computation. That\u0026rsquo;s how deep neural networks become successful in tasks involving rich inputs.\nThere is no hard requirement on how the explanation should be presented, but the primary goal is mainly to answer: Can I trust this model? When we rely on the model to make a critical or life-and-death decision, we have to make sure the model is trustworthy ahead of time.\nThe interpretation framework should balance between two goals:\n Fidelity: the prediction produced by an explanation should agree with the original model as much as possible. Interpretability: the explanation should be simple enough to be human-understandable. Side Notes: The next three methods are designed for local interpretation.\n Prediction Decomposition Robnik-Sikonja and Kononenko (2008) proposed to explain the model prediction for one instance by measuring the difference between the original prediction and the one made with omitting a set of features.\nLet\u0026rsquo;s say we need to generate an explanation for a classification model $f: \\mathbf{X} \\rightarrow \\mathbf{Y}$. Given a data point $x \\in X$ which consists of $a$ individual values of attribute $A_i$, $i = 1, \\dots, a$, and is labeled with class $y \\in Y$. The prediction difference is quantified by computing the difference between the model predicted probabilities with or without knowing $A_i$:\n$$ \\text{probDiff}_i (y | x) = p(y| x) - p(y | x \\backslash A_i) $$\n(The paper also discussed on using the odds ratio or the entropy-based information metric to quantify the prediction difference.)\nProblem: If the target model outputs a probability, then great, getting $ p(y \\vert x) $ is straightforward. Otherwise, the model prediction has to run through an appropriate post-modeling calibration to translate the prediction score into probabilities. This calibration layer is another piece of complication.\nAnother problem: If we generate $x \\backslash A_i$ by replacing $A_i$ with a missing value (like None, NaN, etc.), we have to rely on the model\u0026rsquo;s internal mechanism for missing value imputation. A model which replaces these missing cases with the median should have output very different from a model which imputes a special placeholder. One solution as presented in the paper is to replace $A_i$ with all possible values of this feature and then sum up the prediction weighted by how likely each value shows in the data:\n $$ \\begin{aligned} p(y \\vert x \\backslash A_i) \u0026= \\sum_{s=1}^{m_i} p(A_i=a_s \\vert x \\backslash A_i) p(y \\vert x \\leftarrow A_i=a_s) \\\\ \u0026\\approx \\sum_{s=1}^{m_i} p(A_i=a_s) p(y \\vert x \\leftarrow A_i=a_s) \\end{aligned} $$ Where $p(y \\vert x \\leftarrow A_i=a_s)$ is the probability of getting label $y$ if we replace the feature $A_i$ with value $a_s$ in the feature vector of $x$. There are $m_i$ unique values of $A_i$ in the training set.\nWith the help of the measures of prediction difference when omitting known features, we can decompose the impact of each individual feature on the prediction.\nFig. 1. Explanations for a SVM model predicting the survival of one male adult first-class passenger in the Titanic dataset. The information difference is very similar to the probability difference, but it measures the amount of information necessary to find out $y$ is true for the given instance without the knowledge of $A\\_i$: $\\text{infDiff}\\_i (y|x) = \\log\\_2 p(y|x) - \\log\\_2 p(y|x \\backslash A\\_i)$. Explanations for particular instance are depicted with dark bars. The light shaded half-height bars are average positive and negative explanations for given attributes' values. In this case, being a male adult makes it very less likely to survive; the class level does not impact as much. Local Gradient Explanation Vector This method (Baehrens, et al. 2010) is able to explain the local decision taken by arbitrary nonlinear classification algorithms, using the local gradients that characterize how a data point has to be moved to change its predicted label.\nLet\u0026rsquo;s say, we have a Bayes Classifier which is trained on the data set $X$ and outputs probabilities over the class labels $Y$, $p(Y=y \\vert X=x)$. And one class label $y$ is drawn from the class label pool, $\\{1, 2, \\dots, C\\}$. This Bayes classifier is constructed as:\n$$ f^{*}(x) = \\arg \\min_{c \\in \\{1, \\dots, C\\}} p(Y \\neq c \\vert X = x) $$\nThe local explanation vector is defined as the derivative of the probability prediction function at the test point $x = x_0$. A large entry in this vector highlights a feature with a big influence on the model decision; A positive sign indicates that increasing the feature would lower the probability of $x_0$ assigned to $f^{*}(x_0)$.\nHowever, this approach requires the model output to be a probability (similar to the \u0026ldquo;Prediction Decomposition\u0026rdquo; method above). What if the original model (labelled as $f$) is not calibrated to yield probabilities? As suggested by the paper, we can approximate $f$ by another classifier in a form that resembles the Bayes classifier $f^{*}$:\n(1) Apply Parzen window to the training data to estimate the weighted class densities:\n$$ \\hat{p}_{\\sigma}(x, y=c) = \\frac{1}{n} \\sum_{i \\in I_c} k_{\\sigma} (x - x_i) $$\nWhere $I_c$ is the index set containing the indices of data points assigned to class $c$ by the model $f$, $I_c = \\{i \\vert f(x_i) = c\\}$. $k_{\\sigma}$ is a kernel function. Gaussian kernel is a popular one among many candidates.\n(2) Then, apply the Bayes' rule to approximate the probability $p(Y=c \\vert X=x)$ for all classes:\n $$ \\begin{aligned} \\hat{p}_{\\sigma}(y=c | x) \u0026= \\frac{\\hat{p}_{\\sigma}(x, y=c)}{\\hat{p}_{\\sigma}(x, y=c) + \\hat{p}_{\\sigma}(x, y \\neq c)} \\\\ \u0026\\approx \\frac{\\sum_{i \\in I_c} k_{\\sigma} (x - x_i)}{\\sum_i k_{\\sigma} (x - x_i)} \\end{aligned} $$ (3) The final estimated Bayes classifier takes the form:\n$$ \\hat{f}_{\\sigma} = \\arg\\min_{c \\in \\{1, \\dots, C\\}} \\hat{p}_{\\sigma}(y \\neq c \\vert x) $$\nNoted that we can generate the labeled data with the original model $f$, as much as we want, not restricted by the size of the training data. The hyperparameter $\\sigma$ is selected to optimize the chances of $\\hat{f}_{\\sigma}(x) = f(x)$ to achieve high fidelity.\nFig. 2. An example of how local gradient explanation vector is applied on simple object classification with Gaussian Processes Classifier (GPC). The GPC model outputs the probability by nature. (a) shows the training points and their labels in red (positive 1) and blue (negative -1). (b) illustrates a probability function for the positive class. (c-d) shows the local gradients and the directions of the local explanation vectors. Side notes: As you can see both the methods above require the model prediction to be a probability. Calibration of the model output adds another layer of complication.\n LIME (Local Interpretable Model-Agnostic Explanations) LIME, short for local interpretable model-agnostic explanation, can approximate a black-box model locally in the neighborhood of the prediction we are interested (Ribeiro, Singh, \u0026amp; Guestrin, 2016).\nSame as above, let us label the black-box model as $f$. LIME presents the following steps:\n(1) Convert the dataset into interpretable data representation: $x \\Rightarrow x_b$.\n Text classifier: a binary vector indicating the presence or absence of a word Image classifier: a binary vector indicating the presence or absence of a contiguous patch of similar pixels (super-pixel). Fig. 3. An example of converting an image into interpretable data representation. (Image source: www.oreilly.com/learning/introduction-to-local-interpretable-model-agnostic-explanations-lime) (2) Given a prediction $f(x)$ with the corresponding interpretable data representation $x_b$, let us sample instances around $x_b$ by drawing nonzero elements of $x_b$ uniformly at random where the number of such draws is also uniformly sampled. This process generates a perturbed sample $z_b$ which contains a fraction of nonzero elements of $x_b$.\nThen we recover $z_b$ back into the original input $z$ and get a prediction score $f(z)$ by the target model.\nUse many such sampled data points $z_b \\in \\mathcal{Z}_b$ and their model predictions, we can learn an explanation model (such as in a form as simple as a regression) with local fidelity. The sampled data points are weighted differently based on how close they are to $x_b$. The paper used a lasso regression with preprocessing to select top $k$ most significant features beforehand, named \u0026ldquo;K-LASSO\u0026rdquo;.\nFig. 4. The pink and blue areas are two classes predicted by the black-box model $f$. the big red cross is the point to be explained and other smaller crosses (predicted as pink by $f$) and dots (predicted as blue by $f$) are sampled data points. Even though the model can be very complicated, we are still able to learn a local explanation model as simple as the grey dash line. (Image source: homes.cs.washington.edu/~marcotcr/blog/lime) Examining whether the explanation makes sense can directly decide whether the model is trustworthy because sometimes the model can pick up spurious correlation or generalization. One interesting example in the paper is to apply LIME on an SVM text classifier for differentiating \u0026ldquo;Christianity\u0026rdquo; from \u0026ldquo;Atheism\u0026rdquo;. The model achieved a pretty good accuracy (94% on held-out testing set!), but the LIME explanation demonstrated that decisions were made by very arbitrary reasons, such as counting the words \u0026ldquo;re\u0026rdquo;, \u0026ldquo;posting\u0026rdquo; and \u0026ldquo;host\u0026rdquo; which have no connection with neither \u0026ldquo;Christianity\u0026rdquo; nor \u0026ldquo;Atheism\u0026rdquo; directly. After such a diagnosis, we learned that even the model gives us a nice accuracy, it cannot be trusted. It also shed lights on ways to improve the model, such as better preprocessing on the text.\nFig. 5. Illustration of how to use LIME on an image classifier. (Image source: www.oreilly.com/learning/introduction-to-local-interpretable-model-agnostic-explanations-lime) For more detailed non-paper explanation, please read this blog post by the author. A very nice read.\n Side Notes: Interpreting a model locally is supposed to be easier than interpreting the model globally, but harder to maintain (thinking about the curse of dimensionality). Methods described below aim to explain the behavior of a model as a whole. However, the global approach is unable to capture the fine-grained interpretation, such as a feature might be important in this region but not at all in another.\n Feature Selection Essentially all the classic feature selection methods (Yang and Pedersen, 1997; Guyon and Elisseeff, 2003) can be considered as ways to explain a model globally. Feature selection methods decompose the contribution of multiple features so that we can explain the overall model output by individual feature impact.\nThere are a ton of resources on feature selection so I would skip the topic in this post.\nBETA (Black Box Explanation through Transparent Approximations) BETA, short for black box explanation through transparent approximations, is closely connected to Interpretable Decision Sets (Lakkaraju, Bach \u0026amp; Leskovec, 2016). BETA learns a compact two-level decision set in which each rule explains part of the model behavior unambiguously.\nThe authors proposed an novel objective function so that the learning process is optimized for high fidelity (high agreement between explanation and the model), low unambiguity (little overlaps between decision rules in the explanation), and high interpretability (the explanation decision set is lightweight and small). These aspects are combined into one objection function to optimize for.\nFig. 6. Measures for desiderata of a good model explanation: fidelity, unambiguity, and interpretability. Given the target model is $\\mathcal{B}$, its explanation is a two level decision set $\\Re$ containing a set of rules ${(q\\_1, s\\_1, c\\_1), \\dots, (q\\_M, s\\_M, c\\_M)}$, where $q\\_i$ and $s\\_i$ are conjunctions of predicates of the form (feature, operator, value) and $c\\_i$ is a class label. Check the paper for more details. (Image source: arxiv.org/abs/1707.01154) Explainable Artificial Intelligence I borrow the name of this section from the DARPA project \u0026ldquo;Explainable Artificial Intelligence\u0026rdquo;. This Explainable AI (XAI) program aims to develop more interpretable models and to enable human to understand, appropriately trust, and effectively manage the emerging generation of artificially intelligent techniques.\nWith the progress of the deep learning applications, people start worrying about that we may never know even if the model goes bad. The complicated structure, the large number of learnable parameters, the nonlinear mathematical operations and some intriguing properties (Szegedy et al., 2014) lead to the un-interpretability of deep neural networks, creating a true black-box. Although the power of deep learning is originated from this complexity \u0026mdash; more flexible to capture rich and intricate patterns in the real-world data.\nStudies on adversarial examples (OpenAI Blog: Robust Adversarial Examples, Attacking Machine Learning with Adversarial Examples, Goodfellow, Shlens \u0026amp; Szegedy, 2015; Nguyen, Yosinski, \u0026amp; Clune, 2015) raise the alarm on the robustness and safety of AI applications. Sometimes the models could show unintended, unexpected and unpredictable behavior and we have no fast/good strategy to tell why.\nFig. 7. Illustrations of adversarial examples. (a-d) are adversarial images that are generated by adding human-imperceptible noises onto original images (Szegedy et al., 2013). A well-trained neural network model can successfully classify original ones but fail adversarial ones. (e-h) are patterns that are generated (Nguyen, Yosinski \u0026 Clune, 2015). A well-trained neural network model labels them into (e) school bus, (f) guitar, (g) peacock and (h) Pekinese respectively. (Image source: Wang, Raj \u0026 Xing, 2017) Nvidia recently developed a method to visualize the most important pixel points in their self-driving cars' decisioning process. The visualization provides insights on how AI thinks and what the system relies on while operating the car. If what the AI believes to be important agrees with how human make similar decisions, we can naturally gain more confidence in the black-box model.\nMany exciting news and findings are happening in this evolving field every day. Hope my post can give you some pointers and encourage you to investigate more into this topic :)\n Cited as:\n@article{weng2017gan, title = \u0026quot;How to Explain the Prediction of a Machine Learning Model?\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2017\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2017-08-01-interpretation/\u0026quot; } References [1] Zachary C. Lipton. \u0026ldquo;The mythos of model interpretability.\u0026quot; arXiv preprint arXiv:1606.03490 (2016).\n[2] Been Kim, Rajiv Khanna, and Oluwasanmi O. Koyejo. \u0026ldquo;Examples are not enough, learn to criticize! criticism for interpretability.\u0026rdquo; Advances in Neural Information Processing Systems. 2016.\n[3] Himabindu Lakkaraju, Stephen H. Bach, and Jure Leskovec. \u0026ldquo;Interpretable decision sets: A joint framework for description and prediction.\u0026quot; Proc. 22nd ACM SIGKDD Intl. Conf. on Knowledge Discovery and Data Mining. ACM, 2016.\n[4] Robnik-Šikonja, Marko, and Igor Kononenko. \u0026ldquo;Explaining classifications for individual instances.\u0026quot; IEEE Transactions on Knowledge and Data Engineering 20.5 (2008): 589-600.\n[5] Baehrens, David, et al. \u0026ldquo;How to explain individual classification decisions.\u0026quot; Journal of Machine Learning Research 11.Jun (2010): 1803-1831.\n[6] Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin. \u0026ldquo;Why should I trust you?: Explaining the predictions of any classifier.\u0026quot; Proc. 22nd ACM SIGKDD Intl. Conf. on Knowledge Discovery and Data Mining. ACM, 2016.\n[7] Yiming Yang, and Jan O. Pedersen. \u0026ldquo;A comparative study on feature selection in text categorization.\u0026quot; Intl. Conf. on Machine Learning. Vol. 97. 1997.\n[8] Isabelle Guyon, and André Elisseeff. \u0026ldquo;An introduction to variable and feature selection.\u0026quot; Journal of Machine Learning Research 3.Mar (2003): 1157-1182.\n[9] Ian J. Goodfellow, Jonathon Shlens, and Christian Szegedy. \u0026ldquo;Explaining and harnessing adversarial examples.\u0026quot; ICLR 2015.\n[10] Christian Szegedy, Wojciech Zaremba, Ilya Sutskever, Joan Bruna, Dumitru Erhan, Ian Goodfellow, Rob Fergus. \u0026ldquo;Intriguing properties of neural networks.\u0026quot; Intl. Conf. on Learning Representations (2014)\n[11] Nguyen, Anh, Jason Yosinski, and Jeff Clune. \u0026ldquo;Deep neural networks are easily fooled: High confidence predictions for unrecognizable images.\u0026quot; Proc. IEEE Conference on Computer Vision and Pattern Recognition. 2015.\n[12] Benjamin Letham, Cynthia Rudin, Tyler H. McCormick, and David Madigan. \u0026ldquo;Interpretable classifiers using rules and Bayesian analysis: Building a better stroke prediction model.\u0026quot; The Annals of Applied Statistics 9, No. 3 (2015): 1350-1371.\n[13] Haohan Wang, Bhiksha Raj, and Eric P. Xing. \u0026ldquo;On the Origin of Deep Learning.\u0026quot; arXiv preprint arXiv:1702.07800 (2017).\n[14] OpenAI Blog: Robust Adversarial Examples\n[15] Attacking Machine Learning with Adversarial Examples\n[16] Reading an AI Car’s Mind: How NVIDIA’s Neural Net Makes Decisions\n","permalink":"https://lilianweng.github.io/posts/2017-08-01-interpretation/","summary":"The machine learning models have started penetrating into critical areas like health care, justice systems, and financial industry. Thus to figure out how the models make the decisions and make sure the decisioning process is aligned with the ethnic requirements or legal regulations becomes a necessity.\nMeanwhile, the rapid growth of deep learning models pushes the requirement of interpreting complicated models further. People are eager to apply the power of AI fully on key aspects of everyday life.","title":"How to Explain the Prediction of a Machine Learning Model?"},{"content":"In the Part 2 tutorial, I would like to continue the topic on stock price prediction and to endow the recurrent neural network that I have built in Part 1 with the capability of responding to multiple stocks. In order to distinguish the patterns associated with different price sequences, I use the stock symbol embedding vectors as part of the input.\n Dataset During the search, I found this library for querying Yahoo! Finance API. It would be very useful if Yahoo hasn’t shut down the historical data fetch API. You may find it useful for querying other information though. Here I pick the Google Finance link, among a couple of free data sources for downloading historical stock prices.\nThe data fetch code can be written as simple as:\nimport urllib2 from datetime import datetime BASE_URL = \u0026#34;https://www.google.com/finance/historical?\u0026#34; \u0026#34;output=csv\u0026amp;q={0}\u0026amp;startdate=Jan+1%2C+1980\u0026amp;enddate={1}\u0026#34; symbol_url = BASE_URL.format( urllib2.quote(\u0026#39;GOOG\u0026#39;), # Replace with any stock you are interested. urllib2.quote(datetime.now().strftime(\u0026#34;%b+%d,+%Y\u0026#34;), \u0026#39;+\u0026#39;) ) When fetching the content, remember to add try-catch wrapper in case the link fails or the provided stock symbol is not valid.\ntry: f = urllib2.urlopen(symbol_url) with open(\u0026#34;GOOG.csv\u0026#34;, \u0026#39;w\u0026#39;) as fin: print \u0026gt;\u0026gt; fin, f.read() except urllib2.HTTPError: print \u0026#34;Fetching Failed: {}\u0026#34;.format(symbol_url) The full working data fetcher code is available here.\nModel Construction The model is expected to learn the price sequences of different stocks in time. Due to the different underlying patterns, I would like to tell the model which stock it is dealing with explicitly. Embedding is more favored than one-hot encoding, because:\n Given that the train set includes $N$ stocks, the one-hot encoding would introduce $N$ (or $N-1$) additional sparse feature dimensions. Once each stock symbol is mapped onto a much smaller embedding vector of length $k$, $k \\ll N$, we end up with a much more compressed representation and smaller dataset to take care of. Since embedding vectors are variables to learn. Similar stocks could be associated with similar embeddings and help the prediction of each others, such as \u0026ldquo;GOOG\u0026rdquo; and \u0026ldquo;GOOGL\u0026rdquo; which you will see in Fig. 5. later. In the recurrent neural network, at one time step $t$, the input vector contains input_size (labelled as $w$) daily price values of $i$-th stock, $(p_{i, tw}, p_{i, tw+1}, \\dots, p_{i, (t+1)w-1})$. The stock symbol is uniquely mapped to a vector of length embedding_size (labelled as $k$), $(e_{i,0}, e_{i,1}, \\dots, e_{i,k})$. As illustrated in Fig. 1., the price vector is concatenated with the embedding vector and then fed into the LSTM cell.\nAnother alternative is to concatenate the embedding vectors with the last state of the LSTM cell and learn new weights $W$ and bias $b$ in the output layer. However, in this way, the LSTM cell cannot tell apart prices of one stock from another and its power would be largely restrained. Thus I decided to go with the former approach.\nFig. 1. The architecture of the stock price prediction RNN model with stock symbol embeddings. Two new configuration settings are added into RNNConfig:\n embedding_size controls the size of each embedding vector; stock_count refers to the number of unique stocks in the dataset. Together they define the size of the embedding matrix, for which the model has to learn embedding_size $\\times$ stock_count additional variables compared to the model in Part 1.\nclass RNNConfig(): # ... old ones embedding_size = 3 stock_count = 50 Define the Graph \u0026mdash; Let\u0026rsquo;s start going through some code \u0026mdash;\n(1) As demonstrated in tutorial Part 1: Define the Graph, let us define a tf.Graph() named lstm_graph and a set of tensors to hold input data, inputs, targets, and learning_rate in the same way. One more placeholder to define is a list of stock symbols associated with the input prices. Stock symbols have been mapped to unique integers beforehand with label encoding.\n# Mapped to an integer. one label refers to one stock symbol. stock_labels = tf.placeholder(tf.int32, [None, 1]) (2) Then we need to set up an embedding matrix to play as a lookup table, containing the embedding vectors of all the stocks. The matrix is initialized with random numbers in the interval [-1, 1] and gets updated during training.\n# NOTE: config = RNNConfig() and it defines hyperparameters. # Convert the integer labels to numeric embedding vectors. embedding_matrix = tf.Variable( tf.random_uniform([config.stock_count, config.embedding_size], -1.0, 1.0) ) (3) Repeat the stock labels num_steps times to match the unfolded version of RNN and the shape of inputs tensor during training. The transformation operation tf.tile receives a base tensor and creates a new tensor by replicating its certain dimensions multiples times; precisely the $i$-th dimension of the input tensor gets multiplied by multiples[i] times. For example, if the stock_labels is [[0], [0], [2], [1]] tiling it by [1, 5] produces [[0 0 0 0 0], [0 0 0 0 0], [2 2 2 2 2], [1 1 1 1 1]].\nstacked_stock_labels = tf.tile(stock_labels, multiples=[1, config.num_steps]) (4) Then we map the symbols to embedding vectors according to the lookup table embedding_matrix.\n# stock_label_embeds.get_shape() = (?, num_steps, embedding_size). stock_label_embeds = tf.nn.embedding_lookup(embedding_matrix, stacked_stock_labels) (5) Finally, combine the price values with the embedding vectors. The operation tf.concat concatenates a list of tensors along the dimension axis. In our case, we want to keep the batch size and the number of steps unchanged, but only extend the input vector of length input_size to include embedding features.\n# inputs.get_shape() = (?, num_steps, input_size) # stock_label_embeds.get_shape() = (?, num_steps, embedding_size) # inputs_with_embeds.get_shape() = (?, num_steps, input_size + embedding_size) inputs_with_embeds = tf.concat([inputs, stock_label_embeds], axis=2) The rest of code runs the dynamic RNN, extracts the last state of the LSTM cell, and handles weights and bias in the output layer. See Part 1: Define the Graph for the details.\nTraining Session Please read Part 1: Start Training Session if you haven\u0026rsquo;t for how to run a training session in Tensorflow.\nBefore feeding the data into the graph, the stock symbols should be transformed to unique integers with label encoding.\nfrom sklearn.preprocessing import LabelEncoder label_encoder = LabelEncoder() label_encoder.fit(list_of_symbols) The train/test split ratio remains same, 90% for training and 10% for testing, for every individual stock.\nVisualize the Graph After the graph is defined in code, let us check the visualization in Tensorboard to make sure that components are constructed correctly. Essentially it looks very much like our architecture illustration in Fig. 1.\nFig. 2. Tensorboard visualization of the graph defined above. Two modules, \"train\" and \"save\", have been removed from the main graph. Other than presenting the graph structure or tracking the variables in time, Tensorboard also supports embeddings visualization. In order to communicate the embedding values to Tensorboard, we need to add proper tracking in the training logs.\n(0) In my embedding visualization, I want to color each stock with its industry sector. This metadata should stored in a csv file. The file has two columns, the stock symbol and the industry sector. It does not matter whether the csv file has header, but the order of the listed stocks must be consistent with label_encoder.classes_.\nimport csv embedding_metadata_path = os.path.join(your_log_file_folder, \u0026#39;metadata.csv\u0026#39;) with open(embedding_metadata_path, \u0026#39;w\u0026#39;) as fout: csv_writer = csv.writer(fout) # write the content into the csv file. # for example, csv_writer.writerows([\u0026#34;GOOG\u0026#34;, \u0026#34;information_technology\u0026#34;]) (1) Set up the summary writer first within the training tf.Session.\nfrom tensorflow.contrib.tensorboard.plugins import projector with tf.Session(graph=lstm_graph) as sess: summary_writer = tf.summary.FileWriter(your_log_file_folder) summary_writer.add_graph(sess.graph) (2) Add the tensor embedding_matrix defined in our graph lstm_graph into the projector config variable and attach the metadata csv file.\nprojector_config = projector.ProjectorConfig() # You can add multiple embeddings. Here we add only one. added_embedding = projector_config.embeddings.add() added_embedding.tensor_name = embedding_matrix.name # Link this tensor to its metadata file. added_embedding.metadata_path = embedding_metadata_path (3) This line creates a file projector_config.pbtxt in the folder your_log_file_folder. TensorBoard will read this file during startup.\nprojector.visualize_embeddings(summary_writer, projector_config) Results The model is trained with top 50 stocks with largest market values in the S\u0026amp;P 500 index.\n(Run the following command within github.com/lilianweng/stock-rnn)\npython main.py --stock_count=50 --embed_size=3 --input_size=3 --max_epoch=50 --train And the following configuration is used:\nstock_count = 100 input_size = 3 embed_size = 3 num_steps = 30 lstm_size = 256 num_layers = 1 max_epoch = 50 keep_prob = 0.8 batch_size = 64 init_learning_rate = 0.05 learning_rate_decay = 0.99 init_epoch = 5 Price Prediction As a brief overview of the prediction quality, Fig. 3 plots the predictions for test data of \u0026ldquo;KO\u0026rdquo;, \u0026ldquo;AAPL\u0026rdquo;, \u0026ldquo;GOOG\u0026rdquo; and \u0026ldquo;NFLX\u0026rdquo;. The overall trends matched up between the true values and the predictions. Considering how the prediction task is designed, the model relies on all the historical data points to predict only next 5 (input_size) days. With a small input_size, the model does not need to worry about the long-term growth curve. Once we increase input_size, the prediction would be much harder.\nFig. 3. True and predicted stock prices of AAPL, MSFT and GOOG in the test set. The prices are normalized across consecutive prediction sliding windows (See Part 1: Normalization. The y-axis values get multiplied by 5 for a better comparison between true and predicted trends. Embedding Visualization One common technique to visualize the clusters in embedding space is t-SNE (Maaten and Hinton, 2008), which is well supported in Tensorboard. t-SNE, short for “t-Distributed Stochastic Neighbor Embedding, is a variation of Stochastic Neighbor Embedding (Hinton and Roweis, 2002), but with a modified cost function that is easier to optimize.\n Similar to SNE, t-SNE first converts the high-dimensional Euclidean distances between data points into conditional probabilities that represent similarities. t-SNE defines a similar probability distribution over the data points in the low-dimensional space, and it minimizes the Kullback–Leibler divergence between the two distributions with respect to the locations of the points on the map. Check this post for how to adjust the parameters, Perplexity and learning rate (epsilon), in t-SNE visualization.\nFig. 4. Visualization of the stock embeddings using t-SNE. Each label is colored based on the stock industry sector. We have 5 clusters. Interstingly, GOOG, GOOGL and FB belong to the same cluster, while AMZN and AAPL stay in another. In the embedding space, we can measure the similarity between two stocks by examining the similarity between their embedding vectors. For example, GOOG is mostly similar to GOOGL in the learned embeddings (See Fig. 5).\nFig. 5. \"GOOG\" is clicked in the embedding visualization graph and top 20 similar neighbors are highlighted with colors from dark to light as the similarity decreases. Known Problems The prediction values get diminished and flatten quite a lot as the training goes. That\u0026rsquo;s why I multiplied the absolute values by a constant to make the trend is more visible in Fig. 3., as I\u0026rsquo;m more curious about whether the prediction on the up-or-down direction right. However, there must be a reason for the diminishing prediction value problem. Potentially rather than using simple MSE as the loss, we can adopt another form of loss function to penalize more when the direction is predicted wrong. The loss function decreases fast at the beginning, but it suffers from occasional value explosion (a sudden peak happens and then goes back immediately). I suspect it is related to the form of loss function too. A updated and smarter loss function might be able to resolve the issue. The full code in this tutorial is available in github.com/lilianweng/stock-rnn.\n","permalink":"https://lilianweng.github.io/posts/2017-07-22-stock-rnn-part-2/","summary":"In the Part 2 tutorial, I would like to continue the topic on stock price prediction and to endow the recurrent neural network that I have built in Part 1 with the capability of responding to multiple stocks. In order to distinguish the patterns associated with different price sequences, I use the stock symbol embedding vectors as part of the input.\n Dataset During the search, I found this library for querying Yahoo!","title":"Predict Stock Prices Using RNN: Part 2"},{"content":"This is a tutorial for how to build a recurrent neural network using Tensorflow to predict stock market prices. The full working code is available in github.com/lilianweng/stock-rnn. If you don\u0026rsquo;t know what is recurrent neural network or LSTM cell, feel free to check my previous post.\n One thing I would like to emphasize that because my motivation for writing this post is more on demonstrating how to build and train an RNN model in Tensorflow and less on solve the stock prediction problem, I didn\u0026rsquo;t try hard on improving the prediction outcomes. You are more than welcome to take my code as a reference point and add more stock prediction related ideas to improve it. Enjoy!\n Overview of Existing Tutorials There are many tutorials on the Internet, like:\n A noob\u0026rsquo;s guide to implementing RNN-LSTM using Tensorflow TensorFlow RNN Tutorial LSTM by Example using Tensorflow How to build a Recurrent Neural Network in TensorFlow RNNs in Tensorflow, a Practical Guide and Undocumented Features Sequence prediction using recurrent neural networks(LSTM) with TensorFlow Anyone Can Learn To Code an LSTM-RNN in Python How to do time series prediction using RNNs, TensorFlow and Cloud ML Engine Despite all these existing tutorials, I still want to write a new one mainly for three reasons:\n Early tutorials cannot cope with the new version any more, as Tensorflow is still under development and changes on API interfaces are being made fast. Many tutorials use synthetic data in the examples. Well, I would like to play with the real world data. Some tutorials assume that you have known something about Tensorflow API beforehand, which makes the reading a bit difficult. After reading a bunch of examples, I would like to suggest taking the official example on Penn Tree Bank (PTB) dataset as your starting point. The PTB example showcases a RNN model in a pretty and modular design pattern, but it might prevent you from easily understanding the model structure. Hence, here I will build up the graph in a very straightforward manner.\nThe Goal I will explain how to build an RNN model with LSTM cells to predict the prices of S\u0026amp;P500 index. The dataset can be downloaded from Yahoo! Finance ^GSPC. In the following example, I used S\u0026amp;P 500 data from Jan 3, 1950 (the maximum date that Yahoo! Finance is able to trace back to) to Jun 23, 2017. The dataset provides several price points per day. For simplicity, we will only use the daily close prices for prediction. Meanwhile, I will demonstrate how to use TensorBoard for easily debugging and model tracking.\nAs a quick recap: the recurrent neural network (RNN) is a type of artificial neural network with self-loop in its hidden layer(s), which enables RNN to use the previous state of the hidden neuron(s) to learn the current state given the new input. RNN is good at processing sequential data. Long short-term memory (LSTM) cell is a specially designed working unit that helps RNN better memorize the long-term context.\nFor more information in depth, please read my previous post or this awesome post.\nData Preparation The stock prices is a time series of length $N$, defined as $p_0, p_1, \\dots, p_{N-1}$ in which $p_i$ is the close price on day $i$, $0 \\le i \u0026lt; N$. Imagine that we have a sliding window of a fixed size $w$ (later, we refer to this as input_size) and every time we move the window to the right by size $w$, so that there is no overlap between data in all the sliding windows.\nFig. 1. The S\u0026P 500 prices in time. We use content in one sliding windows to make prediction for the next, while there is no overlap between two consecutive windows. The RNN model we are about to build has LSTM cells as basic hidden units. We use values from the very beginning in the first sliding window $W_0$ to the window $W_t$ at time $t$:\n $$ \\begin{aligned} W_0 \u0026= (p_0, p_1, \\dots, p_{w-1}) \\\\ W_1 \u0026= (p_w, p_{w+1}, \\dots, p_{2w-1}) \\\\ \\dots \\\\ W_t \u0026= (p_{tw}, p_{tw+1}, \\dots, p_{(t+1)w-1}) \\end{aligned} $$ to predict the prices in the following window $w_{t+1}$:\n$$ W_{t+1} = (p_{(t+1)w}, p_{(t+1)w+1}, \\dots, p_{(t+2)w-1}) $$\nEssentially we try to learn an approximation function, $f(W_0, W_1, \\dots, W_t) \\approx W_{t+1}$.\nFig. 2 The unrolled version of RNN. Considering how back propagation through time (BPTT) works, we usually train RNN in a “unrolled” version so that we don\u0026rsquo;t have to do propagation computation too far back and save the training complication.\nHere is the explanation on num_steps from Tensorflow\u0026rsquo;s tutorial:\n By design, the output of a recurrent neural network (RNN) depends on arbitrarily distant inputs. Unfortunately, this makes backpropagation computation difficult. In order to make the learning process tractable, it is common practice to create an \u0026ldquo;unrolled\u0026rdquo; version of the network, which contains a fixed number (num_steps) of LSTM inputs and outputs. The model is then trained on this finite approximation of the RNN. This can be implemented by feeding inputs of length num_steps at a time and performing a backward pass after each such input block.\n The sequence of prices are first split into non-overlapped small windows. Each contains input_size numbers and each is considered as one independent input element. Then any num_steps consecutive input elements are grouped into one training input, forming an \u0026ldquo;un-rolled\u0026rdquo; version of RNN for training on Tensorfow. The corresponding label is the input element right after them.\nFor instance, if input_size=3 and num_steps=2, my first few training examples would look like:\n $$ \\begin{aligned} \\text{Input}_1 \u0026= [[p_0, p_1, p_2], [p_3, p_4, p_5]]\\quad\\text{Label}_1 = [p_6, p_7, p_8] \\\\ \\text{Input}_2 \u0026= [[p_3, p_4, p_5], [p_6, p_7, p_8]]\\quad\\text{Label}_2 = [p_9, p_{10}, p_{11}] \\\\ \\text{Input}_3 \u0026= [[p_6, p_7, p_8], [p_9, p_{10}, p_{11}]]\\quad\\text{Label}_3 = [p_{12}, p_{13}, p_{14}] \\end{aligned} $$ Here is the key part for formatting the data:\nseq = [np.array(seq[i * self.input_size: (i + 1) * self.input_size]) for i in range(len(seq) // self.input_size)] # Split into groups of `num_steps` X = np.array([seq[i: i + self.num_steps] for i in range(len(seq) - self.num_steps)]) y = np.array([seq[i + self.num_steps] for i in range(len(seq) - self.num_steps)]) The complete code of data formatting is here.\nTrain / Test Split Since we always want to predict the future, we take the latest 10% of data as the test data.\nNormalization The S\u0026amp;P 500 index increases in time, bringing about the problem that most values in the test set are out of the scale of the train set and thus the model has to predict some numbers it has never seen before. Sadly and unsurprisingly, it does a tragic job. See Fig. 3.\nFig. 3 A very sad example when the RNN model have to predict numbers out of the scale of the training data. To solve the out-of-scale issue, I normalize the prices in each sliding window. The task becomes predicting the relative change rates instead of the absolute values. In a normalized sliding window $W'_t$ at time $t$, all the values are divided by the last unknown price\u0026mdash;the last price in $W_{t-1}$:\n$$ W'_t = (\\frac{p_{tw}}{p_{tw-1}}, \\frac{p_{tw+1}}{p_{tw-1}}, \\dots, \\frac{p_{(t+1)w-1}}{p_{tw-1}}) $$\nHere is a data archive stock-data-lilianweng.tar.gz of S \u0026amp; P 500 stock prices I crawled up to Jul, 2017. Feel free to play with it :)\nModel Construction Definitions lstm_size: number of units in one LSTM layer. num_layers: number of stacked LSTM layers. keep_prob: percentage of cell units to keep in the dropout operation. init_learning_rate: the learning rate to start with. learning_rate_decay: decay ratio in later training epochs. init_epoch: number of epochs using the constant init_learning_rate. max_epoch: total number of epochs in training input_size: size of the sliding window / one training data point batch_size: number of data points to use in one mini-batch. The LSTM model has num_layers stacked LSTM layer(s) and each layer contains lstm_size number of LSTM cells. Then a dropout mask with keep probability keep_prob is applied to the output of every LSTM cell. The goal of dropout is to remove the potential strong dependency on one dimension so as to prevent overfitting.\nThe training requires max_epoch epochs in total; an epoch is a single full pass of all the training data points. In one epoch, the training data points are split into mini-batches of size batch_size. We send one mini-batch to the model for one BPTT learning. The learning rate is set to init_learning_rate during the first init_epoch epochs and then decay by $\\times$ learning_rate_decay during every succeeding epoch.\n# Configuration is wrapped in one object for easy tracking and passing. class RNNConfig(): input_size=1 num_steps=30 lstm_size=128 num_layers=1 keep_prob=0.8 batch_size = 64 init_learning_rate = 0.001 learning_rate_decay = 0.99 init_epoch = 5 max_epoch = 50 config = RNNConfig() Define Graph A tf.Graph is not attached to any real data. It defines the flow of how to process the data and how to run the computation. Later, this graph can be fed with data within a tf.session and at this moment the computation happens for real.\n\u0026mdash; Let\u0026rsquo;s start going through some code \u0026mdash;\n(1) Initialize a new graph first.\nimport tensorflow as tf tf.reset_default_graph() lstm_graph = tf.Graph() (2) How the graph works should be defined within its scope.\nwith lstm_graph.as_default(): (3) Define the data required for computation. Here we need three input variables, all defined as tf.placeholder because we don\u0026rsquo;t know what they are at the graph construction stage.\n inputs: the training data X, a tensor of shape (# data examples, num_steps, input_size); the number of data examples is unknown, so it is None. In our case, it would be batch_size in training session. Check the input format example if confused. targets: the training label y, a tensor of shape (# data examples, input_size). learning_rate: a simple float. # Dimension = ( # number of data examples, # number of input in one computation step, # number of numbers in one input # ) # We don\u0026#39;t know the number of examples beforehand, so it is None. inputs = tf.placeholder(tf.float32, [None, config.num_steps, config.input_size]) targets = tf.placeholder(tf.float32, [None, config.input_size]) learning_rate = tf.placeholder(tf.float32, None) (4) This function returns one LSTMCell with or without dropout operation.\ndef _create_one_cell(): return tf.contrib.rnn.LSTMCell(config.lstm_size, state_is_tuple=True) if config.keep_prob \u0026lt; 1.0: return tf.contrib.rnn.DropoutWrapper(lstm_cell, output_keep_prob=config.keep_prob) (5) Let\u0026rsquo;s stack the cells into multiple layers if needed. MultiRNNCell helps connect sequentially multiple simple cells to compose one cell.\ncell = tf.contrib.rnn.MultiRNNCell( [_create_one_cell() for _ in range(config.num_layers)], state_is_tuple=True ) if config.num_layers \u0026gt; 1 else _create_one_cell() (6) tf.nn.dynamic_rnn constructs a recurrent neural network specified by cell (RNNCell). It returns a pair of (model outpus, state), where the outputs val is of size (batch_size, num_steps, lstm_size) by default. The state refers to the current state of the LSTM cell, not consumed here.\nval, _ = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32) (7) tf.transpose converts the outputs from the dimension (batch_size, num_steps, lstm_size) to (num_steps, batch_size, lstm_size). Then the last output is picked.\n# Before transpose, val.get_shape() = (batch_size, num_steps, lstm_size) # After transpose, val.get_shape() = (num_steps, batch_size, lstm_size) val = tf.transpose(val, [1, 0, 2]) # last.get_shape() = (batch_size, lstm_size) last = tf.gather(val, int(val.get_shape()[0]) - 1, name=\u0026#34;last_lstm_output\u0026#34;) (8) Define weights and biases between the hidden and output layers.\nweight = tf.Variable(tf.truncated_normal([config.lstm_size, config.input_size])) bias = tf.Variable(tf.constant(0.1, shape=[config.input_size])) prediction = tf.matmul(last, weight) + bias (9) We use mean square error as the loss metric and the RMSPropOptimizer algorithm for gradient descent optimization.\nloss = tf.reduce_mean(tf.square(prediction - targets)) optimizer = tf.train.RMSPropOptimizer(learning_rate) minimize = optimizer.minimize(loss) Start Training Session (1) To start training the graph with real data, we need to start a tf.session first.\nwith tf.Session(graph=lstm_graph) as sess: (2) Initialize the variables as defined.\ntf.global_variables_initializer().run() (0) The learning rates for training epochs should have been precomputed beforehand. The index refers to the epoch index.\nlearning_rates_to_use = [ config.init_learning_rate * ( config.learning_rate_decay ** max(float(i + 1 - config.init_epoch), 0.0) ) for i in range(config.max_epoch)] (3) Each loop below completes one epoch training.\nfor epoch_step in range(config.max_epoch): current_lr = learning_rates_to_use[epoch_step] # Check https://github.com/lilianweng/stock-rnn/blob/master/data_wrapper.py # if you are curious to know what is StockDataSet and how generate_one_epoch() # is implemented. for batch_X, batch_y in stock_dataset.generate_one_epoch(config.batch_size): train_data_feed = { inputs: batch_X, targets: batch_y, learning_rate: current_lr } train_loss, _ = sess.run([loss, minimize], train_data_feed) (4) Don\u0026rsquo;t forget to save your trained model at the end.\nsaver = tf.train.Saver() saver.save(sess, \u0026#34;your_awesome_model_path_and_name\u0026#34;, global_step=max_epoch_step) The complete code is available here.\nUse TensorBoard Building the graph without visualization is like drawing in the dark, very obscure and error-prone. Tensorboard provides easy visualization of the graph structure and the learning process. Check out this hand-on tutorial, only 20 min, but it is very practical and showcases several live demos.\nBrief Summary\n Use with [tf.name_scope](https://www.tensorflow.org/api_docs/python/tf/name_scope)(\u0026quot;your_awesome_module_name\u0026quot;): to wrap elements working on the similar goal together. Many tf.* methods accepts name= argument. Assigning a customized name can make your life much easier when reading the graph. Methods like tf.summary.scalar and tf.summary.histogram help track the values of variables in the graph during iterations. In the training session, define a log file using tf.summary.FileWriter. with tf.Session(graph=lstm_graph) as sess: merged_summary = tf.summary.merge_all() writer = tf.summary.FileWriter(\u0026#34;location_for_keeping_your_log_files\u0026#34;, sess.graph) writer.add_graph(sess.graph) Later, write the training progress and summary results into the file.\n_summary = sess.run([merged_summary], test_data_feed) writer.add_summary(_summary, global_step=epoch_step) # epoch_step in range(config.max_epoch) Fig. 4a The RNN graph built by the example code. The \"train\" module has been \"removed from the main graph\", as it is not a real part of the model during the prediction time. Fig. 4b Click the \"output_layer\" module to expand it and check the structure in details. The full working code is available in github.com/lilianweng/stock-rnn.\nResults I used the following configuration in the experiment.\nnum_layers=1 keep_prob=0.8 batch_size = 64 init_learning_rate = 0.001 learning_rate_decay = 0.99 init_epoch = 5 max_epoch = 100 num_steps=30 (Thanks to Yury for cathcing a bug that I had in the price normalization. Instead of using the last price of the previous time window, I ended up with using the last price in the same window. The following plots have been corrected.)\nOverall predicting the stock prices is not an easy task. Especially after normalization, the price trends look very noisy.\nFig. 5a Predictoin results for the last 200 days in test data. Model is trained with input_size=1 and lstm_size=32. Fig. 5b Predictoin results for the last 200 days in test data. Model is trained with input_size=1 and lstm_size=128. Fig. 5c Predictoin results for the last 200 days in test data. Model is trained with input_size=5, lstm_size=128 and max_epoch=75 (instead of 50). The example code in this tutorial is available in github.com/lilianweng/stock-rnn:scripts.\n(Updated on Sep 14, 2017) The model code has been updated to be wrapped into a class: LstmRNN. The model training can be triggered by main.py, such as:\npython main.py --stock_symbol=SP500 --train --input_size=1 --lstm_size=128 ","permalink":"https://lilianweng.github.io/posts/2017-07-08-stock-rnn-part-1/","summary":"This is a tutorial for how to build a recurrent neural network using Tensorflow to predict stock market prices. The full working code is available in github.com/lilianweng/stock-rnn. If you don\u0026rsquo;t know what is recurrent neural network or LSTM cell, feel free to check my previous post.\n One thing I would like to emphasize that because my motivation for writing this post is more on demonstrating how to build and train an RNN model in Tensorflow and less on solve the stock prediction problem, I didn\u0026rsquo;t try hard on improving the prediction outcomes.","title":"Predict Stock Prices Using RNN: Part 1"},{"content":"(The post was originated from my talk for WiMLDS x Fintech meetup hosted by Affirm.)\nI believe many of you have watched or heard of the games between AlphaGo and professional Go player Lee Sedol in 2016. Lee has the highest rank of nine dan and many world championships. No doubt, he is one of the best Go players in the world, but he lost by 1-4 in this series versus AlphaGo. Before this, Go was considered to be an intractable game for computers to master, as its simple rules lay out an exponential number of variations in the board positions, many more than what in Chess. This event surely highlighted 2016 as a big year for AI. Because of AlphaGo, much attention has been attracted to the progress of AI.\nMeanwhile, many companies are spending resources on pushing the edges of AI applications, that indeed have the potential to change or even revolutionize how we are gonna live. Familiar examples include self-driving cars, chatbots, home assistant devices and many others. One of the secret receipts behind the progress we have had in recent years is deep learning.\nWhy Does Deep Learning Work Now? Deep learning models, in simple words, are large and deep artificial neural nets. A neural network (\u0026ldquo;NN\u0026rdquo;) can be well presented in a directed acyclic graph: the input layer takes in signal vectors; one or multiple hidden layers process the outputs of the previous layer. The initial concept of a neural network can be traced back to more than half a century ago. But why does it work now? Why do people start talking about them all of a sudden?\nFig. 1. A three-layer artificial neural network. (Image source: http://cs231n.github.io/convolutional-networks/#conv) The reason is surprisingly simple:\n We have a lot more data. We have much powerful computers. A large and deep neural network has many more layers + many more nodes in each layer, which results in exponentially many more parameters to tune. Without enough data, we cannot learn parameters efficiently. Without powerful computers, learning would be too slow and insufficient.\nHere is an interesting plot presenting the relationship between the data scale and the model performance, proposed by Andrew Ng in his \u0026ldquo;Nuts and Bolts of Applying Deep Learning\u0026rdquo; talk. On a small dataset, traditional algorithms (Regression, Random Forests, SVM, GBM, etc.) or statistical learning does a great job, but once the data scale goes up to the sky, the large NN outperforms others. Partially because compared to a traditional ML model, a neural network model has many more parameters and has the capability to learn complicated nonlinear patterns. Thus we expect the model to pick the most helpful features by itself without too much expert-involved manual feature engineering.\nFig. 2. The data scale versus the model performance. (Recreated based on: https://youtu.be/F1ka6a13S9I) Deep Learning Models Next, let\u0026rsquo;s go through a few classical deep learning models.\nConvolutional Neural Network Convolutional neural networks, short for \u0026ldquo;CNN\u0026rdquo;, is a type of feed-forward artificial neural networks, in which the connectivity pattern between its neurons is inspired by the organization of the visual cortex system. The primary visual cortex (V1) does edge detection out of the raw visual input from the retina. The secondary visual cortex (V2), also called prestriate cortex, receives the edge features from V1 and extracts simple visual properties such as orientation, spatial frequency, and color. The visual area V4 handles more complicated object attributes. All the processed visual features flow into the final logic unit, inferior temporal gyrus (IT), for object recognition. The shortcut between V1 and V4 inspires a special type of CNN with connections between non-adjacent layers: Residual Net (He, et al. 2016) containing \u0026ldquo;Residual Block\u0026rdquo; which supports some input of one layer to be passed to the component two layers later.\nFig. 3. Illustration of the human visual cortex system. (Image source: Wang \u0026 Raj 2017) Convolution is a mathematical term, here referring to an operation between two matrices. The convolutional layer has a fixed small matrix defined, also called kernel or filter. As the kernel is sliding, or convolving, across the matrix representation of the input image, it is computing the element-wise multiplication of the values in the kernel matrix and the original image values. Specially designed kernels can process images for common purposes like blurring, sharpening, edge detection and many others, fast and efficiently.\nFig. 4. The LeNet architecture consists of two sets of convolutional, activation, and pooling layers, followed by a fully-connected layer, activation, another fully-connected layer, and finally a softmax classifier (Image source: http://deeplearning.net/tutorial/lenet.html) Convolutional and pooling (or \u0026ldquo;sub-sampling\u0026rdquo; in Fig. 4) layers act like the V1, V2 and V4 visual cortex units, responding to feature extraction. The object recognition reasoning happens in the later fully-connected layers which consume the extracted features.\nRecurrent Neural Network A sequence model is usually designed to transform an input sequence into an output sequence that lives in a different domain. Recurrent neural network, short for \u0026ldquo;RNN\u0026rdquo;, is suitable for this purpose and has shown tremendous improvement in problems like handwriting recognition, speech recognition, and machine translation (Sutskever et al. 2011, Liwicki et al. 2007).\nA recurrent neural network model is born with the capability to process long sequential data and to tackle tasks with context spreading in time. The model processes one element in the sequence at one time step. After computation, the newly updated unit state is passed down to the next time step to facilitate the computation of the next element. Imagine the case when an RNN model reads all the Wikipedia articles, character by character, and then it can predict the following words given the context.\nFig. 5. A recurrent neural network with one hidden unit (left) and its unrolling version in time (right). The unrolling version illustrates what happens in time: $s\\_{t-1}$, $s\\_{t}$, and $s\\_{t+1}$ are the same unit with different states at different time steps $t-1$, $t$, and $t+1$. (Image source: LeCun, Bengio, and Hinton, 2015; Fig. 5) However, simple perceptron neurons that linearly combine the current input element and the last unit state may easily lose the long-term dependencies. For example, we start a sentence with \u0026ldquo;Alice is working at \u0026hellip;\u0026rdquo; and later after a whole paragraph, we want to start the next sentence with \u0026ldquo;She\u0026rdquo; or \u0026ldquo;He\u0026rdquo; correctly. If the model forgets the character\u0026rsquo;s name \u0026ldquo;Alice\u0026rdquo;, we can never know. To resolve the issue, researchers created a special neuron with a much more complicated internal structure for memorizing long-term context, named \u0026ldquo;Long-short term memory (LSTM)\u0026quot; cell. It is smart enough to learn for how long it should memorize the old information, when to forget, when to make use of the new data, and how to combine the old memory with new input. This introduction is so well written that I recommend everyone with interest in LSTM to read it. It has been officially promoted in the Tensorflow documentation ;-)\nFig. 6. The structure of a LSTM cell. (Image source: http://colah.github.io/posts/2015-08-Understanding-LSTMs) To demonstrate the power of RNNs, Andrej Karpathy built a character-based language model using RNN with LSTM cells. Without knowing any English vocabulary beforehand, the model could learn the relationship between characters to form words and then the relationship between words to form sentences. It could achieve a decent performance even without a huge set of training data.\nFig. 7. A character-based recurrent neural network model writes like a Shakespeare. (Image source: http://karpathy.github.io/2015/05/21/rnn-effectiveness) RNN: Sequence-to-Sequence Model The sequence-to-sequence model is an extended version of RNN, but its application field is distinguishable enough that I would like to list it in a separated section. Same as RNN, a sequence-to-sequence model operates on sequential data, but particularly it is commonly used to develop chatbots or personal assistants, both generating meaningful response for input questions. A sequence-to-sequence model consists of two RNNs, encoder and decoder. The encoder learns the contextual information from the input words and then hands over the knowledge to the decoder side through a \u0026ldquo;context vector\u0026rdquo; (or \u0026ldquo;thought vector\u0026rdquo;, as shown in Fig 8.). Finally, the decoder consumes the context vector and generates proper responses.\nFig. 8. A sequence-to-sequence model for generating Gmail auto replies. (Image source: https://research.googleblog.com/2015/11/computer-respond-to-this-email.html) Autoencoders Different from the previous models, autoencoders are for unsupervised learning. It is designed to learn a low-dimensional representation of a high-dimensional data set, similar to what Principal Components Analysis (PCA) does. The autoencoder model tries to learn an approximation function $ f(x) \\approx x $ to reproduce the input data. However, it is restricted by a bottleneck layer in the middle with a very small number of nodes. With limited capacity, the model is forced to form a very efficient encoding of the data, that is essentially the low-dimensional code we learned.\nFig. 9. An autoencoder model has a bottleneck layer with only a few neurons. (Image source: Geoffrey Hinton’s Coursera class \"Neural Networks for Machine Learning\" - Week 15) Hinton and Salakhutdinov used autoencoders to compress documents on a variety of topics. As shown in Fig 10, when both PCA and autoencoder were applied to reduce the documents onto two dimensions, autoencoder demonstrated a much better outcome. With the help of autoencoder, we can do efficient data compression to speed up the information retrieval including both documents and images.\nFig. 10. The outputs of PCA (left) and autoencoder (right) when both try to compress documents into two numbers. (Image source: Hinton \u0026 Salakhutdinov 2006) Reinforcement (Deep) Learning Since I started my post with AlphaGo, let us dig a bit more on why AlphaGo worked out. Reinforcement learning (\u0026ldquo;RL\u0026rdquo;) is one of the secrets behind its success. RL is a subfield of machine learning which allows machines and software agents to automatically determine the optimal behavior within a given context, with a goal to maximize the long-term performance measured by a given metric.\nFig. 11. AlphaGo neural network training pipeline and architecture. (Image source: Silver et al. 2016) The AlphaGo system starts with a supervised learning process to train a fast rollout policy and a policy network, relying on the manually curated training dataset of professional players' games. It learns what is the best strategy given the current position on the game board. Then it applies reinforcement learning by setting up self-play games. The RL policy network gets improved when it wins more and more games against previous versions of the policy network. In the self-play stage, AlphaGo becomes stronger and stronger by playing against itself without requiring additional external training data.\nGenerative Adversarial Network Generative adversarial network, short for \u0026ldquo;GAN\u0026rdquo;, is a type of deep generative models. GAN is able to create new examples after learning through the real data. It is consist of two models competing against each other in a zero-sum game framework. The famous deep learning researcher Yann LeCun gave it a super high praise: Generative Adversarial Network is the most interesting idea in the last ten years in machine learning. (See the Quora question: \u0026ldquo;What are some recent and potentially upcoming breakthroughs in deep learning?\u0026quot;)\nFig. 12. The architecture of a generative adversarial network. (Image source: http://www.kdnuggets.com/2017/01/generative-adversarial-networks-hot-topic-machine-learning.html) In the original GAN paper, GAN was proposed to generate meaningful images after learning from real photos. It comprises two independent models: the Generator and the Discriminator. The generator produces fake images and sends the output to the discriminator model. The discriminator works like a judge, as it is optimized for identifying the real photos from the fake ones. The generator model is trying hard to cheat the discriminator while the judge is trying hard not to be cheated. This interesting zero-sum game between these two models motivates both to develop their designed skills and improve their functionalities. Eventually, we take the generator model for producing new images.\nToolkits and Libraries After learning all these models, you may start wondering how you can implement the models and use them for real. Fortunately, we have many open source toolkits and libraries for building deep learning models. Tensorflow is fairly new but has attracted a lot of popularity. It turns out, TensorFlow was the most forked Github project of 2015. All that happened in a period of 2 months after its release in Nov 2015.\nHow to Learn? If you are very new to the field and willing to devote some time to studying deep learning in a more systematic way, I would recommend you to start with the book Deep Learning by Ian Goodfellow, Yoshua Bengio, and Aaron Courville. The Coursera course \u0026ldquo;Neural Networks for Machine Learning\u0026rdquo; by Geoffrey Hinton (Godfather of deep learning!). The content for the course was prepared around 2006, pretty old, but it helps you build up a solid foundation for understanding deep learning models and expedite further exploration.\nMeanwhile, maintain your curiosity and passion. The field is making progress every day. Even classical or widely adopted deep learning models may just have been proposed 1-2 years ago. Reading academic papers can help you learn stuff in depth and keep up with the cutting-edge findings.\nUseful resources Google Scholar: http://scholar.google.com arXiv cs section: https://arxiv.org/list/cs/recent Unsupervised Feature Learning and Deep Learning Tutorial Tensorflow Tutorials Data Science Weekly KDnuggets Tons of blog posts and online tutorials Related Cousera courses awesome-deep-learning-papers Blog posts mentioned Explained Visually: Image Kernels Understanding LSTM Networks The Unreasonable Effectiveness of Recurrent Neural Networks Computer, respond to this email. Interesting blogs worthy of checking www.wildml.com colah.github.io karpathy.github.io blog.openai.com Papers mentioned [1] He, Kaiming, et al. \u0026ldquo;Deep residual learning for image recognition.\u0026quot; Proc. IEEE Conf. on computer vision and pattern recognition. 2016.\n[2] Wang, Haohan, Bhiksha Raj, and Eric P. Xing. \u0026ldquo;On the Origin of Deep Learning.\u0026quot; arXiv preprint arXiv:1702.07800, 2017.\n[3] Sutskever, Ilya, James Martens, and Geoffrey E. Hinton. \u0026ldquo;Generating text with recurrent neural networks.\u0026quot; Proc. of the 28th Intl. Conf. on Machine Learning (ICML). 2011.\n[4] Liwicki, Marcus, et al. \u0026ldquo;A novel approach to on-line handwriting recognition based on bidirectional long short-term memory networks.\u0026quot; Proc. of 9th Intl. Conf. on Document Analysis and Recognition. 2007.\n[5] LeCun, Yann, Yoshua Bengio, and Geoffrey Hinton. \u0026ldquo;Deep learning.\u0026quot; Nature 521.7553 (2015): 436-444.\n[6] Hochreiter, Sepp, and Jurgen Schmidhuber. \u0026ldquo;Long short-term memory.\u0026quot; Neural computation 9.8 (1997): 1735-1780.\n[7] Cho, Kyunghyun. et al. \u0026ldquo;Learning phrase representations using RNN encoder-decoder for statistical machine translation.\u0026quot; Proc. Conference on Empirical Methods in Natural Language Processing 1724–1734 (2014).\n[8] Hinton, Geoffrey E., and Ruslan R. Salakhutdinov. \u0026ldquo;Reducing the dimensionality of data with neural networks.\u0026quot; science 313.5786 (2006): 504-507.\n[9] Silver, David, et al. \u0026ldquo;Mastering the game of Go with deep neural networks and tree search.\u0026quot; Nature 529.7587 (2016): 484-489.\n[10] Goodfellow, Ian, et al. \u0026ldquo;Generative adversarial nets.\u0026quot; NIPS, 2014.\n","permalink":"https://lilianweng.github.io/posts/2017-06-21-overview/","summary":"(The post was originated from my talk for WiMLDS x Fintech meetup hosted by Affirm.)\nI believe many of you have watched or heard of the games between AlphaGo and professional Go player Lee Sedol in 2016. Lee has the highest rank of nine dan and many world championships. No doubt, he is one of the best Go players in the world, but he lost by 1-4 in this series versus AlphaGo.","title":"An Overview of Deep Learning for Curious People"},{"content":"","permalink":"https://lilianweng.github.io/faq/","summary":"","title":"FAQ"}] \ No newline at end of file +[{"content":"Prompt Engineering, also known as In-Context Prompting, refers to methods for how to communicate with LLM to steer its behavior for desired outcomes without updating the model weights. It is an empirical science and the effect of prompt engineering methods can vary a lot among models, thus requiring heavy experimentation and heuristics.\nThis post only focuses on prompt engineering for autoregressive language models, so nothing with Cloze tests, image generation or multimodality models. At its core, the goal of prompt engineering is about alignment and model steerability. Check my previous post on controllable text generation.\n[My personal spicy take] In my opinion, some prompt engineering papers are not worthy 8 pages long, since those tricks can be explained in one or a few sentences and the rest is all about benchmarking. An easy-to-use and shared benchmark infrastructure should be more beneficial to the community. Iterative prompting or external tool use would not be trivial to set up. Also non-trivial to align the whole research community to adopt it.\nUseful Resources\n OpenAI Cookbook has many in-depth examples for how to utilize LLM efficiently. LangChain, a library for combining language models with other components to build applications. Prompt Engineering Guide repo contains a pretty comprehensive collection of education materials on prompt engineering. learnprompting.org PromptPerfect Semantic Kernel Basic Prompting Zero-shot and few-shot learning are two most basic approaches for prompting the model, pioneered by many LLM papers and commonly used for benchmarking LLM performance.\nZero-Shot Zero-shot learning is to simply feed the task text to the model and ask for results.\n(All the sentiment analysis examples are from SST-2)\nText: i'll bet the video game is a lot more fun than the film. Sentiment: Few-shot Few-shot learning presents a set of high-quality demonstrations, each consisting of both input and desired output, on the target task. As the model first sees good examples, it can better understand human intention and criteria for what kinds of answers are wanted. Therefore, few-shot learning often leads to better performance than zero-shot. However, it comes at the cost of more token consumption and may hit the context length limit when input and output text are long.\nText: (lawrence bounces) all over the stage, dancing, running, sweating, mopping his face and generally displaying the wacky talent that brought him fame in the first place. Sentiment: positive Text: despite all evidence to the contrary, this clunker has somehow managed to pose as an actual feature movie, the kind that charges full admission and gets hyped on tv and purports to amuse small children and ostensible adults. Sentiment: negative Text: for the first time in years, de niro digs deep emotionally, perhaps because he's been stirred by the powerful work of his co-stars. Sentiment: positive Text: i'll bet the video game is a lot more fun than the film. Sentiment: Many studies looked into how to construct in-context examples to maximize the performance and observed that choice of prompt format, training examples, and the order of the examples can lead to dramatically different performance, from near random guess to near SoTA.\nZhao et al. (2021) investigated the case of few-shot classification and proposed that several biases with LLM (they use GPT-3 in the experiments) contribute to such high variance: (1) Majority label bias exists if distribution of labels among the examples is unbalanced; (2) Recency bias refers to the tendency where the model may repeat the label at the end; (3) Common token bias indicates that LLM tends to produce common tokens more often than rare tokens. To conquer such bias, they proposed a method to calibrate the label probabilities output by the model to be uniform when the input string is N/A.\nTips for Example Selection Choose examples that are semantically similar to the test example using $k$-NN clustering in the embedding space (Liu et al., 2021)\n To select a diverse and representative set of examples, Su et al. (2022) proposed to use a graph-based approach: (1) First, construct a directed graph $G=(V, E)$ based on the embedding (e.g. by SBERT or other embedding models) cosine similarity between samples, where each node points to its $k$ nearest neighbors; (2) Start with a set of selected samples $\\mathcal{L}=\\emptyset$ and a set of remaining samples $\\mathcal{U}$. Each sample $u \\in \\mathcal{U}$ is scored by $$ \\text{score}(u) = \\sum_{v \\in \\{v \\mid (u, v) \\in E, v\\in \\mathcal{U}\\}} s(v)\\quad\\text{where }s(v)=\\rho^{- \\vert \\{\\ell \\in \\mathcal{L} \\vert (v, \\ell)\\in E \\}\\vert},\\quad\\rho \u0026gt; 1 $$ such that $s(v)$ is low if many of $v$\u0026rsquo;s neighbors are selected and thus the scoring encourages to pick diverse samples.\n Rubin et al. (2022) proposed to train embeddings via contrastive learning specific to one training dataset for in-context learning sample selection. Given each training pair $(x, y)$, the quality of one example $e_i$ (formatted input-output pair) can be measured by a conditioned probability assigned by LM: $\\text{score}(e_i) = P_\\text{LM}(y \\mid e_i, x)$. We can identify other examples with top-$k$ and bottom-$k$ scores as positive and negative sets of candidates for every training pair and use that for contrastive learning.\n Some researchers tried Q-Learning to do sample selection. (Zhang et al. 2022)\n Motivated by uncertainty-based active learning, Diao et al. (2023) suggested to identify examples with high disagreement or entropy among multiple sampling trials. Then annotate these examples to be used in few-shot prompts.\n Tips for Example Ordering A general suggestion is to keep the selection of examples diverse, relevant to the test sample and in random order to avoid majority label bias and recency bias. Increasing model sizes or including more training examples does not reduce variance among different permutations of in-context examples. Same order may work well for one model but badly for another. When the validation set is limited, consider choosing the order such that the model does not produce extremely unbalanced predictions or being overconfident about its predictions. (Lu et al. 2022) Instruction Prompting The purpose of presenting few-shot examples in the prompt is to explain our intent to the model; in other words, describe the task instruction to the model in the form of demonstrations. However, few-shot can be expensive in terms of token usage and restricts the input length due to limited context length. So, why not just give the instruction directly?\nInstructed LM (e.g. InstructGPT, natural instruction) finetunes a pretrained model with high-quality tuples of (task instruction, input, ground truth output) to make LM better understand user intention and follow instruction. RLHF (Reinforcement Learning from Human Feedback) is a common method to do so. The benefit of instruction following style fine-tuning improves the model to be more aligned with human intention and greatly reduces the cost of communication.\nWhen interacting with instruction models, we should describe the task requirement in details, trying to be specific and precise and avoiding say \u0026ldquo;not do something\u0026rdquo; but rather specify what to do.\nPlease label the sentiment towards the movie of the given movie review. The sentiment label should be \u0026quot;positive\u0026quot; or \u0026quot;negative\u0026quot;. Text: i'll bet the video game is a lot more fun than the film. Sentiment: Explaining the desired audience is another smart way to give instructions\n For example to produce education materials for kids, Describe what is quantum physics to a 6-year-old. And safe content, ... in language that is safe for work. In-context instruction learning (Ye et al. 2023) combines few-shot learning with instruction prompting. It incorporates multiple demonstration examples across different tasks in the prompt, each demonstration consisting of instruction, task input and output. Note that their experiments were only on classification tasks and the instruction prompt contains all label options.\nDefinition: Determine the speaker of the dialogue, \u0026quot;agent\u0026quot; or \u0026quot;customer\u0026quot;. Input: I have successfully booked your tickets. Ouput: agent Definition: Determine which category the question asks for, \u0026quot;Quantity\u0026quot; or \u0026quot;Location\u0026quot;. Input: What's the oldest building in US? Ouput: Location Definition: Classify the sentiment of the given movie review, \u0026quot;positive\u0026quot; or \u0026quot;negative\u0026quot;. Input: i'll bet the video game is a lot more fun than the film. Output: Self-Consistency Sampling Self-consistency sampling (Wang et al. 2022a) is to sample multiple outputs with temperature \u0026gt; 0 and then selecting the best one out of these candidates. The criteria for selecting the best candidate can vary from task to task. A general solution is to pick majority vote. For tasks that are easy to validate such as a programming question with unit tests, we can simply run through the interpreter and verify the correctness with unit tests.\nChain-of-Thought (CoT) Chain-of-thought (CoT) prompting (Wei et al. 2022) generates a sequence of short sentences to describe reasoning logics step by step, known as reasoning chains or rationales, to eventually lead to the final answer. The benefit of CoT is more pronounced for complicated reasoning tasks, while using large models (e.g. with more than 50B parameters). Simple tasks only benefit slightly from CoT prompting.\nTypes of CoT prompts Two main types of CoT prompting:\n Few-shot CoT. It is to prompt the model with a few demonstrations, each containing manually written (or model-generated) high-quality reasoning chains. (All the math reasoning examples are from GSM8k)\nQuestion: Tom and Elizabeth have a competition to climb a hill. Elizabeth takes 30 minutes to climb the hill. Tom takes four times as long as Elizabeth does to climb the hill. How many hours does it take Tom to climb up the hill? Answer: It takes Tom 30*4 = \u0026lt;\u0026lt;30*4=120\u0026gt;\u0026gt;120 minutes to climb the hill. It takes Tom 120/60 = \u0026lt;\u0026lt;120/60=2\u0026gt;\u0026gt;2 hours to climb the hill. So the answer is 2. === Question: Jack is a soccer player. He needs to buy two pairs of socks and a pair of soccer shoes. Each pair of socks cost $9.50, and the shoes cost $92. Jack has $40. How much more money does Jack need? Answer: The total cost of two pairs of socks is $9.50 x 2 = $\u0026lt;\u0026lt;9.5*2=19\u0026gt;\u0026gt;19. The total cost of the socks and the shoes is $19 + $92 = $\u0026lt;\u0026lt;19+92=111\u0026gt;\u0026gt;111. Jack need $111 - $40 = $\u0026lt;\u0026lt;111-40=71\u0026gt;\u0026gt;71 more. So the answer is 71. === Question: Marty has 100 centimeters of ribbon that he must cut into 4 equal parts. Each of the cut parts must be divided into 5 equal parts. How long will each final cut be? Answer: Zero-shot CoT. Use natural language statement like Let's think step by step to explicitly encourage the model to first generate reasoning chains and then to prompt with Therefore, the answer is to produce answers (Kojima et al. 2022 ). Or a similar statement Let's work this out it a step by step to be sure we have the right answer (Zhou et al. 2022). Question: Marty has 100 centimeters of ribbon that he must cut into 4 equal parts. Each of the cut parts must be divided into 5 equal parts. How long will each final cut be? Answer: Let's think step by step. Tips and Extensions Self-consistency sampling can improve reasoning accuracy by sampling a number of diverse answers and then taking the majority vote. (Wang et al. 2022a)\n Another approach for ensemble learning is to alter the example order or use model generated rationales to replace human-written ones to introduce randomness during multiple sample trials. Then aggregate model outputs with a majority vote to get final answer. (Wang et al. 2022b)\n If training examples are only associated with true answers (easy to verify!) but no rationales, we can follow the STaR (Self-Taught Reasoner; Zelikman et al. 2022) method : (1) Ask LLM to generate reasoning chains and only keep those leading to correct answers; (2) Then fine-tune the model with generated rationales and repeat the process until convergence. Note that higher temperature is more likely to generate incorrect rationales with correct answers. If training examples do not have ground truth answers, maybe consider using majority votes as the \u0026ldquo;correct\u0026rdquo; answers.\n Prompts with demonstrations of higher reasoning complexity can achieve better performance, where complexity is measured by the number of reasoning steps in the chains. When separating reasoning steps, newline \\n symbol works better than step i, period . or semicolon ;. (Fu et al. 2023)\n Complexity-based consistency is to explicitly prefer complex chains among all the generations by taking majority vote among only top $k$ complex chains. (Fu et al. 2023)\n Later, Shum et al. (2023) found that in their experiments CoT prompts with only complex examples can improve the accuracy of complex questions, but perform poorly in simple questions; evidence shown on GSM8k.\n Changing Q: to Question: is found to be helpful. (Fu et al. 2023)\n Ye \u0026amp; Durrett (2022) found that the benefit of including explanations in the prompt is small to moderate for NLP tasks that involve reasoning over text (i.e. QA and NLI) and the effects vary by models. They observed that explanations are more likely to be nonfactual than be inconsistent (i.e. whether explanation entails prediction). Nonfactual explanations most likely lead to incorrect predictions.\n Self-Ask (Press et al. 2022) is a method to repeatedly prompt the model to ask following-up questions to construct the thought process iteratively. Follow-up questions can be answered by search engine results. Similarly, IRCoT (Interleaving Retrieval CoT; Trivedi et al. 2022) and ReAct (Reason + Act; Yao et al. 2023) combines iterative CoT prompting with queries to Wikipedia APIs to search for relevant entities and content and then add it back into the context.\n Fig. 1. How Self-Ask works with external search queries.(Image source: Press et al. 2022). Automatic Prompt Design Prompt is a sequence of prefix tokens that increase the probability of getting desired output given input. Therefore we can treat them as trainable parameters and optimize them directly on the embedding space via gradient descent, such as AutoPrompt (Shin et al., 2020, Prefix-Tuning (Li \u0026amp; Liang (2021)), P-tuning (Liu et al. 2021) and Prompt-Tuning (Lester et al. 2021). This section in my \u0026ldquo;Controllable Neural Text Generation\u0026rdquo; post has a good coverage of them. The trend from AutoPrompt to Prompt-Tuning is that the setup gets gradually simplified.\nAPE (Automatic Prompt Engineer; Zhou et al. 2022) is a method to search over a pool of model-generated instruction candidates and then filters the candidate set according to a chosen score function to ultimately choose the best candidate with highest score.\n Prompt LLM to generate instruction candidates based on a small set of demonstrations in the form of input-output pairs. E.g. {{Given desired input-output pairs}}\\n\\nThe instruction is.\n Given a dataset of $\\mathcal{D}_\\text{train} = \\{(x, y)\\}$, we would like to find an instruction $\\rho$ such that $\\rho^* = \\arg\\max_\\rho \\mathbb{E}_{(x, y) \\in \\mathcal{D}_\\text{train}} [f(\\rho, x, y)]$, where $f(.)$ is a per-sample score function, such as execution accuracy $\\mathbb{1}[\\text{LM}(.\\vert \\rho, x)=y]$ or log probability: $p_\\text{LM}(y \\mid \\rho, x)$.\n Use an iterative Monte Carlo search method to improve the best candidates by proposing semantically similar variants via prompts like Generate a variation of the following instruction while keeping the semantic meaning.\\n\\nInput: ...\\n\\nOutput:...\n To construct chain-of-thought prompts automatically, Shum et al. (2023) suggested augment-prune-select, a three-step process:\n Augment: Generate multiple pseudo-chains of thought given question using few-shot or zero-shot CoT prompts; Prune: Prune pseudo chains based on whether generated answers match ground truths. Select: Apply a variance-reduced policy gradient strategy to learn the probability distribution over selected examples, while considering the probability distribution over examples as policy and the validation set accuracy as reward. Zhang et al. (2023) instead adopted clustering techniques to sample questions and then generates chains. They observed that LLMs tend to make certain types of mistakes. One type of errors can be similar in the emebedding space and thus get grouped together. By only sampling one or a few from frequent-error clusters, we can prevent too many wrong demonstrations of one error type and collect a diverse set of examples.\n Question clustering: Embed questions and run $k$-means for clustering. Demonstration selection: Select a set of representative questions from each cluster; i.e. one demonstration from one cluster. Samples in each cluster are sorted by distance to the cluster centroid and those closer to the centroid are selected first. Rationale generation: Use zero-shot CoT to generate reasoning chains for selected questions and construct few-shot prompt to run inference. Augmented Language Models A survey on augmented language models by Mialon et al. (2023) has great coverage over multiple categories of language models augmented with reasoning skills and the ability of using external tools. Recommend it.\nRetrieval Often we need to complete tasks that require latest knowledge after the model pretraining time cutoff or internal/private knowledge base. In that case, the model would not know the context if we don’t explicitly provide it in the prompt. Many methods for Open Domain Question Answering depend on first doing retrieval over a knowledge base and then incorporating the retrieved content as part of the prompt. The accuracy of such a process depends on the quality of both retrieval and generation steps.\nLazaridou et al. (2022) studied how to use Google Search for document retrieval to augment LLMs. Given a question $q$, clean text is extracted out of 20 URLs returned by Google, resulting in a set of documents. Because these documents are long, each document is split into paragraphs of 6 sentences, $\\{p\\}$. Paragraphs are ranked by TF-IDF based cosine similarity between evidence paragraphs and the query. Only the most relevant paragraph is used in the prompt to produce an answer $a$.\nFor closed-book QA, each demonstration is formatted as follows to construct few-shot prompts. Swapping the question with the evidence (longer distance between questions and answers) is found to consistently yield lower results across all datasets.\nEvidence: ... Question: ... Answer: ... The answer probability is computed in three ways:\n RAG style, $p(a_i \\mid q) = \\sum_{i=1}^n p_\\text{tf-idf} (p_i \\mid q) \\cdot p_\\text{LM}(a_i \\mid q, p_i)$, where $p_\\text{tf-idf} (p_i \\mid q)$ is the normalized cosine similarities between the TF-IDF passage and question representations. Noisy channel inference, $p(a_i\\mid q) = \\frac{p_\\text{LM}(q \\mid a_i, p_i) \\cdot p_\\text{LM}(a_i \\mid p_i)}{p_\\text{LM}(q \\mid p_i)}$ Product-of-Experts (PoE), combines all probabilities used above in addition to $p_\\text{LM}(p_i \\mid q)$. According to their experiments on generation and classification tasks, among three answer reranking scores - PoE \u0026gt; Noisy channel \u0026gt; RAG. Among individual probabilities, $p_\\text{LM}(a \\mid q, p_i)$ and $p_\\text{LM}(q \\mid p_i, a)$ are found to be most informative. $p_\\text{LM}(q \\mid p_i, a)$ captures how well the question can be explained by LM given evidence paragraph and answer and can reliably be used for reranking answer candidates.\nOne observation with SituatedQA dataset for questions grounded in different dates is that despite LM (pretraining cutoff is year 2020) has access to latest information via Google Search, its performance on post-2020 questions are still a lot worse than on pre-2020 questions. This suggests the existence of some discrepencies or conflicting parametric between contextual information and model internal knowledge.\nInterestingly it is found to be beneficial even with only \u0026ldquo;internal retrieval\u0026rdquo;, that is, to generate knowledge about a topic before answering the question (Liu et al. 2022). First we can use the following template to extract knowledge:\nGenerate some knowledge about the input. Examples: Input: What type of water formation is formed by clouds? Knowledge: Clouds are made of water vapor. Input: {question} Knowledge: And then with model-generated knowledge, prompt the LM further to get the answer.\nProgramming Language Both PAL (Program-aided language models); Gao et al. 2022) and PoT (Program of Thoughts prompting; Chen et al. 2022) ask LLM to generate programming language statements to resolve natural language reasoning problems, hence offloading the solution step to a runtime such as a Python interpreter. Such setup decouples complex computation and reasoning. It relies on a LM with good enough coding skills.\nFig. 2. Comparing CoT and PoT. (Image source: Chen et al. 2022). External APIs TALM (Tool Augmented Language Models; Parisi et al. 2022) is a language model augmented with text-to-text API calls. LM is guided to generate |tool-call and tool input text conditioned on task input text to construct API call requests. When |result shows up, the specified tool API is called and the returned result gets appended to the text sequence. The final output is generated following |output token.\nFig. 3. The format of API calls in TALM. (Image source: Parisi et al. 2022). TALM adopts a self-play approach to iteratively bootstrap the dataset of tool use examples and finetune LM with it. This iterative self-play pipeline mimics a RL process where LM is the policy network and it is trained by policy gradient with a binary reward signal.\nFig. 4. Self-play iterations help boost the model performance.(Image source: Parisi et al. 2022). Toolformer (Schick et al. 2023) is a LM that can use external tools via simple APIs, which is built in a self-supervised manner and only requires a handful of demonstrations for each API. The toolbox of Toolformer includes:\n Calculator to help LM with the lack of precise math skills; Q\u0026amp;A system to help with unfaithful content and hallucination; Search engine to provide up-to-date information after pretraining cut off time; Translation system to improve performance on low resource language; Calendar to make LM be aware of time progression. Fig. 5. Illustration of how to build Toolformer.(Image source: Schick et al. 2023). Toolformer is trained as follows:\n Prompting to annotate potential API calls. Ask a pre-trained LM to annotate a dataset via few-shot learning with API call usage examples. Formatting example:\nFig. 6. How dataset is annotated to do API calls.(Image source: Schick et al. 2023). Each API call is represented as a tuple of (API name, corresponding input), $c=(a_c, i_c)$ and its corresponding result is denoted as $r$. The API call sequences with and without results are labeled as follows, respectively:\n $$ \\begin{aligned} e(c) \u0026= \\langle\\texttt{API}\\rangle a_c(i_c) \\langle\\texttt{/API}\\rangle \\\\ e(c, r) \u0026= \\langle\\texttt{API}\\rangle a_c(i_c) \\to r \\langle\\texttt{/API}\\rangle \\end{aligned} $$ Sample API calls based on the probabilities $p_\\text{LM}(\\langle\\texttt{API}\\rangle \\mid \\text{prompt}(\\mathbf{x}), \\mathbf{x}_{1:i})$ and select top $k$ candidate positions for doing API calls at position $i$ if the probability is larger than a threshold.\n Then we sample potential API calls from the LM given the sequence $[\\text{prompt}(\\mathbf{x}), x_1, \\dots, x_{i-1}, \\langle\\texttt{API}\\rangle]$ as prefix and $\\langle\\texttt{/API}\\rangle$ as suffix.\n Filter annotations based on whether API calls help model predict future tokens. Use a self-supervised loss to decide which API calls are actually helpful.\n Execute each API call $c_i$ to get corresponding result $r_i$.\n Compute weighted cross entropy loss for the LM over tokens $x_i, \\dots, x_n$ when the model is prefixed with the prompt. Two versions are computed, one with API result and the other with empty sequence $\\varepsilon$.\n $$ \\begin{aligned} L^+_i \u0026= L_i(e(c_i, r_i)) \\\\ L^-_i \u0026= \\min(L_i(\\varepsilon), L_i(e(c_i, \\varepsilon))) \\\\ \\end{aligned} $$ Only API calls with $L^-_i - L^+_i$ larger than a threshold are kept, meaning that adding this API call and its results help the model predict future tokens.\n Fine-tune LM on this annotated dataset. The new training sequences are constructed as $\\mathbf{x}^* = x_{1:i-1}, e(c_i, r_i), x_{i:n}$ . The training data is a combination of the original dataset (e.g. a subset of CCNet, as in the paper) and its augmented version.\n At inference time, decoding runs until the model produces \u0026ldquo;$\\to$ \u0026quot; token, indicating that it is expecting response from an API call next.\nToolformer currently does not support tool use in a chain (i.e. using the output of one tool as an input for another tool) or in an interactive way (i.e. adopt API response after human selection). Both are interesting future directions to expand the model for.\nCitation Cited as:\n Weng, Lilian. (Mar 2023). Prompt Engineering. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/.\n Or\n@article{weng2023prompt, title = \u0026quot;Prompt Engineering\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2023\u0026quot;, month = \u0026quot;Mar\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/\u0026quot; } References [1] Zhao et al. \u0026ldquo;Calibrate Before Use: Improving Few-shot Performance of Language Models.\u0026quot; ICML 2021\n[2] Liu et al. \u0026ldquo;What Makes Good In-Context Examples for GPT-3?\u0026quot; arXiv preprint arXiv:2101.06804 (2021).\n[3] Lu et al. \u0026ldquo;Fantastically Ordered Prompts and Where to Find Them: Overcoming Few-Shot Prompt Order Sensitivity.\u0026quot; ACL 2022\n[4] Ye et al. \u0026ldquo;In-Context Instruction Learning.\u0026quot; arXiv preprint arXiv:2302.14691 (2023).\n[5] Su et al. \u0026ldquo;Selective annotation makes language models better few-shot learners.\u0026quot; arXiv preprint arXiv:2209.01975 (2022).\n[6] Rubin et al. \u0026ldquo;Learning to retrieve prompts for in-context learning.\u0026quot; NAACL-HLT 2022\n[7] Wei et al. \u0026ldquo;Chain of thought prompting elicits reasoning in large language models.\u0026quot; NeurIPS 2022\n[8] Wang et al. \u0026ldquo;Self-Consistency Improves Chain of Thought Reasoning in Language Models.\u0026quot; ICLR 2023.\n[9] Diao et al. \u0026ldquo;Active Prompting with Chain-of-Thought for Large Language Models.\u0026quot; arXiv preprint arXiv:2302.12246 (2023).\n[10] Zelikman et al. \u0026ldquo;STaR: Bootstrapping Reasoning With Reasoning.\u0026quot; arXiv preprint arXiv:2203.14465 (2022).\n[11] Ye \u0026amp; Durrett. \u0026ldquo;The unreliability of explanations in few-shot in-context learning.\u0026quot; arXiv preprint arXiv:2205.03401 (2022).\n[12] Trivedi et al. \u0026ldquo;Interleaving retrieval with chain-of-thought reasoning for knowledge-intensive multi-step questions.\u0026quot; arXiv preprint arXiv:2212.10509 (2022).\n[13] Press et al. \u0026ldquo;Measuring and narrowing the compositionality gap in language models.\u0026quot; arXiv preprint arXiv:2210.03350 (2022).\n[14] Yao et al. \u0026ldquo;ReAct: Synergizing reasoning and acting in language models.\u0026quot; ICLR 2023.\n[15] Fu et al. \u0026ldquo;Complexity-based prompting for multi-step reasoning.\u0026quot; arXiv preprint arXiv:2210.00720 (2022).\n[16] Wang et al. \u0026ldquo;Rationale-augmented ensembles in language models.\u0026quot; arXiv preprint arXiv:2207.00747 (2022).\n[17] Zhang et al. \u0026ldquo;Automatic chain of thought prompting in large language models.\u0026quot; arXiv preprint arXiv:2210.03493 (2022).\n[18] Shum et al. \u0026ldquo;Automatic Prompt Augmentation and Selection with Chain-of-Thought from Labeled Data.\u0026quot; arXiv preprint arXiv:2302.12822 (2023).\n[19] Zhou et al. \u0026ldquo;Large Language Models Are Human-Level Prompt Engineers.\u0026quot; ICLR 2023.\n[20] Lazaridou et al. \u0026ldquo;Internet augmented language models through few-shot prompting for open-domain question answering.\u0026quot; arXiv preprint arXiv:2203.05115 (2022).\n[21] Chen et al. \u0026ldquo;Program of Thoughts Prompting: Disentangling Computation from Reasoning for Numerical Reasoning Tasks.\u0026quot; arXiv preprint arXiv:2211.12588 (2022).\n[22] Gao et al. \u0026ldquo;PAL: Program-aided language models.\u0026quot; arXiv preprint arXiv:2211.10435 (2022).\n[23] Parisi et al. \u0026ldquo;TALM: Tool Augmented Language Models\u0026rdquo; arXiv preprint arXiv:2205.12255 (2022).\n[24] Schick et al. \u0026ldquo;Toolformer: Language Models Can Teach Themselves to Use Tools.\u0026quot; arXiv preprint arXiv:2302.04761 (2023).\n[25] Mialon et al. \u0026ldquo;Augmented Language Models: a Survey\u0026rdquo; arXiv preprint arXiv:2302.07842 (2023).\n","permalink":"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/","summary":"Prompt Engineering, also known as In-Context Prompting, refers to methods for how to communicate with LLM to steer its behavior for desired outcomes without updating the model weights. It is an empirical science and the effect of prompt engineering methods can vary a lot among models, thus requiring heavy experimentation and heuristics.\nThis post only focuses on prompt engineering for autoregressive language models, so nothing with Cloze tests, image generation or multimodality models.","title":"Prompt Engineering"},{"content":"Many new Transformer architecture improvements have been proposed since my last post on \u0026ldquo;The Transformer Family\u0026rdquo; about three years ago. Here I did a big refactoring and enrichment of that 2020 post \u0026mdash; restructure the hierarchy of sections and improve many sections with more recent papers. Version 2.0 is a superset of the old version, about twice the length.\nNotations Symbol Meaning $d$ The model size / hidden state dimension / positional encoding size. $h$ The number of heads in multi-head attention layer. $L$ The segment length of input sequence. $N$ The total number of attention layers in the model; not considering MoE. $\\mathbf{X} \\in \\mathbb{R}^{L \\times d}$ The input sequence where each element has been mapped into an embedding vector of shape $d$, same as the model size. $\\mathbf{W}^k \\in \\mathbb{R}^{d \\times d_k}$ The key weight matrix. $\\mathbf{W}^q \\in \\mathbb{R}^{d \\times d_k}$ The query weight matrix. $\\mathbf{W}^v \\in \\mathbb{R}^{d \\times d_v}$ The value weight matrix. Often we have $d_k = d_v = d$. $\\mathbf{W}^k_i, \\mathbf{W}^q_i \\in \\mathbb{R}^{d \\times d_k/h}; \\mathbf{W}^v_i \\in \\mathbb{R}^{d \\times d_v/h}$ The weight matrices per head. $\\mathbf{W}^o \\in \\mathbb{R}^{d_v \\times d}$ The output weight matrix. $\\mathbf{Q} = \\mathbf{X}\\mathbf{W}^q \\in \\mathbb{R}^{L \\times d_k}$ The query embedding inputs. $\\mathbf{K} = \\mathbf{X}\\mathbf{W}^k \\in \\mathbb{R}^{L \\times d_k}$ The key embedding inputs. $\\mathbf{V} = \\mathbf{X}\\mathbf{W}^v \\in \\mathbb{R}^{L \\times d_v}$ The value embedding inputs. $\\mathbf{q}_i, \\mathbf{k}_i \\in \\mathbb{R}^{d_k}, \\mathbf{v}_i \\in \\mathbb{R}^{d_v}$ Row vectors in query, key, value matrices, $\\mathbf{Q}$, $\\mathbf{K}$ and $\\mathbf{V}$. $S_i$ A collection of key positions for the $i$-th query $\\mathbf{q}_i$ to attend to. $\\mathbf{A} \\in \\mathbb{R}^{L \\times L}$ The self-attention matrix between a input sequence of lenght $L$ and itself. $\\mathbf{A} = \\text{softmax}(\\mathbf{Q}\\mathbf{K}^\\top / \\sqrt{d_k})$. $a_{ij} \\in \\mathbf{A}$ The scalar attention score between query $\\mathbf{q}_i$ and key $\\mathbf{k}_j$. $\\mathbf{P} \\in \\mathbb{R}^{L \\times d}$ position encoding matrix, where the $i$-th row $\\mathbf{p}_i$ is the positional encoding for input $\\mathbf{x}_i$. Transformer Basics The Transformer (which will be referred to as \u0026ldquo;vanilla Transformer\u0026rdquo; to distinguish it from other enhanced versions; Vaswani, et al., 2017) model has an encoder-decoder architecture, as commonly used in many NMT models. Later simplified Transformer was shown to achieve great performance in language modeling tasks, like in encoder-only BERT or decoder-only GPT.\nAttention and Self-Attention Attention is a mechanism in neural network that a model can learn to make predictions by selectively attending to a given set of data. The amount of attention is quantified by learned weights and thus the output is usually formed as a weighted average.\nSelf-attention is a type of attention mechanism where the model makes prediction for one part of a data sample using other parts of the observation about the same sample. Conceptually, it feels quite similar to non-local means. Also note that self-attention is permutation-invariant; in other words, it is an operation on sets.\nThere are various forms of attention / self-attention, Transformer (Vaswani et al., 2017) relies on the scaled dot-product attention: given a query matrix $\\mathbf{Q}$, a key matrix $\\mathbf{K}$ and a value matrix $\\mathbf{V}$, the output is a weighted sum of the value vectors, where the weight assigned to each value slot is determined by the dot-product of the query with the corresponding key:\n $$ \\text{attn}(\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}) = \\text{softmax}(\\frac{\\mathbf{Q} {\\mathbf{K}}^\\top}{\\sqrt{d_k}})\\mathbf{V} $$ And for a query and a key vector $\\mathbf{q}_i, \\mathbf{k}_j \\in \\mathbb{R}^d$ (row vectors in query and key matrices), we have a scalar score:\n $$ a_{ij} = \\text{softmax}(\\frac{\\mathbf{q}_i {\\mathbf{k}_j}^\\top}{\\sqrt{d_k}}) = \\frac{\\exp(\\mathbf{q}_i {\\mathbf{k}_j}^\\top)}{ \\sqrt{d_k} \\sum_{r \\in \\mathcal{S}_i} \\exp(\\mathbf{q}_i {\\mathbf{k}_r}^\\top) } $$ where $\\mathcal{S}_i$ is a collection of key positions for the $i$-th query to attend to.\nSee my old post for other types of attention if interested.\nMulti-Head Self-Attention The multi-head self-attention module is a key component in Transformer. Rather than only computing the attention once, the multi-head mechanism splits the inputs into smaller chunks and then computes the scaled dot-product attention over each subspace in parallel. The independent attention outputs are simply concatenated and linearly transformed into expected dimensions.\n $$ \\begin{aligned} \\text{MultiHeadAttn}(\\mathbf{X}_q, \\mathbf{X}_k, \\mathbf{X}_v) \u0026= [\\text{head}_1; \\dots; \\text{head}_h] \\mathbf{W}^o \\\\ \\text{where head}_i \u0026= \\text{Attention}(\\mathbf{X}_q\\mathbf{W}^q_i, \\mathbf{X}_k\\mathbf{W}^k_i, \\mathbf{X}_v\\mathbf{W}^v_i) \\end{aligned} $$ where $[.;.]$ is a concatenation operation. $\\mathbf{W}^q_i, \\mathbf{W}^k_i \\in \\mathbb{R}^{d \\times d_k/h}, \\mathbf{W}^v_i \\in \\mathbb{R}^{d \\times d_v/h}$ are weight matrices to map input embeddings of size $L \\times d$ into query, key and value matrices. And $\\mathbf{W}^o \\in \\mathbb{R}^{d_v \\times d}$ is the output linear transformation. All the weights should be learned during training.\nFig. 1. Illustration of the multi-head scaled dot-product attention mechanism. (Image source: Figure 2 in Vaswani, et al., 2017) Encoder-Decoder Architecture The encoder generates an attention-based representation with capability to locate a specific piece of information from a large context. It consists of a stack of 6 identity modules, each containing two submodules, a multi-head self-attention layer and a point-wise fully connected feed-forward network. By point-wise, it means that it applies the same linear transformation (with same weights) to each element in the sequence. This can also be viewed as a convolutional layer with filter size 1. Each submodule has a residual connection and layer normalization. All the submodules output data of the same dimension $d$.\nThe function of Transformer decoder is to retrieve information from the encoded representation. The architecture is quite similar to the encoder, except that the decoder contains two multi-head attention submodules instead of one in each identical repeating module. The first multi-head attention submodule is masked to prevent positions from attending to the future.\nFig. 2. The architecture of the vanilla Transformer model. (Image source: Figure 17) Positional Encoding Because self-attention operation is permutation invariant, it is important to use proper positional encoding to provide order information to the model. The positional encoding $\\mathbf{P} \\in \\mathbb{R}^{L \\times d}$ has the same dimension as the input embedding, so it can be added on the input directly. The vanilla Transformer considered two types of encodings:\nSinusoidal Positional Encoding Sinusoidal positional encoding is defined as follows, given the token position $i=1,\\dots,L$ and the dimension $\\delta=1,\\dots,d$:\n $$ \\text{PE}(i,\\delta) = \\begin{cases} \\sin(\\frac{i}{10000^{2\\delta'/d}}) \u0026 \\text{if } \\delta = 2\\delta'\\\\ \\cos(\\frac{i}{10000^{2\\delta'/d}}) \u0026 \\text{if } \\delta = 2\\delta' + 1\\\\ \\end{cases} $$ In this way each dimension of the positional encoding corresponds to a sinusoid of different wavelengths in different dimensions, from $2\\pi$ to $10000 \\cdot 2\\pi$.\nFig. 3. Sinusoidal positional encoding with $L=32$ and $d=128$. The value is between -1 (black) and 1 (white) and the value 0 is in gray. Learned Positional Encoding Learned positional encoding assigns each element with a learned column vector which encodes its absolute position (Gehring, et al. 2017) and furthermroe this encoding can be learned differently per layer (Al-Rfou et al. 2018).\nRelative Position Encoding Shaw et al. (2018)) incorporated relative positional information into $\\mathbf{W}^k$ and $\\mathbf{W}^v$. Maximum relative position is clipped to a maximum absolute value of $k$ and this clipping operation enables the model to generalize to unseen sequence lengths. Therefore, $2k + 1$ unique edge labels are considered and let us denote $\\mathbf{P}^k, \\mathbf{P}^v \\in \\mathbb{R}^{2k+1}$ as learnable relative position representations.\n $$ A_{ij}^k = P^k_{\\text{clip}(j - i, k)} \\quad A_{ij}^v = P^v_{\\text{clip}(j - i, k)} \\quad \\text{where }\\text{clip}(x, k) = \\text{clip}(x, -k, k) $$ Transformer-XL (Dai et al., 2019) proposed a type of relative positional encoding based on reparametrization of dot-product of keys and queries. To keep the positional information flow coherently across segments, Transformer-XL encodes the relative position instead, as it could be sufficient enough to know the position offset for making good predictions, i.e. $i-j$, between one key vector $\\mathbf{k}_{\\tau, j}$ and its query $\\mathbf{q}_{\\tau, i}$.\nIf omitting the scalar $1/\\sqrt{d_k}$ and the normalizing term in softmax but including positional encodings, we can write the attention score between query at position $i$ and key at position $j$ as:\n $$ \\begin{aligned} a_{ij} \u0026= \\mathbf{q}_i {\\mathbf{k}_j}^\\top = (\\mathbf{x}_i + \\mathbf{p}_i)\\mathbf{W}^q ((\\mathbf{x}_j + \\mathbf{p}_j)\\mathbf{W}^k)^\\top \\\\ \u0026= \\mathbf{x}_i\\mathbf{W}^q {\\mathbf{W}^k}^\\top\\mathbf{x}_j^\\top + \\mathbf{x}_i\\mathbf{W}^q {\\mathbf{W}^k}^\\top\\mathbf{p}_j^\\top + \\mathbf{p}_i\\mathbf{W}^q {\\mathbf{W}^k}^\\top\\mathbf{x}_j^\\top + \\mathbf{p}_i\\mathbf{W}^q {\\mathbf{W}^k}^\\top\\mathbf{p}_j^\\top \\end{aligned} $$ Transformer-XL reparameterizes the above four terms as follows:\n $$ a_{ij}^\\text{rel} = \\underbrace{ \\mathbf{x}_i\\mathbf{W}^q \\color{blue}{ {\\mathbf{W}_E^k}^\\top } \\mathbf{x}_j^\\top }_\\text{content-based addressing} + \\underbrace{ \\mathbf{x}_i\\mathbf{W}^q \\color{blue}{ {\\mathbf{W}_R^k}^\\top } \\color{green}{\\mathbf{r}_{i-j}^\\top} }_\\text{content-dependent positional bias} + \\underbrace{ \\color{red}{\\mathbf{u}} \\color{blue}{ {\\mathbf{W}_E^k}^\\top } \\mathbf{x}_j^\\top }_\\text{global content bias} + \\underbrace{ \\color{red}{\\mathbf{v}} \\color{blue}{ {\\mathbf{W}_R^k}^\\top } \\color{green}{\\mathbf{r}_{i-j}^\\top} }_\\text{global positional bias} $$ Replace $\\mathbf{p}_j$ with relative positional encoding $\\mathbf{r}_{i-j} \\in \\mathbf{R}^{d}$; Replace $\\mathbf{p}_i\\mathbf{W}^q$ with two trainable parameters $\\mathbf{u}$ (for content) and $\\mathbf{v}$ (for location) in two different terms; Split $\\mathbf{W}^k$ into two matrices, $\\mathbf{W}^k_E$ for content information and $\\mathbf{W}^k_R$ for location information. Rotary Position Embedding Rotary position embedding (RoPE; Su et al. 2021) encodes the absolution position with a rotation matrix and multiplies key and value matrices of every attention layer with it to inject relative positional information at every layer.\nWhen encoding relative positional information into the inner product of the $i$-th key and the $j$-th query, we would like to formulate the function in a way that the inner product is only about the relative position $i-j$. Rotary Position Embedding (RoPE) makes use of the rotation operation in Euclidean space and frames the relative position embedding as simply rotating feature matrix by an angle proportional to its position index.\nGiven a vector $\\mathbf{z}$, if we want to rotate it counterclockwise by $\\theta$, we can multiply it by a rotation matrix to get $R\\mathbf{z}$ where the rotation matrix $R$ is defined as:\n $$ R = \\begin{bmatrix} \\cos\\theta \u0026 -\\sin\\theta \\\\ \\sin\\theta \u0026 \\cos\\theta \\end{bmatrix} $$ When generalizing to higher dimensional space, RoPE divide the $d$-dimensional space into $d/2$ subspaces and constructs a rotation matrix $R$ of size $d \\times d$ for token at position $i$:\n $$ R^d_{\\Theta, i} = \\begin{bmatrix} \\cos i\\theta_1 \u0026 -\\sin i\\theta_1 \u0026 0 \u0026 0 \u0026 \\dots \u0026 0 \u0026 0 \\\\ \\sin i\\theta_1 \u0026 \\cos i\\theta_1 \u0026 0 \u0026 0 \u0026 \\dots \u0026 0 \u0026 0 \\\\ 0 \u0026 0 \u0026 \\cos i\\theta_2 \u0026 -\\sin i\\theta_2 \u0026 \\dots \u0026 0 \u0026 0 \\\\ 0 \u0026 0 \u0026 \\sin i\\theta_1 \u0026 \\cos i\\theta_1 \u0026 \\dots \u0026 0 \u0026 0 \\\\ \\vdots \u0026 \\vdots \u0026 \\vdots \u0026 \\vdots \u0026 \\ddots \u0026 \\vdots \u0026 \\vdots \\\\ 0 \u0026 0 \u0026 0 \u0026 0 \u0026 \\dots \u0026 \\cos i\\theta_{d/2} \u0026 -\\sin i\\theta_{d/2} \\\\ 0 \u0026 0 \u0026 0 \u0026 0 \u0026 \\dots \u0026 \\sin i\\theta_{d/2} \u0026 \\cos i\\theta_{d/2} \\\\ \\end{bmatrix} $$ where in the paper we have $\\Theta = {\\theta_i = 10000^{-2(i−1)/d}, i \\in [1, 2, \u0026hellip;, d/2]}$. Note that this is essentially equivalent to sinusoidal positional encoding but formulated as a rotation matrix.\nThen both key and query matrices incorporates the positional information by multiplying with this rotation matrix:\n $$ \\begin{aligned} \u0026 \\mathbf{q}_i^\\top \\mathbf{k}_j = (R^d_{\\Theta, i} \\mathbf{W}^q\\mathbf{x}_i)^\\top (R^d_{\\Theta, j} \\mathbf{W}^k\\mathbf{x}_j) = \\mathbf{x}_i^\\top\\mathbf{W}^q R^d_{\\Theta, j-i}\\mathbf{W}^k\\mathbf{x}_j \\\\ \u0026 \\text{ where } R^d_{\\Theta, j-i} = (R^d_{\\Theta, i})^\\top R^d_{\\Theta, j} \\end{aligned} $$ Fig. 4. Visual illustration of how rotary position embedding is implemented.(Image source: Su et al., 2021) Longer Context The length of an input sequence for transformer models at inference time is upper-bounded by the context length used for training. Naively increasing context length leads to high consumption in both time ($\\mathcal{O}(L^2d)$) and memory ($\\mathcal{O}(L^2)$) and may not be supported due to hardware constraints.\nThis section introduces several improvements in transformer architecture to better support long context at inference; E.g. using additional memory, design for better context extrapolation, or recurrency mechanism.\nContext Memory The vanilla Transformer has a fixed and limited attention span. The model can only attend to other elements in the same segments during each update step and no information can flow across separated fixed-length segments. This context segmentation causes several issues:\n The model cannot capture very long term dependencies. It is hard to predict the first few tokens in each segment given no or thin context. The evaluation is expensive. Whenever the segment is shifted to the right by one, the new segment is re-processed from scratch, although there are a lot of overlapped tokens. Transformer-XL (Dai et al., 2019; \u0026ldquo;XL\u0026rdquo; means \u0026ldquo;extra long\u0026rdquo;) modifies the architecture to reuse hidden states between segments with an additional memory. The recurrent connection between segments is introduced into the model by continuously using the hidden states from the previous segments.\nFig. 5. A comparison between the training phrase of vanilla Transformer \u0026 Transformer-XL with a segment length 4. (Image source: left part of Figure 2 in Dai et al., 2019). Let\u0026rsquo;s label the hidden state of the $n$-th layer for the $(\\tau + 1)$-th segment in the model as $\\mathbf{h}_{\\tau+1}^{(n)} \\in \\mathbb{R}^{L \\times d}$. In addition to the hidden state of the last layer for the same segment $\\mathbf{h}_{\\tau+1}^{(n-1)}$, it also depends on the hidden state of the same layer for the previous segment $\\mathbf{h}_{\\tau}^{(n)}$. By incorporating information from the previous hidden states, the model extends the attention span much longer in the past, over multiple segments.\n $$ \\begin{aligned} \\color{red}{\\widetilde{\\mathbf{h}}_{\\tau+1}^{(n-1)}} \u0026= [\\text{stop-gradient}(\\mathbf{h}_{\\tau}^{(n-1)}) \\circ \\mathbf{h}_{\\tau+1}^{(n-1)}] \\\\ \\mathbf{Q}_{\\tau+1}^{(n)} \u0026= \\mathbf{h}_{\\tau+1}^{(n-1)}\\mathbf{W}^q \\\\ \\mathbf{K}_{\\tau+1}^{(n)} \u0026= \\color{red}{\\widetilde{\\mathbf{h}}_{\\tau+1}^{(n-1)}} \\mathbf{W}^k \\\\ \\mathbf{V}_{\\tau+1}^{(n)} \u0026= \\color{red}{\\widetilde{\\mathbf{h}}_{\\tau+1}^{(n-1)}} \\mathbf{W}^v \\\\ \\mathbf{h}_{\\tau+1}^{(n)} \u0026= \\text{transformer-layer}(\\mathbf{Q}_{\\tau+1}^{(n)}, \\mathbf{K}_{\\tau+1}^{(n)}, \\mathbf{V}_{\\tau+1}^{(n)}) \\end{aligned} $$ Note that both keys and values rely on extended hidden states, while queries only consume hidden states at the current step. The concatenation operation $[. \\circ .]$ is along the sequence length dimension. And Transformer-XL needs to use relative positional encoding because previous and current segments would be assigned with the same encoding if we encode absolute positions, which is undesired.\nCompressive Transformer (Rae et al. 2019) extends Transformer-XL by compressing past memories to support longer sequences. It explicitly adds memory slots of size $m_m$ per layer for storing past activations of this layer to preserve long context. When some past activations become old enough, they are compressed and saved in an additional compressed memory of size $m_{cm}$ per layer.\nFig. 6. Compressive transformer maintains two types of memory slots, memory and compressed memory, to support long context. (Image source: Rae et al. 2019). Both memory and compressed memory are FIFO queues. Given the model context length $L$, the compression function of compression rate $c$ is defined as $f_c: \\mathbb{R}^{L \\times d} \\to \\mathbb{R}^{[\\frac{L}{c}] \\times d}$, mapping $L$ oldest activations to $[\\frac{L}{c}]$ compressed memory elements. There are several choices of compression functions:\n Max/mean pooling of kernel and stride size $c$; 1D convolution with kernel and stride size $c$ (need to learn additional parameters); Dilated convolution (need to learn additional parameters). In their experiments, convolution compression works out the best on EnWik8 dataset; Most used memories. Compressive transformer has two additional training losses:\n Auto-encoding loss (lossless compression objective) measures how well we can reconstruct the original memories from compressed memories\n $$ \\mathcal{L}_{ac} = \\| \\textbf{old_mem}^{(i)} - g(\\textbf{new_cm}^{(i)}) \\|_2 $$ where $g: \\mathbb{R}^{[\\frac{L}{c}] \\times d} \\to \\mathbb{R}^{L \\times d}$ reverses the compression function $f$. Attention-reconstruction loss (lossy objective) reconstructs content-based attention over memory vs compressed memory and minimize the difference:\n $$ \\mathcal{L}_{ar} = \\|\\text{attn}(\\mathbf{h}^{(i)}, \\textbf{old_mem}^{(i)}) − \\text{attn}(\\mathbf{h}^{(i)}, \\textbf{new_cm}^{(i)})\\|_2 $$ Transformer-XL with a memory of size $m$ has a maximum temporal range of $m \\times N$, where $N$ is the number of layers in the model, and attention cost $\\mathcal{O}(L^2 + Lm)$. In comparison, compressed transformer has a temporal range of $(m_m + c \\cdot m_{cm}) \\times N$ and attention cost $\\mathcal{O}(L^2 + L(m_m + m_{cm}))$. A larger compression rate $c$ gives better tradeoff between temporal range length and attention cost.\nAttention weights, from oldest to newest, are stored in three locations: compressed memory → memory → causally masked sequence. In the experiments, they observed an increase in attention weights from oldest activations stored in the regular memory, to activations stored in the compressed memory, implying that the network is learning to preserve salient information.\nFig. 7. Attention weights with one standard deviation as error bars versus memory positions, from oldest (left) to newest (right). (Image source: Rae et al. 2019). Non-Differentiable External Memory $k$NN-LM (Khandelwal et al. 2020) enhances a pretrained LM with a separate $k$NN model by linearly interpolating the next token probabilities predicted by both models. The $k$NN model is built upon an external key-value store which can store any large pre-training dataset or OOD new dataset. This datastore is preprocessed to save a large number of pairs, (LM embedding representation of context, next token) and the nearest neighbor retrieval happens in the LM embedding space. Because the datastore can be gigantic, we need to rely on libraries for fast dense vector search such as FAISS or ScaNN. The indexing process only happens once and parallelism is easy to implement at inference time.\nAt inference time, the next token probability is a weighted sum of two predictions:\n $$ \\begin{aligned} p(y \\vert \\mathbf{x}) \u0026= \\lambda \\; p_\\text{kNN}(y \\vert \\mathbf{x}) + (1- \\lambda) \\; p_\\text{LM}(y \\vert \\mathbf{x}) \\\\ p_\\text{kNN}(y \\vert \\mathbf{x}) \u0026\\propto \\sum_{(k_i, w_i) \\in \\mathcal{N}} \\mathbb{1}[y = w_i] \\exp(-d(k_i, f(\\mathbf{x}))) \\end{aligned} $$ where $\\mathcal{N}$ contains a set of nearest neighbor data points retrieved by $k$NN; $d(., .)$ is a distance function such as L2 distance.\nAccording to the experiments, larger datastore size or larger $k$ is correlated with better perplexity. The weighting scalar $\\lambda$ should be tuned, but in general it is expected to be larger for out-of-domain data compared to in-domain data and larger datastore can afford a larger $\\lambda$.\nSPALM (Adaptive semiparametric language models; Yogatama et al. 2021) incorporates both (1) Transformer-XL style memory for hidden states from external context as short-term memory and (2) $k$NN-LM style key-value store as long memory.\nFig. 8. Illustration of how SPALM combines context memory of past hidden states (short term memory) with an external key-value datastore (long term memory) to support longer context. (Image source: Yogatama et al. 2021). SPALM runs $k$NN search to fetch $k$ tokens with most relevant context. For each token we can get the same embedding representation provided by a pretrained LM, denoted as $\\{\\mathbf{y}_i\\}_{i=1}^k$. The gating mechanism first aggregates the retrieved token embeddings with a simple attention layer using $\\mathbf{h}^R_t$ (the hidden state for token $x_t$ at layer $R$) as a query and then learns a gating parameter $\\mathbf{g}_t$ to balance between local information $\\mathbf{h}^R_t$ and long-term information $\\mathbf{m}_t$.\n $$ \\begin{aligned} \\mathbf{m}_t \u0026= \\sum_{i=1}^k \\frac{\\exp(\\mathbf{y}_i^\\top \\mathbf{h}^R_t)}{\\sum_{j=1}^k \\exp(\\mathbf{y}_j^\\top \\mathbf{h}^R_t)} \\cdot \\mathbf{y}_i \\\\ \\mathbf{g}_t \u0026= \\sigma(\\mathbf{w}_g^\\top \\mathbf{h}_t^R) \\\\ \\mathbf{z}_t \u0026= (1 - \\mathbf{g}_t) \\odot \\mathbf{m}_t + \\mathbf{g}_t \\odot \\mathbf{h}^R_t \\\\ p(x_{t+1}\\mid \\mathbf{x}_{\\leq t}) \u0026= \\text{softmax}(\\mathbf{z}_t; \\mathbf{W}) \\end{aligned} $$ where $\\mathbf{w}_g$ is a parameter vector to learn; $\\sigma(.)$ is sigmoid; $\\mathbf{W}$ is the word embedding matrix shared between both input and output tokens. Different from $k$NN-LM, they didn\u0026rsquo;t find the nearest neighbor distance to be helpful in the aggregation of retrieved tokens.\nDuring training, the key representations in the long-term memory stay constant, produced by a pretrained LM, but the value encoder, aka the word embedding matrix, gets updated.\nMemorizing Transformer (Wu et al. 2022) adds a $k$NN-augmented attention layer near the top stack of a decoder-only Transformer. This special layer maintains a Transformer-XL style FIFO cache of past key-value pairs.\nThe same QKV values are used for both local attention and $k$NN mechanisms. The $k$NN lookup returns top-$k$ (key, value) pairs for each query in the input sequence and then they are processed through the self-attention stack to compute a weighted average of retrieved values. Two types of attention are combined with a learnable per-head gating parameter. To prevent large distributional shifts in value magnitude, both keys and values in the cache are normalized.\nWhat they found during experiments with Memorizing Transformer:\n It is observed in some experiments that training models with a small memory and then finetuned with a larger memory works better than training with a large memory from scratch. The smaller Memorizing Transformer with just 8k tokens in memory can match the perplexity of a larger vanilla Transformer with 5X more trainable parameters. Increasing the size of external memory provided consistent gains up to a size of 262K. A non-memory transformer can be finetuned to use memory. Fig. 9. Fine-tuning a vanilla Transformer with a key-value memory can achieve similar performance as training a memorizing transformer from scratch. (Image source: Wu et al. 2022). Distance-Enhanced Attention Scores Distance Aware Transformer(DA-Transformer; Wu, et al. 2021) and Attention with Linear Biases (ALiBi; Press et al. 2022) are motivated by similar ideas \u0026mdash; in order to encourage the model to extrapolate over longer context than what the model is trained on, we can explicitly attach the positional information to every pair of attention score based on the distance between key and query tokens.\nNote that the default positional encoding in vanilla Transformer only adds positional information to the input sequence, while later improved encoding mechanisms alter attention scores of every layer, such as rotary position embedding, and they take on form very similar to distance enhanced attention scores.\nDA-Transformer (Wu, et al. 2021) multiplies attention scores at each layer by a learnable bias that is formulated as a function of the distance between key and query. Different attention heads use different parameters to distinguish diverse preferences to short-term vs long-term context. Given two positions, $i, j$, DA-Transformer uses the following weighting function to alter the self-attention score:\n $$ \\begin{aligned} \\mathbf{R}^{(i)} \u0026= \\alpha_i \\mathbf{R} \\quad \\text{where }R_{ij} = \\vert i-j \\vert\\\\ f(\\mathbf{R}^{(i)}; \\beta_i) \u0026= \\frac{1 + \\exp(\\beta_i)}{1 + \\exp(\\beta_i - \\mathbf{R}^{(i)})} \\\\ \\text{attn}(\\mathbf{Q}^{(i)}, \\mathbf{K}^{(i)}, \\mathbf{V}^{(i)}) \u0026= \\text{row-softmax}\\Big(\\frac{\\text{ReLU}(\\mathbf{Q}^{(i)}\\mathbf{K}^{(i)\\top})f(\\mathbf{R}^{(i)})}{\\sqrt{d}}\\Big) \\mathbf{V}^{(i)} \\end{aligned} $$ where $\\alpha_i$ is a learnable parameters to weight relative distance differently per head where the head is indexed by superscript $^{(i)}$; $\\beta_i$ is a learnable parameter to control the upper bound and ascending slope wrt the distance for the $i$-th attention head. The weighting function $f(.)$ is designed in a way that: (1) $f(0)=1$; (2) $f(\\mathbf{R}^{(i)}) = 0$ when $\\mathbf{R}^{(i)} \\to -\\infty$; (3) $f(\\mathbf{R}^{(i)})$ is bounded when $\\mathbf{R}^{(i)} \\to +\\infty$; (4) the scale is tunable; (5) and the function is monotonic. The extra time complexity brought by $f(\\mathbf{R}^{(i)})$ is $\\mathcal{O}(L^2)$ and it is small relative to the self attention time complexity $\\mathcal{O}(L^2 d)$. The extra memory consumption is minimal, ~$\\mathcal{O}(2h)$.\nInstead of multipliers, ALiBi (Press et al. 2022) adds a constant bias term on query-key attention scores, proportional to pairwise distances. The bias introduces a strong recency preference and penalizes keys that are too far away. The penalties are increased at different rates within different heads. $$ \\text{softmax}(\\mathbf{q}_i \\mathbf{K}^\\top + \\alpha_i \\cdot [0, -1, -2, \\dots, -(i-1)]) $$ where $\\alpha_i$ is a head-specific weighting scalar. Different from DA-transformer, $\\alpha_i$ is not learned but fixed as a geometric sequence; for example, for 8 heads, ${\\alpha_i} = {\\frac{1}{2}, \\frac{1}{2^2}, \\dots, \\frac{1}{2^8}}$. The overall idea is very much similar to what relative positional encoding aims to solve.\nFig. 10. Illustration of how ALiBi enhances attention scores with a positional bias term. (Image source: Press et al. 2021). With ALiBi, Press et al. (2022) trained a 1.3B model on context length 1024 during training and extrapolated to 2046 at inference time.\nFig. 11. Extrapolation experiments for running inference with Transformers of different configs, including sinusoidal positional encoding, rotary positional encoding, simplified relative positional encoding in T5 and ALiBi. All models were trained with small context length but inference ran for much longer context. (Image source: Press et al. 2021). Make it Recurrent Universal Transformer (Dehghani, et al. 2019) combines self-attention in Transformer with the recurrent mechanism in RNN, aiming to benefit from both a long-term global receptive field of Transformer and learned inductive biases of RNN. Rather than going through a fixed number of layers, Universal Transformer dynamically adjusts the number of steps using adaptive computation time. If we fix the number of steps, an Universal Transformer is equivalent to a multi-layer Transformer with shared parameters across layers.\nOn a high level, the universal transformer can be viewed as a recurrent function for learning the hidden state representation per token. The recurrent function evolves in parallel across token positions and the information between positions is shared through self-attention.\nFig. 12. How the Universal Transformer refines a set of hidden state representations repeatedly for every position in parallel. (Image source: Figure 1 in Dehghani, et al. 2019). Given an input sequence of length $L$, Universal Transformer iteratively updates the representation $\\mathbf{h}^t \\in \\mathbb{R}^{L \\times d}$ at step $t$ for an adjustable number of steps. At step 0, $\\mathbf{h}^0$ is initialized to be same as the input embedding matrix. All the positions are processed in parallel in the multi-head self-attention mechanism and then go through a recurrent transition function.\n $$ \\begin{aligned} \\mathbf{A}^t \u0026= \\text{LayerNorm}(\\mathbf{h}^{t-1} + \\text{MultiHeadAttention}(\\mathbf{h}^{t-1} + \\mathbf{P}^t) \\\\ \\mathbf{h}^t \u0026= \\text{LayerNorm}(\\mathbf{A}^{t-1} + \\text{Transition}(\\mathbf{A}^t)) \\end{aligned} $$ where $\\text{Transition}(.)$ is either a separable convolution or a fully-connected neural network that consists of two position-wise (i.e. applied to each row of $\\mathbf{A}^t$ individually) affine transformation + one ReLU.\nThe positional encoding $\\mathbf{P}^t$ uses sinusoidal position signal but with an additional time dimension:\n $$ \\text{PE}(i, t, \\delta) = \\begin{cases} \\sin(\\frac{i}{10000^{2\\delta'/d}}) \\oplus \\sin(\\frac{t}{10000^{2\\delta'/d}}) \u0026 \\text{if } \\delta = 2\\delta'\\\\ \\cos(\\frac{i}{10000^{2\\delta'/d}}) \\oplus \\cos(\\frac{t}{10000^{2\\delta'/d}}) \u0026 \\text{if } \\delta = 2\\delta' + 1\\\\ \\end{cases} $$ Fig. 13. A simplified illustration of Universal Transformer. The encoder and decoder share the same basic recurrent structure. But the decoder also attends to final encoder representation $\\mathbf{h}^T$. (Image source: Figure 2 in Dehghani, et al. 2019) In the adaptive version of Universal Transformer, the number of recurrent steps $T$ is dynamically determined by ACT. Each position is equipped with a dynamic ACT halting mechanism. Once a per-token recurrent block halts, it stops taking more recurrent updates but simply copies the current value to the next step until all the blocks halt or until the model reaches a maximum step limit.\nAdaptive Modeling Adaptive modeling refers to a mechanism that can adjust the amount of computation according to different inputs. For example, some tokens may only need local information and thus demand a shorter attention span; Or some tokens are relatively easier to predict and do not need to be processed through the entire attention stack.\nAdaptive Attention Span One key advantage of Transformer is the capability of capturing long-term dependencies. Depending on the context, the model may prefer to attend further sometime than others; or one attention head may had different attention pattern from the other. If the attention span could adapt its length flexibly and only attend further back when needed, it would help reduce both computation and memory cost to support longer maximum context size in the model.\nThis is the motivation for Adaptive Attention Span. Sukhbaatar et al (2019) proposed a self-attention mechanism that seeks an optimal attention span. They hypothesized that different attention heads might assign scores differently within the same context window (See Fig. 14) and thus the optimal span would be trained separately per head.\nFig. 14. Two attention heads in the same model, A \u0026 B, assign attention differently within the same context window. Head A attends more to the recent tokens, while head B look further back into the past uniformly. (Image source: Sukhbaatar, et al. 2019) Given the $i$-th token, we need to compute the attention weights between this token and other keys within its attention span of size $s$:\n $$ \\begin{aligned} e_{ij} \u0026= \\mathbf{q}_i {\\mathbf{k}_j}^\\top \\\\ a_{ij} \u0026= \\text{softmax}(e_{ij}) = \\frac{\\exp(e_{ij})}{\\sum_{r=i-s}^{i-1} \\exp(e_{ir})} \\\\ \\mathbf{y}_i \u0026= \\sum_{r=i-s}^{i-1}a_{ir}\\mathbf{v}_r = \\sum_{r=i-s}^{i-1}a_{ir}\\mathbf{x}_r\\mathbf{W}^v \\end{aligned} $$ A soft mask function $m_z$ is added to control for an effective adjustable attention span, which maps the distance between query and key into a [0, 1] value. $m_z$ is parameterized by $z \\in [0, s]$ and $z$ is to be learned:\n $$ m_z(x) = \\text{clip}(\\frac{1}{R}(R+z-x), 0, 1) $$ where $R$ is a hyper-parameter which defines the softness of $m_z$.\nFig. 15. The soft masking function used in the adaptive attention span. (Image source: Sukhbaatar, et al. 2019.) The soft mask function is applied to the softmax elements in the attention weights:\n $$ a_{ij} = \\frac{m_z(i-j)\\exp(s_{ij})}{\\sum_{r=i-s}^{i-1}m_z(i-r) \\exp(s_{ir})} $$ In the above equation, $z$ is differentiable so it is trained jointly with other parts of the model. Parameters $z^{(i)}, i=1, \\dots, h$ are learned separately per head. Moreover, the loss function has an extra L1 penalty on $\\sum_{i=1}^h z^{(i)}$.\nUsing Adaptive Computation Time, the approach can be further enhanced to have flexible attention span length, adaptive to the current input dynamically. The span parameter $z_t$ of an attention head at time $t$ is a sigmoidal function, $z_t = S \\sigma(\\mathbf{v} \\cdot \\mathbf{x}_t +b)$, where the vector $\\mathbf{v}$ and the bias scalar $b$ are learned jointly with other parameters.\nIn the experiments of Transformer with adaptive attention span, Sukhbaatar, et al. (2019) found a general tendency that lower layers do not require very long attention spans, while a few attention heads in higher layers may use exceptionally long spans. Adaptive attention span also helps greatly reduce the number of FLOPS, especially in a big model with many attention layers and a large context length.\nDepth-Adaptive Transformer At inference time, it is natural to assume that some tokens are easier to predict and thus do not require as much computation as others. Therefore we may only process its prediction through a limited number of layers to achieve a good balance between speed and performance.\nBoth Depth-Adaptive Transformer (Elabyad et al. 2020) and Confident Adaptive Language Model (CALM; Schuster et al. 2022) are motivated by this idea and learn to predict optimal numbers of layers needed for different input tokens.\nDepth-adaptive transformer (Elabyad et al. 2020) attaches an output classifier to every layer to produce exit predictions based on activations of that layer. The classifier weight matrices can be different per layer or shared across layers. During training, the model sample different sequences of exits such that the model is optimized with hidden states of different layers. The learning objective incorporates likelihood probabilities predicted at different layers, $n=1, \\dots, N$:\n $$ \\text{LL}^n_t = \\log p(y_t \\vert \\mathbf{h}^n_{t-1}) \\quad \\text{LL}^n = \\sum_{t=1}^{\\vert\\mathbf{y}\\vert} LL^n_t $$ Adaptive depth classifiers outputs a parametric distribution $q_t$. It is trained with cross entropy loss against an oracle distribution $q^*_t$. The paper explored three confiurations for how to learn such a classifier $q_t$.\nFig. 16. Illustration of three types of adaptive depth classifiers. (Image source: Elabyad et al. 2020). Sequence-specific depth classifier: All tokens of the same sequence share the same exit block. It depends on the average of the encoder representation of the sequence. Given an input sequence $\\mathbf{x}$ of length $L$, the classifier takes $\\bar{\\mathbf{x}} = \\frac{1}{L} \\sum_{t=1}^L \\mathbf{x}_t$ as input and outputs a multinomial distribution of $N$ dimensions, corresponding to $N$ layers.\n $$ \\begin{aligned} q(n \\vert \\mathbf{x}) \u0026=\\text{softmax}(\\mathbf{W}_n \\bar{\\mathbf{x}} + b_n) \\in \\mathbb{R}^N \\\\ q_\\text{lik}^*(\\mathbf{x}, \\mathbf{y}) \u0026= \\delta(\\arg\\max_n \\text{LL}^n - \\lambda n) \\\\ \\text{or }q_\\text{corr}^*(\\mathbf{x}, \\mathbf{y}) \u0026= \\delta(\\arg\\max_n C^n - \\lambda n) \\text{ where }C^n = \\vert\\{t \\vert y_t = \\arg\\max_y p(y \\vert \\mathbf{h}^n_{t-1})\\}\\vert \\\\ \\end{aligned} $$ where $\\delta$ is dirac delta (unit impulse) function and $-\\lambda n$ is a regularization term to encourage lower layer exits. The ground truth $q^*$ can be prepared in two way, based on maximum likelihood $q_\\text{lik}^*$ or correctness $q_\\text{corr}^*$. \n Token-specific depth classifier (multinomial): Each token is decoded with different exit block, predicted conditioned on the first decoder hidden state $\\mathbf{h}^1_t$:\n $$ q_t(n \\vert \\mathbf{x}, \\mathbf{y}_{ Token-specific depth classifier (geometric-like): A binary exit prediction distribution is made per layer per token, $\\mathcal{X}^n_t$. The RBF kernel $\\kappa(t, t’) = \\exp(\\frac{\\vert t - t’ \\vert^2}{\\sigma})$ is used to smooth the predictions to incorporate the impact of current decision on future time steps.\n $$ \\begin{aligned} \\mathcal{X}^n_t \u0026= \\text{sigmoid}(\\mathbf{w}_n^\\top \\mathbf{h}^n_t + b_n)\\quad \\forall n \\in [1, \\dots, N-1] \\\\ q_t(n \\vert \\mathbf{x}, \\mathbf{y}_{ At inference time, the confidence threshold for making an exit decision needs to be calibrated. Depth-adaptive transformer finds such a threshold on a validation set via grid search. CALM (Schuster et al. 2022) applied the Learn then Test (LTT) framework (Angelopoulos et al. 2021) to identify a subset of valid thresholds and chose the minimum value as the threshold for inference. Except for training per-layer exit classifier, CALM also explored other methods for adaptive depth prediction, including the softmax responses (i.e. difference between top two softmax outputs) and hidden state saturation (i.e. $\\cos(\\mathbf{h}^n_t, \\mathbf{h}^{n+1}_t)$) as confidence scores for exit decisions. They found softmax responses result in best inference speedup.\nEfficient Attention The computation and memory cost of the vanilla Transformer grows quadratically with sequence length and hence it is hard to be applied on very long sequences. Many efficiency improvements for Transformer architecture have something to do with the self-attention module - making it cheaper, smaller or faster to run. See the survey paper on Efficient Transformers (Tay et al. 2020).\nSparse Attention Patterns Fixed Local Context A simple alternation to make self-attention less expensive is to restrict the attention span of each token to local context only, so that self-attention grows linearly with the sequence length.\nThe idea was introduced by Image Transformer (Parmer, et al 2018), which formulates image generation as sequence modeling using an encoder-decoder transformer architecture:\n The encoder generates a contextualized, per-pixel-channel representation of the source image; Then the decoder autoregressively generates an output image, one channel per pixel at each time step. Let\u0026rsquo;s label the representation of the current pixel to be generated as the query $\\mathbf{q}$. Other positions whose representations will be used for computing $\\mathbf{q}$ are key vector $\\mathbf{k}_1, \\mathbf{k}_2, \\dots$ and they together form a memory matrix $\\mathbf{M}$. The scope of $\\mathbf{M}$ defines the context window for pixel query $\\mathbf{q}$.\nImage Transformer introduced two types of localized $\\mathbf{M}$, as illustrated below.\nFig. 17. Illustration of 1D and 2D attention span for visual inputs in Image Transformer. The black line marks a query block and the cyan outlines the actual attention span for pixel q. (Image source: Figure 2 in Parmer et al, 2018) 1D Local Attention: The input image is flattened in the raster scanning order, that is, from left to right and top to bottom. The linearized image is then partitioned into non-overlapping query blocks. The context window consists of pixels in the same query block as $\\mathbf{q}$ and a fixed number of additional pixels generated before this query block.\n 2D Local Attention: The image is partitioned into multiple non-overlapping rectangular query blocks. The query pixel can attend to all others in the same memory blocks. To make sure the pixel at the top-left corner can also have a valid context window, the memory block is extended to the top, left and right by a fixed amount, respectively.\n Strided Context Sparse Transformer (Child et al., 2019) introduced factorized self-attention, through sparse matrix factorization, making it possible to train dense attention networks with hundreds of layers on sequence length up to 16,384, which would be infeasible on modern hardware otherwise.\nGiven a set of attention connectivity pattern $\\mathcal{S} = \\{S_1, \\dots, S_n\\}$, where each $S_i$ records a set of key positions that the $i$-th query vector attends to.\n $$ \\begin{aligned} \\text{Attend}(\\mathbf{X}, \\mathcal{S}) \u0026= \\Big( a(\\mathbf{x}_i, S_i) \\Big)_{i \\in \\{1, \\dots, L\\}} \\\\ \\text{ where } a(\\mathbf{x}_i, S_i) \u0026= \\text{softmax}\\Big(\\frac{(\\mathbf{x}_i \\mathbf{W}^q)(\\mathbf{x}_j \\mathbf{W}^k)_{j \\in S_i}^\\top}{\\sqrt{d_k}}\\Big) (\\mathbf{x}_j \\mathbf{W}^v)_{j \\in S_i} \\end{aligned} $$ Note that although the size of $S_i$ is not fixed, $a(\\mathbf{x}_i, S_i)$ is always of size $d_v$ and thus $\\text{Attend}(\\mathbf{X}, \\mathcal{S}) \\in \\mathbb{R}^{L \\times d_v}$.\nIn anto-regressive models, one attention span is defined as $S_i = \\{j: j \\leq i\\}$ as it allows each token to attend to all the positions in the past.\nIn factorized self-attention, the set $S_i$ is decomposed into a tree of dependencies, such that for every pair of $(i, j)$ where $j \\leq i$, there is a path connecting $i$ back to $j$ and $i$ can attend to $j$ either directly or indirectly.\nPrecisely, the set $S_i$ is divided into $p$ non-overlapping subsets, where the $m$-th subset is denoted as $A^{(m)}_i \\subset S_i, m = 1,\\dots, p$. Therefore the path between the output position $i$ and any $j$ has a maximum length $p + 1$. For example, if $(j, a, b, c, \\dots, i)$ is a path of indices between $i$ and $j$, we would have $j \\in A_a^{(1)}, a \\in A_b^{(2)}, b \\in A_c^{(3)}, \\dots$, so on and so forth.\nSparse Factorized Attention\nSparse Transformer proposed two types of fractorized attention. It is easier to understand the concepts as illustrated in Fig. 10 with 2D image inputs as examples.\nFig. 18. The top row illustrates the attention connectivity patterns in (a) Transformer, (b) Sparse Transformer with strided attention, and (c) Sparse Transformer with fixed attention. The bottom row contains corresponding self-attention connectivity matrices. Note that the top and bottom rows are not in the same scale. (Image source: Child et al., 2019 + a few of extra annotations.) Strided attention with stride $\\ell \\sim \\sqrt{n}$. This works well with image data as the structure is aligned with strides. In the image case, each pixel would attend to all the previous $\\ell$ pixels in the raster scanning order (naturally cover the entire width of the image) and then those pixels attend to others in the same column (defined by another attention connectivity subset).\n $$ \\begin{aligned} A_i^{(1)} \u0026= \\{ t, t+1, \\dots, i\\} \\text{, where } t = \\max(0, i - \\ell) \\\\ A_i^{(2)} \u0026= \\{j: (i-j) \\mod \\ell = 0\\} \\end{aligned} $$ Fixed attention. A small set of tokens summarize previous locations and propagate that information to all future locations.\n $$ \\begin{aligned} A_i^{(1)} \u0026= \\{j: \\lfloor \\frac{j}{\\ell} \\rfloor = \\lfloor \\frac{i}{\\ell} \\rfloor \\} \\\\ A_i^{(2)} \u0026= \\{j: j \\mod \\ell \\in \\{\\ell-c, \\dots, \\ell-1\\} \\} \\end{aligned} $$ where $c$ is a hyperparameter. If $c=1$, it restricts the representation whereas many depend on a few positions. The paper chose $c\\in \\{ 8, 16, 32 \\}$ for $\\ell \\in \\{ 128, 256 \\}$.\n Use Factorized Self-Attention in Transformer\nThere are three ways to use sparse factorized attention patterns in Transformer architecture:\n One attention type per residual block and then interleave them, $\\text{attn}(\\mathbf{X}) = \\text{Attend}(\\mathbf{X}, A^{(n \\mod p)}) \\mathbf{W}^o$, where $n$ is the index of the current residual block. Set up a single head which attends to locations that all the factorized heads attend to, $\\text{attn}(\\mathbf{X}) = \\text{Attend}(\\mathbf{X}, \\cup_{m=1}^p A^{(m)}) \\mathbf{W}^o $. Use a multi-head attention mechanism, but different from vanilla Transformer, each head might adopt a pattern presented above, 1 or 2. $\\rightarrow$ This option often performs the best. Sparse Transformer also proposed a set of changes so as to train the Transformer up to hundreds of layers, including gradient checkpointing, recomputing attention \u0026amp; FF layers during the backward pass, mixed precision training, efficient block-sparse implementation, etc. Please check the paper for more details or my previous post on techniques for scaling up model training.\nBlockwise Attention (Qiu et al. 2019) introduces a sparse block matrix to only allow each token to attend to a small set of other tokens. Each attention matrix of size $L \\times L$ is partitioned into $n \\times n$ smaller blocks of size $\\frac{L}{n}\\times\\frac{L}{n}$ and a sparse block matrix $\\mathbf{M} \\in \\{0, 1\\}^{L \\times L}$ is defined by a permutation $\\pi$ of ${1, \\dots, n}$, which records the column index per row in the block matrix.\n $$ \\begin{aligned} \\text{attn}(\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}, \\mathbf{M}) \u0026= \\text{softmax}\\Big(\\frac{\\mathbf{Q}\\mathbf{K}^\\top}{\\sqrt{d}} \\odot \\mathbf{M}\\Big)\\mathbf{V} \\\\ (\\mathbf{A} \\odot \\mathbf{M})_{ij} \u0026= \\begin{cases} A_{ij} \u0026 \\text{if }M_{ij} = 1 \\\\ -\\infty \u0026 \\text{if }M_{ij} = 0 \\\\ \\end{cases} \\\\ \\text{where } M_{ij} \u0026= \\begin{cases} 1 \u0026 \\text{if }\\pi\\big(\\lfloor\\frac{(i-1)n}{L} + 1\\rfloor\\big) = \\lfloor\\frac{(j-1)n}{L} + 1\\rfloor \\\\ 0 \u0026 \\text{otherwise} \\end{cases} \\end{aligned} $$ The actual implementation of Blockwise Attention only stores QKV as block matrices, each of size $n\\times n$:\n $$ \\text{Blockwise-attn}(\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}, \\mathbf{M}) = \\begin{bmatrix} \\text{softmax}\\big(\\frac{\\hat{\\mathbf{q}}_1\\hat{\\mathbf{k}}_{\\pi(1)}^\\top}{\\sqrt{d}} \\Big)\\hat{\\mathbf{v}}_{\\pi(1)} \\\\ \\vdots \\\\ \\text{softmax}\\big(\\frac{\\hat{\\mathbf{q}}_n\\hat{\\mathbf{k}}_{\\pi(n)}^\\top}{\\sqrt{d}} \\odot \\Big)\\hat{\\mathbf{v}}_{\\pi(n)} \\\\ \\end{bmatrix} $$ where $\\hat{\\mathbf{q}}_i$, $\\hat{\\mathbf{k}}_i$ and $\\hat{\\mathbf{v}}_i$ are the $i$-the row in the QKV block matrix respectively. Each $\\mathbf{q}_i\\mathbf{k}_{\\pi(i)}^\\top, \\forall i = 1, \\dots, n$ is of size $\\frac{N}{n}\\times\\frac{N}{n}$ and therefore Blockwise Attention is able to reduce the memory complexity of attention matrix from $\\mathcal{O}(L^2)$ to $\\mathcal{O}(\\frac{L}{n}\\times\\frac{L}{n} \\times n) = \\mathcal{O}(L^2/n)$.\nCombination of Local and Global Context ETC (Extended Transformer Construction; Ainslie et al. 2019), Longformer (Beltagy et al. 2020) and Big Bird (Zaheer et al. 2020) models combine both local and global context when building an attention matrix. All these models can be initialized from existing pretrained models.\nGlobal-Local Attention of ETC (Ainslie et al. 2019) takes two inputs, (1) the long input $\\mathbf{x}^l$ of size $n_l$ which is the regular input sequence and (2) the global input $\\mathbf{x}^g$ of size $n_g$ which contains a smaller number of auxiliary tokens, $n_g \\ll n_l$. Attention is thus split into four components based on directional attention across these two inputs: g2g, g2l, l2g and l2l. Because the l2l attention piece can be very large, it is restricted to a fixed size attention span of radius $w$ (i.e. local attention span) and the l2l matrix can be reshaped to $n_l \\times (2w+1)$.\nETC utilizes four binary matrices to handle structured inputs, $\\mathbf{M}^{g2g}$, $\\mathbf{M}^{g2l}$, $\\mathbf{M}^{l2g}$ and $\\mathbf{M}^{l2l}$. For example, each element $z^g_i \\in \\mathbb{R}^d$ in the attention output $z^g = (z^g_1, \\dots, z^g_{n_g})$ for g2g attention piece is formatted as:\n $$ \\begin{aligned} a^{g2g}_{ij} = \\frac{1}{\\sqrt{d}} x^g_i \\mathbf{W}^Q (x^g_j \\mathbf{W}^K + P^K_{ij})^\\top - (1- M^{g2g}_{ij})C \\\\ A^{g2g}_{ij} = \\frac{\\exp(a^{g2g}_{ij})}{\\sum_{k=1}^{n_g} \\exp(a^{g2g}_{ik})} \\quad z^g_i = \\sum^{n_g}_{j=1} A^{g2g}_{ij} x^g_j \\mathbf{W}^V \\end{aligned} $$ where $P^K_{ij}$ is a learnable vector for relative position encoding and $C$ is a very large constant ($C=10000$ in the paper) to offset any attention weights when mask is off.\nFig. 19. Attention patterns of ETC, Longformer and Big Bird. One more update in ETC is to incorporate a CPC (contrastive predictive coding) task using NCE loss into the pretraining stage, besides the MLM task: The representation of one sentence should be similar to the representation of context around it when this sentence is masked.\nThe global input $\\mathbf{x}^g$ for ETC is constructed as follows: Assuming there are some segments within the long inputs (e.g. by sentence), each segment is attached with one auxiliary token to learn global inputs. Relative position encoding is used to mark the global segment tokens with the token position. Hard masking in one direction (i.e., tokens before vs after are labeled differently) is found to bring performance gains in some datasets.\nAttention pattern in Longformer contains three components:\n Local attention: Similar to ETC, local attention is controlled by a sliding window of fixed size $w$; Global attention of preselected tokens: Longformer has a few pre-selected tokens (e.g. [CLS] token) assigned with global attention span, that is, attending to all other tokens in the input sequence. Dilated attention: Dilated sliding window of fixed size $r$ and gaps of dilation size $d$, similar to Sparse Transformer; Big Bird is quite similar to Longformer, equipped with both local attention and a few preselected tokens with global attention span, but Big Bird replaces dilated attention with a new mechanism where all tokens attend to a set of random tokens. The design is motivated by the fact that attention pattern can be viewed as a directed graph and a random graph has the property that information is able to rapidly flow between any pair of nodes.\nLongformer uses smaller window size at lower layers and larger window sizes at higher layers. Ablation studies showed that this setup works better than reversed or fixed size config. Lower layers do not have dilated sliding windows to better learn to use immediate local context. Longformer also has a staged training procedure where initially the model is trained with small window size to learn from local context and then subsequent stages of training have window sizes increased and learning rate decreased.\nContent-based Attention The improvements proposed by Reformer (Kitaev, et al. 2020) aim to solve the following pain points in vanilla Transformer:\n Quadratic time and memory complexity within self-attention module. Memory in a model with $N$ layers is $N$-times larger than in a single-layer model because we need to store activations for back-propagation. The intermediate FF layers are often quite large. Reformer proposed two main changes:\n Replace the dot-product attention with locality-sensitive hashing (LSH) attention, reducing the complexity from $\\mathcal{O}(L^2)$ to $\\mathcal{O}(L\\log L)$. Replace the standard residual blocks with reversible residual layers, which allows storing activations only once during training instead of $N$ times (i.e. proportional to the number of layers). Locality-Sensitive Hashing Attention\nIn $\\mathbf{Q} \\mathbf{K}^\\top$ part of the attention formula, we are only interested in the largest elements as only large elements contribute a lot after softmax. For each query $\\mathbf{q}_i \\in \\mathbf{Q}$, we are looking for row vectors in $\\mathbf{K}$ closest to $\\mathbf{q}_i$. In order to find nearest neighbors quickly in high-dimensional space, Reformer incorporates Locality-Sensitive Hashing (LSH) into its attention mechanism.\nA hashing scheme $x \\mapsto h(x)$ is locality-sensitive if it preserves the distancing information between data points, such that close vectors obtain similar hashes while distant vectors have very different ones. The Reformer adopts a hashing scheme as such, given a fixed random matrix $\\mathbf{R} \\in \\mathbb{R}^{d \\times b/2}$ (where $b$ is a hyperparam), the hash function is $h(x) = \\arg\\max([xR; −xR])$.\n$$ \\mathbf{o}_i = \\sum_{j \\in S_i} \\exp(\\mathbf{q}_i \\cdot \\mathbf{k}_j - Z(i, S_i)) \\mathbf{v}_j \\text{, where } S_i = \\{j: j \\leq i\\} $$ -- Fig. 20. Illustration of Locality-Sensitive Hashing (LSH) attention. (Image source: right part of Figure 1 in Kitaev, et al. 2020). In LSH attention, a query can only attend to positions in the same hashing bucket, $S_i = \\{j: h(\\mathbf{q}_i) = h(\\mathbf{k}_j)\\}$. It is carried out in the following process, as illustrated in Fig. 20:\n (a) The attention matrix for full attention is often sparse. (b) Using LSH, we can sort the keys and queries to be aligned according to their hash buckets. (c) Set $\\mathbf{Q} = \\mathbf{K}$ (precisely $\\mathbf{k}_j = \\mathbf{q}_j / |\\mathbf{q}_j|$), so that there are equal numbers of keys and queries in one bucket, easier for batching. Interestingly, this \u0026ldquo;shared-QK\u0026rdquo; config does not affect the performance of the Transformer. (d) Apply batching where chunks of $m$ consecutive queries are grouped together. Fig. 21. The LSH attention consists of 4 steps: bucketing, sorting, chunking, and attention computation. (Image source: left part of Figure 1 in Kitaev, et al. 2020). Reversible Residual Network\nAnother improvement by Reformer is to use reversible residual layers (Gomez et al. 2017). The motivation for reversible residual network is to design the architecture in a way that activations at any given layer can be recovered from the activations at the following layer, using only the model parameters. Hence, we can save memory by recomputing the activation during backprop rather than storing all the activations.\nGiven a layer $x \\mapsto y$, the normal residual layer does $y = x + F(x)$, but the reversible layer splits both input and output into pairs $(x_1, x_2) \\mapsto (y_1, y_2)$ and then executes the following:\n $$ y_1 = x_1 + F(x_2),\\; y_2 = x_2 + G(y_1) $$ and reversing is easy:\n $$ x_2 = y_2 - G(y_1), \\; x_1 = y_1 − F(x_2) $$ Reformer applies the same idea to Transformer by combination attention ($F$) and feed-forward layers ($G$) within a reversible net block:\n $$ Y_1 = X_1 + \\text{Attention}(X_2), \\; Y_2 = X_2 + \\text{FeedForward}(Y_1) $$ The memory can be further reduced by chunking the feed-forward computation:\n $$ Y_2 = [Y_2^{(1)}; \\dots; Y_2^{(c)}] = [X_2^{(1)} + \\text{FeedForward}(Y_1^{(1)}); \\dots; X_2^{(c)} + \\text{FeedForward}(Y_1^{(c)})] $$ The resulting reversible Transformer does not need to store activation in every layer.\nRouting Transformer (Roy et al. 2021) is also built on content-based clustering of keys and queries. Instead of using a static hashing function like LSH, it utilizes online $k$-means clustering and combines it with local, temporal sparse attention to reduce the attention complexity from $O(L^2)$ to $O(L^{1.5})$.\nWithin routing attention, both keys and queries are clustered with $k$-means clustering method and the same set of centroids $\\boldsymbol{\\mu} = (\\mu_1, \\dots, \\mu_k) \\in \\mathbb{R}^{k \\times d}$. Queries are routed to keys that get assigned to the same centroid. The total complexity is $O(Lkd + L^2d/k)$, where $O(Lkd)$ is for running clustering assignments and $O(L^2d/k)$ is for attention computation. The cluster centroids are updated by EMA (exponential moving average) using all associated keys and queries.\nIn the experiments for Routing Transformer, some best config only has routing attention enabled in the last two layers of the model and half of the attention heads, while the other half utilizing local attention. They also observed that local attention is a pretty strong baseline and larger attention window always leads to better results.\nLow-Rank Attention Linformer (Wang et al. 2020) approximates the full attention matrix with a low rank matrix, reducing the time \u0026amp; space complexity to be linear. Instead of using expensive SVD to identify low rank decomposition, Linformer adds two linear projections $\\mathbf{E}_i, \\mathbf{F}_i \\in \\mathbb{R}^{L \\times k}$ for key and value matrices, respectively, reducing their dimensions from $L \\times d$ to $k \\times d$. As long as $k \\ll L$, the attention memory can be greatly reduced.\n $$ \\begin{aligned} \\overline{\\text{head}}_i \u0026= \\text{attn}(\\mathbf{X}_q\\mathbf{W}^q_i, \\mathbf{E}_i\\mathbf{X}_k\\mathbf{W}^k_i, \\mathbf{F}_i\\mathbf{X}_v\\mathbf{W}^v_i) \\\\ \u0026= \\underbrace{\\text{softmax}\\Big( \\frac{\\mathbf{X}_q\\mathbf{W}^q_i (\\mathbf{E}_i \\mathbf{X}_k\\mathbf{W}^k_i)^\\top}{\\sqrt{d}} \\Big)}_{\\text{low rank attention matrix }\\bar{A} \\in \\mathbb{R}^{k \\times d}} \\mathbf{F}_i \\mathbf{X}_v\\mathbf{W}^v_i \\end{aligned} $$ Additional techniques can be applied to further improve efficiency of Linformer:\n Parameter sharing between projection layers, such as head-wise, key-value and layer-wise (across all layers) sharing. Use different $k$ at different layers, as heads in higher layers tend to have a more skewed distribution (lower rank) and thus we can use smaller $k$ at higher layers. Use different types of projections; e.g. mean/max pooling, convolution layer with kernel and stride $L/k$. Fig. 22. (Left) Informer has two projection layers added for keys and values. (Right) Plot of inference time as a function of sequence length. (Image source: Wang et al. 2020). Random Feature Attention (RFA; Peng et al. 2021) relies on random feature methods (Rahimi \u0026amp; Recht, 2007) to approximate softmax operation in self-attention with low rank feature maps in order to achieve linear time and space complexity. Performers (Choromanski et al. 2021) also adopts random feature attention with improvements on the kernel construction to further reduce the kernel approximation error.\nThe main theorem behind RFA is from Rahimi \u0026amp; Recht, 2007:\n Let $\\phi: \\mathbb{R}^d \\to \\mathbb{R}^{2D}$ be a nonlinear transformation:\n $$ \\phi(\\mathbf{x}) = \\frac{1}{\\sqrt{D}}[\\sin(\\mathbf{w}_1^\\top \\mathbf{x}), \\dots, \\sin(\\mathbf{w}_D^\\top \\mathbf{x}), \\cos(\\mathbf{w}_1^\\top \\mathbf{x}), \\dots, \\cos(\\mathbf{w}_D^\\top \\mathbf{x})]^\\top $$ When $d$-dimensional random vectors $\\mathbf{w}_i$ are i.i.d. from $\\mathcal{N}(\\mathbf{0}, \\sigma^2\\mathbf{I}_d)$, $$ \\mathbb{E}_{\\mathbf{w}_i} [\\phi(\\mathbf{x}) \\cdot \\phi(\\mathbf{y})] = \\exp(-\\frac{\\| \\mathbf{x} - \\mathbf{y} \\|^2}{2\\sigma^2}) $$ An unbiased estimation of $\\exp(\\mathbf{x} \\cdot \\mathbf{y})$ is:\n $$ \\begin{aligned} \\exp(\\mathbf{x} \\cdot \\mathbf{y} / \\sigma^2) \u0026= \\exp(\\frac{1}{2\\sigma^2}(\\|\\mathbf{x}\\|^2 + \\|\\mathbf{y}\\|^2 - \\|\\mathbf{x} - \\mathbf{y}\\|^2) \\\\ \u0026= \\exp(\\frac{\\|\\mathbf{x}\\|^2}{2\\sigma^2}) \\exp(\\frac{\\|\\mathbf{y}\\|^2}{2\\sigma^2}) ( - \\frac{\\|\\mathbf{x} - \\mathbf{y}\\|^2}{2\\sigma^2}) \\\\ \u0026\\approx \\exp(\\frac{\\|\\mathbf{x}\\|^2}{2\\sigma^2}) \\exp(\\frac{\\|\\mathbf{y}\\|^2}{2\\sigma^2})\\;\\phi(\\mathbf{x})\\cdot\\phi(\\mathbf{y}) \\\\ \u0026= \\exp(\\frac{1}{\\sigma^2})\\;\\phi(\\mathbf{x})\\cdot\\phi(\\mathbf{y}) \u0026 \\text{; unit vectors} \\end{aligned} $$ Then we can write the attention function as follows, where $\\otimes$ is outer product operation and $\\sigma^2$ is the temperature:\n $$ \\begin{aligned} \\text{attn}(\\mathbf{q}_t, \\{\\mathbf{k}_i\\}, \\{\\mathbf{v}_i\\}) \u0026= \\sum_i \\frac{\\exp(\\mathbf{q}_t\\cdot\\mathbf{k}_i/\\sigma^2)}{\\sum_j \\exp(\\mathbf{q}_t\\cdot\\mathbf{k}_j/\\sigma^2)}\\mathbf{v}_i^\\top \\approx \\sum_i \\frac{\\phi(\\mathbf{q}_t)\\phi(\\mathbf{k}_i)\\mathbf{v}_i^\\top}{\\sum_j \\phi(\\mathbf{q}_t)\\phi(\\mathbf{k}_j)} \\\\ \u0026= \\color{green}{\\frac{\\phi(\\mathbf{q}_t)^\\top \\sum_i \\phi(\\mathbf{k}_i)\\otimes\\mathbf{v}_i}{\\phi(\\mathbf{q}_t)^\\top \\sum_j \\phi(\\mathbf{k}_j)} = \\text{RFA}(\\mathbf{q}_t, \\{\\mathbf{k}_i\\}, \\{\\mathbf{v}_i\\})} \\end{aligned} $$ Fig. 23. (Left) The order of computation for default softmax operation. (Right) The order of computation when using random feature attention, a lot cheaper than default softmax. (Image source: Peng et al. 2021). Causal Attention RFA has token at time step $t$ only attend to earlier keys and values $\\{\\mathbf{k}_i\\}_{i \\leq t}, \\{\\mathbf{v}_i\\}_{i \\leq t}$. Let us use a tuple of variables, $(\\mathbf{S}_t \\in \\mathbb{R}^{2D \\times d}, \\mathbf{z} \\in \\mathbb{R}^{2D})$, to track the hidden state history at time step $t$, similar to RNNs:\n $$ \\begin{aligned} \u0026\\text{causal-RFA}(\\mathbf{q}_t, \\{\\mathbf{k}_i\\}_{i \\leq t}, \\{\\mathbf{v}_i\\}_{i \\leq t}) = \\frac{\\phi(\\mathbf{q}_t)^\\top \\mathbf{S}_t}{\\phi(\\mathbf{q}_t) \\cdot \\mathbf{z}_t} \\\\ \u0026\\text{where } \\mathbf{S}_t = \\mathbf{S}_{t-1} + \\phi(\\mathbf{k}_t)\\otimes\\mathbf{v}_t, \\quad \\mathbf{z}_t = \\mathbf{z}_{t-1} + \\phi(\\mathbf{k}_t) \\end{aligned} $$ where $2D$ is the size of $\\phi(.)$ and $D$ should be no less than the model size $d$ for reasonable approximation.\nRFA leads to significant speedup in autoregressive decoding and the memory complexity mainly depends on the choice of $D$ when constructing the kernel $\\phi(.)$.\nPerformer modifies the random feature attention with positive random feature maps to reduce the estimation error. It also keeps the randomly sampled $\\mathbf{w}_1, \\dots, \\mathbf{w}_D$ to be orthogonal to further reduce the variance of the estimator.\nFig. 24. Comparison of approximation error when using (Left) i.i.d vs orthogonal features and (Right) sin/cos vs positive random features. (Image source: Choromanski et al. 2021). Transformers for Reinforcement Learning The self-attention mechanism avoids compressing the whole past into a fixed-size hidden state and does not suffer from vanishing or exploding gradients as much as RNNs. Reinforcement Learning tasks can for sure benefit from these traits. However, it is quite difficult to train Transformer even in supervised learning, let alone in the RL context. It could be quite challenging to stabilize and train a LSTM agent by itself, after all.\nThe Gated Transformer-XL (GTrXL; Parisotto, et al. 2019) is one attempt to use Transformer for RL. GTrXL succeeded in stabilizing training with two changes on top of Transformer-XL:\n The layer normalization is only applied on the input stream in a residual module, but NOT on the shortcut stream. A key benefit to this reordering is to allow the original input to flow from the first to last layer. The residual connection is replaced with a GRU-style (Gated Recurrent Unit; Chung et al., 2014) gating mechanism. $$ \\begin{aligned} r \u0026= \\sigma(W_r^{(l)} y + U_r^{(l)} x) \\\\ z \u0026= \\sigma(W_z^{(l)} y + U_z^{(l)} x - b_g^{(l)}) \\\\ \\hat{h} \u0026= \\tanh(W_g^{(l)} y + U_g^{(l)} (r \\odot x)) \\\\ g^{(l)}(x, y) \u0026= (1-z)\\odot x + z\\odot \\hat{h} \\end{aligned} $$ The gating function parameters are explicitly initialized to be close to an identity map - this is why there is a $b_g$ term. A $b_g \u0026gt; 0$ greatly helps with the learning speedup.\nFig. 25. Comparison of the model architecture of Transformer-XL, Transformer-XL with the layer norm reordered, and Gated Transformer-XL. (Image source: Figure 1 in Parisotto, et al. 2019) Decision Transformer (DT; Chen et al 2021) formulates Reinforcement Learning problems as a process of conditional sequence modeling, outputting the optimal actions conditioned on the desired return, past states and actions. It therefore becomes straightforward to use Transformer architecture. Decision Transformer is for off-policy RL, where the model only has access to a fixed collection of trajectories collected by other policies.\nTo encourage the model to learn how to act in order to achieve a desired return, it feeds the model with desired future return $\\hat{R} = \\sum_{t'=t}^T r_{t'}$ instead of the current reward. The trajectory consists of a list of triplets, (return-to-go $\\hat{R}_t, state $s_t$, action $a_t$), and it is used as an input sequence for Transformer:\n $$ \\tau = (\\hat{R}_1, s_1, a_1, \\hat{R}_2, s_2, a_2, \\dots, \\hat{R}_T, s_T, a_T) $$ Three linear layers are added and trained for return-to-go, state and action respectively to extract token embeddings. The prediction head learns to predict $a_t$ corresponding to the input token $s_t$. The training uses cross-entropy loss for discrete actions or MSE for continuous actions. Predicting the states or return-to-go was not found to help improve the performance in their experiments.\nThe experiments compared DT with several model-free RL algorithm baselines and showed that:\n DT is more efficient than behavior cloning in low data regime; DT can model the distribution of returns very well; Having a long context is crucial for obtaining good results; DT can work with sparse rewards. Citation Cited as:\n Weng, Lilian. (Jan 2023). The transformer family version 2.0. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/.\n Or\n@article{weng2023transformer, title = \u0026quot;The Transformer Family Version 2.0\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2023\u0026quot;, month = \u0026quot;Jan\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/\u0026quot; } References [1] Ashish Vaswani, et al. \u0026ldquo;Attention is all you need.\u0026quot; NIPS 2017.\n[2] Rami Al-Rfou, et al. \u0026ldquo;Character-level language modeling with deeper self-attention.\u0026quot; AAAI 2019.\n[3] Olah \u0026amp; Carter, \u0026ldquo;Attention and Augmented Recurrent Neural Networks\u0026rdquo;, Distill, 2016.\n[4] Sainbayar Sukhbaatar, et al. \u0026ldquo;Adaptive Attention Span in Transformers\u0026rdquo;. ACL 2019.\n[5] Rewon Child, et al. \u0026ldquo;Generating Long Sequences with Sparse Transformers\u0026rdquo; arXiv:1904.10509 (2019).\n[6] Nikita Kitaev, et al. \u0026ldquo;Reformer: The Efficient Transformer\u0026rdquo; ICLR 2020.\n[7] Alex Graves. (\u0026ldquo;Adaptive Computation Time for Recurrent Neural Networks\u0026rdquo;)[https://arxiv.org/abs/1603.08983]\n[8] Niki Parmar, et al. \u0026ldquo;Image Transformer\u0026rdquo; ICML 2018.\n[9] Zihang Dai, et al. \u0026ldquo;Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.\u0026quot; ACL 2019.\n[10] Aidan N. Gomez, et al. \u0026ldquo;The Reversible Residual Network: Backpropagation Without Storing Activations\u0026rdquo; NIPS 2017.\n[11] Mostafa Dehghani, et al. \u0026ldquo;Universal Transformers\u0026rdquo; ICLR 2019.\n[12] Emilio Parisotto, et al. \u0026ldquo;Stabilizing Transformers for Reinforcement Learning\u0026rdquo; arXiv:1910.06764 (2019).\n[13] Rae et al. “Compressive Transformers for Long-Range Sequence Modelling.” 2019.\n[14] Press et al. “Train Short, Test Long: Attention With Linear Biases Enables Input Length Extrapolation.” ICLR 2022.\n[15] Wu, et al. “DA-Transformer: Distance Aware Transformer” 2021.\n[16] Elabyad et al. “Depth-Adaptive Transformer.” ICLR 2020.\n[17] Schuster et al. “Confident Adaptive Language Modeling” 2022.\n[18] Qiu et al. “Blockwise self-attention for long document understanding” 2019\n[19] Roy et al. “Efficient Content-Based Sparse Attention with Routing Transformers.” 2021.\n[20] Ainslie et al. “ETC: Encoding Long and Structured Inputs in Transformers.” EMNLP 2019.\n[21] Beltagy et al. “Longformer: The long-document transformer.” 2020.\n[22] Zaheer et al. “Big Bird: Transformers for Longer Sequences.” 2020.\n[23] Wang et al. “Linformer: Self-Attention with Linear Complexity.” arXiv preprint arXiv:2006.04768 (2020).\n[24] Tay et al. 2020 “Sparse Sinkhorn Attention.” ICML 2020.\n[25] Peng et al. “Random Feature Attention.” ICLR 2021.\n[26] Choromanski et al. “Rethinking Attention with Performers.” ICLR 2021.\n[27] Khandelwal et al. “Generalization through memorization: Nearest neighbor language models.” ICLR 2020.\n[28] Yogatama et al. “Adaptive semiparametric language models.” ACL 2021.\n[29] Wu et al. “Memorizing Transformers.” ICLR 2022.\n[30] Su et al. “Roformer: Enhanced transformer with rotary position embedding.” arXiv preprint arXiv:2104.09864 (2021).\n[31] Shaw et al. “Self-attention with relative position representations.” arXiv preprint arXiv:1803.02155 (2018).\n[32] Tay et al. \u0026ldquo;Efficient Transformers: A Survey.\u0026quot; ACM Computing Surveys 55.6 (2022): 1-28.\n[33] Chen et al., \u0026ldquo;Decision Transformer: Reinforcement Learning via Sequence Modeling\u0026rdquo; arXiv preprint arXiv:2106.01345 (2021).\n","permalink":"https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/","summary":"Many new Transformer architecture improvements have been proposed since my last post on \u0026ldquo;The Transformer Family\u0026rdquo; about three years ago. Here I did a big refactoring and enrichment of that 2020 post \u0026mdash; restructure the hierarchy of sections and improve many sections with more recent papers. Version 2.0 is a superset of the old version, about twice the length.\nNotations Symbol Meaning $d$ The model size / hidden state dimension / positional encoding size.","title":"The Transformer Family Version 2.0"},{"content":"[Updated on 2023-01-24: add a small section on Distillation.]\nLarge transformer models are mainstream nowadays, creating SoTA results for a variety of tasks. They are powerful but very expensive to train and use. The extremely high inference cost, in both time and memory, is a big bottleneck for adopting a powerful transformer for solving real-world tasks at scale.\nWhy is it hard to run inference for large transformer models? Besides the increasing size of SoTA models, there are two main factors contributing to the inference challenge (Pope et al. 2022):\n Large memory footprint. Both model parameters and intermediate states are needed in memory at inference time. For example, The KV cache should be stored in memory during decoding time; E.g. For a batch size of 512 and context length of 2048, the KV cache totals 3TB, that is 3x the model size (!). Inference cost from the attention mechanism scales quadratically with input sequence length. Low parallelizability. Inference generation is executed in an autoregressive fashion, making the decoding process hard to parallel. In this post, we will look into several approaches for making transformer inference more efficient. Some are general network compression methods, while others are specific to transformer architecture.\nMethods Overview We in general consider the following as goals for model inference optimization:\n Reduce the memory footprint of the model by using fewer GPU devices and less GPU memory; Reduce the desired computation complexity by lowering the number of FLOPs needed; Reduce the inference latency and make things run faster. Several methods can be used to make inference cheaper in memory or/and faster in time.\n Apply various parallelism to scale up the model across a large number of GPUs. Smart parallelism of model components and data makes it possible to run a model of trillions of parameters. Memory offloading to offload temporarily unused data to the CPU and read them back when needed later. This helps with memory usage but causes higher latency. Smart batching strategy; E.g. EffectiveTransformer packs consecutive sequences together to remove padding within one batch. Network compression techniques, such as pruning, quantization, distillation. A model of smaller size, in terms of parameter count or bitwidth, should demand less memory and run faster. Improvement specific to a target model architecture. Many architectural changes, especially those for attention layers, help with transformer decoding speed. Check the previous post on large model training on different types of training parallelism and memory saving designs including CPU memory offloading. This post focuses on network compression techniques and architecture-specific improvement for transformer models.\nDistillation Knowledge Distillation (KD; Hinton et al. 2015, Gou et al. 2020) is a straightforward way to build a smaller, cheaper model (\u0026ldquo;student model\u0026rdquo;) to speed up inference by transferring skills from a pre-trained expensive model (\u0026ldquo;teacher model\u0026rdquo;) into the student. There is no much restriction on how the student architecture should be constructed, except for a matched output space with the teacher in order to construct a proper learning objective.\nFig. 1. The generic framework of teacher-student knowledge distillation training. (Image source: Gou et al. 2020) Given a dataset, a student model is trained to mimic outputs of a teacher via distillation loss. Usually a neural network has a softmax layer; For example, a LLM outputs a probability distribution over tokens. Let\u0026rsquo;s denote the logits layer right before softmax as $\\mathbf{z}_t$ and $\\mathbf{z}_s$ for teacher and student models, respectively. The distillation loss minimizes the difference between two softmax outputs with a high temperature $T$. When ground truth labels $\\mathbf{y}$ are known, we can combine it with a supervised learning objective between ground truth and the student\u0026rsquo;s soft logits using e.g. cross-entropy.\n $$ \\mathcal{L}_\\text{KD} = \\mathcal{L}_\\text{distll}(\\text{softmax}(\\mathbf{z}_t, T), \\text{softmax}(\\mathbf{z}_s, T)) + \\lambda\\mathcal{L}_\\text{CE}(\\mathbf{y}, \\mathbf{z}_s) $$ where $\\lambda$ is a hyperparameter to balance between soft and hard learning objectives. A common choice for $\\mathcal{L}_\\text{distll}$ is KL divergence / cross entropy.\nA successful early trial is DistilBERT (Sanh et al. 2019) that is able to reduce the parameters of a BERT by 40% while maintaining 97% performance of BERT on fine-tuned downstream tasks and running 71% faster. The loss of pre-training DistilBERT is a combination of soft distillation loss, supervised training loss (i.e. Masked language modeling loss $\\mathcal{L}_\\text{MLM}$ in the case of BERT) and a special cosine embedding loss to align the hidden state vectors between teacher and student.\nDistillation can be easily combined with quantization, pruning or sparsification techniques, where the teacher model is the original full-precision, dense model and the student is quantized, pruned, or trimmed to have higher sparsity level.\nQuantization There are two common approaches for applying quantization on a deep neural network:\n Post-Training Quantization (PTQ): A model is first trained to convergence and then we convert its weights to lower precision without more training. It is usually quite cheap to implement, in comparison to training. Quantization-Aware Training (QAT): Quantization is applied during pre-training or further fine-tuning. QAT is able to attain better performance but requires extra computation resources and access to representative training data. We should be aware of the gap between theoretical optimal quantization strategy and the hardware kernel support. Due to the lack of GPU kernel support for certain types of matrix multiplication (e.g. INT4 x FP16), not all the methods below result in speedup for the actual inference.\nChallenges for Transformer Quantization Many studies on Transformer model quantization have the same observation: A simple low-precision (e.g. 8-bit) post-training quantization leads to significant performance drop mainly due to the high dynamic ranges of activation and a naive activation quantization strategy fails to maintain the capacity.\nFig. 2. Only quantizing model weights to 8-bit while keeping activation at full precision (`W8A32`) achieves much better results when activations are quantized to 8-bit irrespective of whether weights are in lower precision (`W8A8` and `W32A8`). (Image source: Bondarenko et al. 2021) Bondarenko et al. (2021) observed in a small BERT model that FFN’s input and output have very different dynamic ranges due to strong outliers in the output tensor. Therefore per-tensor quantization for the FFN’s residual sum is likely to cause a notable error.\nAs the model size continues to grow to billions of parameters, outlier features of high magnitude start to emerge in all transformer layers, causing failure of simple low-bit quantization. Dettmers et al. (2022) observed such a phenomenon for OPT models larger than 6.7B parameters. Larger models have more layers with extreme outliers and these outlier features have a significant impact on the model performance. The scale of activation outliers in a few dimensions can be ~100× larger than most of the other values.\nFig. 3. The mean zero-shot accuracy over a set of language tasks (WinoGrande, HellaSwag, PIQA, LAMBADA) of OPT models of increasing sizes. (Image source: Dettmers et al. 2022) Post-training quantization (PTQ) Mixed-precision quantization The most straightforward approach for resolving the above quantization challenge is to implement quantization at different precision for weights vs activation.\nGOBO (Zadeh et al. 2020) is one of the first models to apply post-training quantization on transformers (i.e. a small BERT model). It assumes that model weights of each layer follow a Gaussian distribution and therefore detects outliers by tracking mean and standard deviation per layer. Outlier features remain in original form, while other values are split into multiple bins and only corresponding bin indices of weights and the centroid values are stored.\nBased on the observation that only certain activation layers (e.g. residual connections after FFN) in BERT cause big performance drop, Bondarenko et al. (2021) adopted mixed-precision quantization by using 16-bit quantization on problematic activations but 8-bit on others.\nMixed-precision quantization in LLM.int8() (Dettmers et al. 2022) is implemented via two mixed-precision decompositions:\n Because matrix multiplication contains a set of independent inner products between row and column vectors, we can impose independent quantization per inner product: Each row and column are scaled by the absolution maximum values and then quantized to INT8. Outlier activation features (e.g. 20x larger than other dimensions) remain in FP16 but they represent only a tiny fraction of total weights. How to identify outliers is empirical. Fig. 4. Two mixed-precision decompositions of `LLM.int8()`. (Image source: Dettmers et al. 2022) Quantization at fine-grained granularity Fig. 5. Comparison of quantization at different granularity. $d$ is the model size / hidden state dimension and $h$ is the number of heads in one MHSA (multi-head self-attention) component. Naively quantizing the entire weight matrix in one layer (\u0026ldquo;per-tensor\u0026rdquo; or \u0026ldquo;per-layer\u0026rdquo; quantization) is easiest to implement but does not lead to good granularity of quantization.\nQ-BERT (Shen, Dong \u0026amp; Ye, et al. 2020) applied group-wise quantization to a fine-tuned BERT model, treating an individual matrix $W$ with respect to each head in MHSA (multi-head self-attention) as one group and then applies Hessian based mixed precision quantization.\nPer-embedding group (PEG) activation quantization was motivated by the observation that outlier values only appear in a few out of $d$ (hidden state / model size) dimensions (Bondarenko et al. 2021). Per-embedding is pretty computationally expensive. In comparison, PEG quantization splits the activation tensor into several evenly sized groups along the embedding dimension where elements in the same group share quantization parameters. To ensure all outliers are grouped together, they apply a deterministic range-based permutation of embedding dimensions, where dimensions are sorted by their value ranges.\nZeroQuant (Yao et al. 2022) uses group-wise quantization for weights, same as in Q-BERT, and token-wise quantization for activation. To avoid expensive quantization and de-quantization computation, ZeroQuant built customized kernel to fuse quantization operation with its previous operator.\nSecond order information for quantization Q-BERT (Shen, Dong \u0026amp; Ye, et al. 2020) developed Hessian AWare Quantization (HAWQ) for its mixed-precision quantization. The motivation is that parameters with higher Hessian spectrum (i.e., larger top eigenvalues) are more sensitive to quantization and thus require higher precision. It is essentially a way to identify outliers.\nIn another viewpoint, the problem of quantization is an optimization problem. Given a weight matrix $\\mathbf{W}$ and an input matrix $\\mathbf{X}$ , we want to find a quantized weight matrix $\\hat{\\mathbf{W}}$ to minimize the MSE:\n$$ \\hat{\\mathbf{W}}^* = {\\arg\\min}_{\\hat{\\mathbf{W}}} | \\mathbf{W}\\mathbf{X} - \\hat{\\mathbf{W}}\\mathbf{X}| $$\nGPTQ (Frantar et al. 2022) treats the weight matrix $\\mathbf{W}$ as a collection of row vectors ${\\mathbf{w}}$ and applies quantization to each row independently. GPTQ iteratively quantizes more weights that are selected greedily to minimize the quantization error. The update on selected weights has a closed-form formula, utilizing Hessian matrices. Read more details in the paper and the OBQ (Optimal Brain Quantization; Frantar \u0026amp; Alistarh 2022) method if interested. GPTQ can reduce the bitwidth of weights in OPT-175B down to 3 or 4 bits without much performance loss, but it only applies to model weights not activation.\nOutlier smoothing It is known that activations are harder to quantize than weights in transformer models. SmoothQuant (Xiao \u0026amp; Lin 2022) proposed a smart solution to smooth outlier features from activations to weights via mathematically equivalent transformation and then enable quantization on both weights and activations (W8A8). Because of this, SmoothQuant has better hardware efficiency than mixed-precision quantization.\nFig. 6. SmoothQuant migrates the scale variance from activations to weights offline to reduce the difficulty of activation quantization. Both the resulting new weight and activation matrices are easy to quantize. (Image source: Xiao \u0026 Lin 2022) Considering a per-channel smooth factor $\\mathbf{s}$, SmoothQuant scales the weights according to:\n$$ \\mathbf{Y} = (\\mathbf{X} \\text{diag}(\\mathbf{s})^{-1}) \\cdot (\\text{diag}(\\mathbf{s})\\mathbf{W}) = \\hat{\\mathbf{X}}\\hat{\\mathbf{W}} $$\nThe smoothing factor can be easily fused into previous layers' parameters offline. A hyperparameter $\\alpha$ controls how much we migrate the quantization difficulty from activations to weights: $\\mathbf{s} = \\max (\\vert \\mathbf{X}_j \\vert)^\\alpha / \\max( \\vert \\mathbf{W}_j \\vert )^{1-\\alpha}$. The paper found that $\\alpha=0.5$ is a sweet spot for many LLMs in the experiments. For models with more significant outliers in activation, $\\alpha$ can be adjusted to be larger.\nQuantization-aware training (QAT) Quantization-aware training fuses the quantization operation into the pre-training or fine-tuning process. It learns model weights in low-bit representation directly and leads to better performance at the cost of additional training time and computation.\nThe most straightforward approach is to fine-tune the model after quantization on a training dataset that is the same as or representative of the pre-training dataset. The training objective can be the same as the one for pre-training (e.g. NLL/MLM in general language model training) or specific to a downstream task that we care about (e.g. Cross entropy for classification).\nAnother approach is to consider the full-precision model as the teacher and the lower-precision model as the student, and then optimize the low-precision model with distillation loss. Distillation usually doesn\u0026rsquo;t need to use the original dataset; E.g. Wikipedia dataset is a good choice and even random tokens can give decent performance gain. The Layer-by-layer Knowledge Distillation (LKD; Yao et al. 2022) method quantizes the network layer by layer and uses its original, unquantized version as the teacher model. Given the same inputs, LKD minimizes the MSE between the multiplication with layer weights and the multiplication of quantized layer weights.\nPruning Network pruning is to reduce the model size by trimming unimportant model weights or connections while the model capacity remains. It may or may not require re-training. Pruning can be unstructured or structured.\n Unstructured pruning is allowed to drop any weight or connection, so it does not retain the original network architecture. Unstructured pruning often does not work well with modern hardware and doesn\u0026rsquo;t lead to actual inference speedup. Structured pruning aims to maintain the dense matrix multiplication form where some elements are zeros. They may need to follow certain pattern restrictions to work with what hardware kernel supports. Here we focus on structured pruning to achieve high sparsity in transformer models. A routine workflow to construct a pruned network has three steps:\n Train a dense network until convergence; Prune the network to remove unwanted structure; Optionally retrain the network to recover the performance with new weights. The idea of discovering a sparse structure within a dense model via network pruning while the sparse network can still maintain similar performance is motivated by Lottery Ticket Hypothesis (LTH): A randomly initialized, dense, feed-forward network contains a pool of subnetworks and among them only a subset (a sparse network) are \u0026ldquo;winning tickets\u0026rdquo; which can achieve the optimal performance when trained in isolation.\nHow to prune? Magnitude pruning is simplest yet quite effective pruning method - weights with smallest absolute values are trimmed. In fact, some studies (Gale et al. 2019) found that simple magnitude pruning approaches can achieve comparable or better results than complicated pruning methods, such as variational dropout (Molchanov et al. 2017) and $l_0$ regularization (Louizos et al. 2017). Magnitude pruning is simple to apply to large models and achieves reasonably consistent performance across a wide range of hyperparameters.\nZhu \u0026amp; Gupta (2017) found that large sparse models were able to achieve better performance than their small but dense counterparts. They proposed Gradual Magnitude Pruning (GMP) algorithm that increases the sparsity of a network gradually over the course of training. At each training step, weights with smallest absolute values are masked to be zeros to achieve a desired sparsity level $s$ and masked weights do not get gradient update during back-propagation. The desired sparsity level $s$ goes up with more training steps. The process of GMP is sensitive to the learning rate schedule, which should be higher than what\u0026rsquo;s used in dense network training, but not too high to prevent convergence.\nIterative pruning (Renda et al. 2020) iterates step 2 (prune) \u0026amp; step 3 (retrain) multiple times: Only a small fraction of weights are pruned and the model is retrained in each iteration. The process repeats until a desired sparsity level is reached.\nHow to retrain? The retraining step can be simple fine-tuning using the same pre-training data or other task-specific datasets.\nLottery Ticket Hypothesis proposed a weight rewinding retraining technique: After pruning, the unpruned weights are reinitialized back to original values earlier in the training and then retrain with the same learning rate schedule.\nLearning rate rewinding (Renda et al. 2020) only resets the learning rate back to its early value, while the unpruned weights stay unchanged since the end of the last train stage. They observed that (1) retraining with weight rewinding outperforms retraining with fine-tuning across networks and datasets and (2) learning rate rewinding matches or outperforms weight rewinding in all tested scenarios.\nSparsity Sparsity is an effective way to scale up model capacity while keeping model inference computationally efficient. Here we consider two types of sparsity for transformers:\n Sparsified dense layers, including both self-attention and FFN layers. Sparse model architecture; i.e. via incorporating the Mixture-of-Experts (MoE) component. N:M Sparsity via Pruning N:M sparsity is a structured sparsity pattern that works well with modern GPU hardware optimization, in which $N$ out of every $M$ consecutive elements are zeros. For example, the sparse tensor core of Nvidia A100 GPU has support for 2:4 sparsity for faster inference (Nvidia 2020).\nFig. 7. A matrix of 2:4 structured sparsity and its compressed representation. (Image source: Nvidia blog) To sparsify a dense neural network to follow a N:M structured sparsity pattern, Nvidia (2020) suggested using the three-step routine workflow for training a pruned network: train \u0026ndash;\u0026gt; prune to satisfy 2:4 sparsity \u0026ndash;\u0026gt; retrain.\nPermuting columns can provide more options in the pruning process to maintain parameters of large magnitude or to satisfy a special restriction like N:M sparsity (Pool \u0026amp; Yu 2021). As long as paired axes of two matrices are permuted in the same order, the results of matrix multiplication would not change. For example,\n(1) Within the self-attention module, if the same permutation order is applied on the axis 1 of query embedding matrix $\\mathbf{Q}$ and the axis 0 of key embedding matrix $\\mathbf{K}^\\top$, the final result of matrix multiplication of $\\mathbf{Q}\\mathbf{K}^\\top$ would stay the same.\nFig. 8. Illustration of same permutation on $\\mathbf{Q}$ (axis 1) and $\\mathbf{K}^\\top$ (axis 0) to keep the results of a self-attention module unchanged. (2) Within the FFN layer that contains two MLP layers and one ReLU non-linear layer, we can permute the first linear weight matrix $\\mathbf{W}_1$ along the axis 1 and the second linear weight matrix $\\mathbf{W}_2$ along the axis 0 in the same order.\nFig. 9. Illustration of the same permutation on $\\mathbf{W}_1$ (axis 1) and $\\mathbf{W}_2$ (axis 0) to keep the FFN layer's output unchanged. For simplicity, the bias terms are skipped but the same permutation should be applied on them too. To enforce N:M structured sparsity, let\u0026rsquo;s split the columns of one matrix into multiple slides of $M$ columns (named \u0026ldquo;stripe\u0026rdquo;) and we can easily observe that both the order of columns within each stripe and the order of stripes have no effect on the N:M sparsity restriction.\nPool \u0026amp; Yu (2021) proposed an iterative greedy algorithm to find optimal permutation that maximizes the weight magnitude for N:M sparsity. All pairs of channels are speculatively swapped and only the swap that leads to the greatest increase in magnitude is adopted, generating a new permutation and concluding a single iteration. Greedy algorithm may only find local minima, so they introduced two techniques to escape local minima:\n Bounded regressions: In practice two random channels are swapped, up to a fixed number of times. The solution search is limited to a depth of only one channel swap to keep the search space broad and shallow. Narrow, deep search: Choose multiple stripes and optimize them at the same time. Fig. 10. Algorithm of finding the best permutation for N:M sparsity greedily and iteratively. (Image source: Pool \u0026 Yu 2021) The network can achieve better performance if it was permuted before pruning, compared to pruning the network in its default channel order.\nTo train a model with N:M sparsity from scratch, Zhou \u0026amp; Ma, et al. (2021) extended STE (Straight-Through Estimator; Bengio et al. 2013), which is commonly used for back-propagation update in model quantization, to work for magnitude pruning and sparse parameter update.\nSTE computes the gradients of dense parameters wrt the pruned network $\\widetilde{W}$, $\\partial \\mathcal{L}/\\partial \\widetilde{W}$, and applies that to the dense network $W$ as an approximation:\n$$ W_{t+1} \\gets W_t - \\gamma \\frac{\\partial\\mathcal{L}}{\\partial\\widetilde{W}} $$\nThe extended version, SR-STE (Sparse-refined STE), updates the dense weights $W$ by:\n$$ W_{t+1} \\gets W_t - \\gamma \\frac{\\partial\\mathcal{L}}{\\partial\\widetilde{W}} + \\lambda_W (\\bar{\\mathcal{E}} \\odot W_t) $$ where $\\bar{\\mathcal{E}}$ is the mask matrix for $\\widetilde{W}$ and $\\odot$ is element-wise multiplication. SR-STE is proposed to prevent large change in the binary mask by (1) restricting the values of weights pruned in $\\widetilde{W}_t$, and (2) promoting the non-pruned weights in $\\widetilde{W}_t$.\nFig. 11. Comparison of STE and SR-STE. $\\odot$ is element-wise product; $\\otimes$ is matrix multiplication. (Image source: Zhou \u0026 Ma, et al. 2021) Different from STE or SR-STE, the Top-KAST (Jayakumar et al. 2021) method can preserve constant sparsity throughout training in both the forward and backward-passes but does not require forward passes with dense parameters or dense gradients.\nAt one training step $t$, Top-KAST processes as follows:\n Sparse forward pass: Select a subset of parameters $A^t \\subset \\Theta$, containing top-$K$ parameters by magnitude by each layer, restricted to top $D$-proportion of weights. The parameterization $\\alpha^t$ at time $t$ has parameters zeroed out if it is not in $A^t$ (active weights). $$ \\alpha^t_i = \\begin{cases} \\theta^t_i \u0026 \\text{ if } i \\in A^t = \\{i \\mid \\theta^t_i \\in \\text{TopK}(\\theta^t, D) \\}\\\\ 0 \u0026 \\text{ otherwise} \\end{cases} $$ where $\\text{TopK}(\\theta, x)$ selected top $x$ proportion of weights from $\\theta$ based on magnitude.\nSparse backward pass: Then apply gradients to a larger parameter subset $B \\subset \\Theta$ where $B$ contains $(D+M)$-proportion of weights and $A \\subset B$. Updating a larger proportion of weights enables more effective exploration of different pruning masks, making it more likely to cause permutations in the top $D$-proportion active weights. $$ \\Delta_{\\theta^t_i} = \\begin{cases} -\\eta \\nabla_{\\alpha_t} \\mathcal{L}(y, x, \\alpha^t)_i \u0026 \\text{ if } i\\in B^t = \\{i \\mid \\theta^t_i \\in \\text{TopK}(\\theta^t, D+M) \\} \\\\ 0 \u0026 \\text{ otherwise } \\end{cases} $$ Training is split into two stages and the additional coordinates in the set $B \\setminus A$ controls how much exploration is brought in. The amount of exploration is expected to diminish gradually through the training process and the mask eventually stabilizes.\nFig. 12. The pruning mask of Top-KAST stabilizes in time. (Image source: Jayakumar et al. 2021) To prevent rich-get-richer phenomenon, Top-KAST penalizes the magnitude of active weights via a L2 regularization loss to encourage more exploration of new items. Parameters in $B \\setminus A$ are penalized more than $A$ for a higher selection bar during updates to stabilize the mask.\n $$ L_\\text{penalty}(\\alpha^t_i) = \\begin{cases} \\vert \\theta^t_i\\vert \u0026 \\text{ if } i \\in A^t \\\\ \\vert \\theta^t_i\\vert / D \u0026 \\text{ if } i \\in B^t \\setminus A^t \\\\ 0 \u0026 \\text{ otherwise} \\end{cases} $$ Sparsified Transformer Scaling Transformer (Jaszczur et al. 2021) sparsifies both self-attention and FFN layers in transformer architecture, achieving 37x speedup for single-example inference.\nFig. 13. The speed of decoding a single token (unbatched inference) by a transformer model when sparsification is applied on different layers. (Image source: Jaszczur et al. 2021) Sparse FFN layer: Each FFN layer contains 2 MLP and one ReLU in-between. Because ReLU will introduce a lot of zeros, they implement a fixed structure on activations to enforce only 1 non-zero value in one block of $N$ elements. The sparsity pattern is dynamic, different for each token.\n $$ \\begin{aligned} Y_\\text{sparse} \u0026= \\max(0, xW_1 + b_1) \\odot \\text{Controller}(x) \\\\ \\text{SparseFFN}(x) \u0026= Y_\\text{sparse} W_2 + b_2 \\\\ \\text{Controller}(x) \u0026= \\arg\\max(\\text{Reshape}(x C_1 C_2, (-1, N))) \\end{aligned} $$ where each activation in $Y_\\text{sparse}$ corresponds to one column in $W_1$ and one row in $W_2$. The controller is implemented as a low-rank bottleneck dense layer, $C_1 \\in \\mathbb{R}^{d_\\text{model} \\times d_\\text{lowrank}}, C_2 \\in \\mathbb{R}^{d_\\text{lowrank} \\times d_\\text{ff}}$ and $d_\\text{lowrank} = d_\\text{model} / N$. It uses $\\arg\\max$ for inference to select which columns should be non-zero and Gumbel-softmax trick (Jang et al. 2016) during training. Because we can compute $\\text{Controller}(x)$ before loading FFN weight matrices, we know which columns will be zeroed out and thus choose not to load them into memory for inference speedup.\nFig. 14. (a) Sparse FFN layer; columns in red are not loaded in memory for faster inference. (b) Sparse FFN controller for 1:4 sparsity. (Image source: Jaszczur et al. 2021) *Lilian's side note*: Fig (a) in the illustration from the paper is actually $Y_\\text{sparse} = \\max\\big(0, (xW_1 + b_1) \\odot \\text{Controller}(x)\\big)$, but it doesn't change the results. Sparse QKV (attention) layer: In the attention layer, the dimensionality $d_\\text{model}$ is divided into $S$ modules, each of size $M=d_\\text{model} /S$. To make sure each subdivision can access any part of the embedding, Scaling Transformer introduces a multiplicative layer (i.e., a multiplication layer multiplies inputs from multiple neural network layers element-wise) which can represent arbitrary permutation but contains fewer parameters than a dense layer.\nGiven an input vector $x \\in \\mathbb{R}^{d_\\text{model}}$, the multiplicative layer outputs $y \\in \\mathbb{R}^{S \\times M}$:\n $$ y_{s,m} = \\sum_i x_i D_{i,s} E_{i,m} \\quad\\text{where }D \\in \\mathbb{R}^{d_\\text{model} \\times S}, D \\in \\mathbb{R}^{d_\\text{model} \\times M} $$ The output of the multiplicative layer is a tensor of size $\\in \\mathbb{R}^{\\text{batch size}\\times \\text{length} \\times S \\times M}$. It then gets processed by a two-dimensional convolutional layer, where $\\text{length}$ and $S$ are treated as the height and width of an image. Such a convolution layer further reduces the parameter count and computation time of attention layer.\nFig. 15. (a) A multiplicative layer is introduced to enable partitions to access any part of an embedding. (b) Combination of multiplicative dense layer and 2-D convolutional layer reduces the number of parameters and computation time of the attention layer. (Image source: Jaszczur et al. 2021) To better work with long sequences, Scaling Transformer is further equipped with LSH (locality-sensitive hashing) attention from Reformer (Kitaev, et al. 2020) and FFN block recurrence, resulting in Terraformer.\nMixture-of-Experts Mixture-of-experts (MoE) models depend on a collection of \u0026ldquo;expert\u0026rdquo; networks and each example only activates a subset of networks to get predictions. The idea originated back to the 1990s (Jacobs et al. 1991) and is strongly related to ensemble methods. For details on how to incorporate MoE module into transformer, please check my previous post on large model training techniques and a survey paper on MoE by Fedus et al. 2022.\nWith MoE architecture, only partial parameters are utilized at decoding time and therefore it saves inference cost. The capacity of each expert can be adjusted with a hyperparameter, capacity factor $C$, and the expert capacity is defined as:\n $$ \\text{Expert capacity} = \\text{round}(C \\cdot k \\cdot \\frac{\\text{total # tokens in one batch}}{\\text{# experts}}) $$ where top-$k$ experts are selected per token. Larger $C$ leads to higher expert capacity and improved performance but more expensive computationally. When $C\u0026gt;1$, a slack capacity is added; otherwise, when $C\u0026lt;1$, the routing network needs to ignore some tokens.\nRouting Strategy Improvement MoE layer has a routing network to assign a subset of experts for each input token. The routing strategy in vanilla MoE models is to route each token toward preferred experts differently as they come up in the natural order. If a token is routed to experts that have reached their capacity, the token would be marked \u0026ldquo;overflowed\u0026rdquo; and skipped.\nV-MoE (Vision MoE; Riquelme et al. 2021) adds MoE layers into ViT (Vision Transformer). It matches the performance of previous SoTA but only requires half of inference compute. V-MoE can be scaled up to 15B parameters. Their experiments used $k=2$, 32 experts and every-2 expert placement (meaning that MoEs are placed in every other layer).\nSince each expert has a limited capacity, some important and informative tokens may have to be discarded if they come up too late in the predefined sequence order (e.g. the order of words in a sentence, or the order of image patches). To avoid such a drawback in the vanilla routing scheme, V-MoE adopts BPR (Batch Priority Routing) to assign experts to tokens with a high priority score first. BPR computes a priority score (max or sum of top-$k$ router scores) per token before expert assignment and alters the order of tokens accordingly. This guarantees that the expert capacity buffer would be fulfilled with key tokens first.\nFig. 16. How image patches are discarded according to priority scores when $C Riquelme et al. 2021) BPR works much better than vanilla routing when $C\\leq 0.5$, where the model starts dropping a significant amount of tokens. It capacitates the model to be competitive with the dense network even at quite low capacities.\nWhen looking into how to interpret image class-expert association, they observed that early MoE layers are more general, while later MoE layers could be specialized for a few image classes.\nTask MoE (Task-level Mixture-of-Experts; Kudugunta et al. 2021 ) takes the task information into consideration and routes tokens at the task level instead of the word or token level for machine translation. They used MNMT (multilingual neural machine translation) as an example and group translation tasks based on the target language or language pairs.\nToken level routing is dynamic and the routing decision for each token is made disjointly. Hence, at inference time, the server needs to preload all the experts. In comparison, task level routing is static given a fixed task, so the inference server for one task only needs to preload $k$ experts (assuming top-$k$ routing). According to their experiments, Task MoE can achieve similar performance gain as token MoE compared to dense model baseline with 2.6x higher peak throughput and 1.6% of the decoder size.\nTask level MoE is essentially to categorize a distribution of tasks according to predefined heuristics and incorporate such human knowledge into the router. When such heuristics do not exist (e.g. consider a general sentence continuation task), it would not be straightforward how to utilize Task MoE.\nPR-MoE (Pyramid residual MoE; Rajbhandari et al. 2022) has each token pass one fixed MLP and one chosen expert. Due to the observation that MoE at later layers is more beneficial, PR-MoE adopts more exports at later layers. DeepSpeed library implements a flexible multi-expert, multi-data parallelism to enable training PR-MoE with different numbers of experts across layers.\nFig. 17. Illustration of PR-MoE architecture in comparison with a standard MoE. (Image source: Rajbhandari et al. 2022) Kernel Improvement Expert networks can be hosted on different devices. However, when the number of GPUs increases, the number of experts per GPU decreases and the communication between experts (\u0026ldquo;All-to-all\u0026rdquo;) grows to be more expensive. All-to-all communication between experts across a number of GPUs relies on P2P APIs of NCCL, which cannot saturate the bandwidth of high-speed links (e.g. NVLink, HDR InfiniBand) at a large scale, as individual chunk gets smaller with more nodes used. The existing all-to-all algorithm performs poorly at large scale with a small workload. There are a variety of kernel improvements to enable more efficient MoE computation, such as making all-to-all communication cheaper/faster.\nBoth the DeepSpeed library (Rajbhandari et al. 2022) and TUTEL (Hwang et al. 2022) implemented a tree-based hierarchical all-to-all algorithm, which runs an intra-node all-to-all followed by an inter-node all-to-all. It reduces the communication hops from $O(G)$ to $O(G_\\text{node} + G / G_\\text{node})$, where $G$ is the total number of GPU nodes and $G_\\text{node}$ is the number of GPU cores per node. Although the communication volume is doubled in such implementation, it enables better scaling with small batches at large scale as the bottleneck is on latency instead of communication bandwidth when the batch size is small.\nDynaMoE (Kossmann et al. 2022) uses dynamic recompilation to adapt the computational resources to dynamic workloads among experts. The RECOMPILE mechanism compiles the computation graph from scratch and only reallocates resources when needed. It measures how many samples are assigned to each expert and adjusts their capacity factors $C$ dynamically, in order to reduce the memory and computation requirements at run time. Based on the observation that sample-expert assignments converge early in training, sample assignment caching is introduced after convergence and then RECOMPILE is used to eliminate the dependency between the gating network and experts.\nArchitectural Optimization The survey paper on Efficient Transformers (Tay et al. 2020) reviewed a collection of new transformer architectures with improvement for better computational and memory efficiency. Strongly recommend a read. You can also check out my post \u0026ldquo;The Transformer Family Version 2.0\u0026rdquo; for introduction to a diverse set of transformer archiecture improvements in depth, including changes to make the model cheaper to run.\nFig. 18. Categorization of efficient transformer models.(Image source: Tay et al. 2020) Since the self-attention mechanism has quadratic time and memory complexity and that is the main bottleneck for better transformer decoding efficiency, all the efficient transformer models have applied some form of sparsity to the otherwise dense attention layer. Here only lists a high-level overview, several derived from Tay et al. 2020.\nSparse Attention Patterns Fixed Patterns limit the field of view for the attention matrix, using predefined, fixed patterns.\n Chunk input sequences into fixed blocks, such as Blockwise Attention; Image Transformer uses local attention; Sparse Transformer uses strided attention patterns. Combined Patterns learn to sort/cluster the input tokens - enabling a more optimal global view of the sequence while maintaining the efficiency benefits of fixed patterns.\n Sparse Transformer combines strided and local attention; Given a high dimensional input tensor, instead of applying attention to the flattened version of the input, Axial Transformer applies multiple attentions, each along a single axis of the input tensor. ETC, Longformer and Big Bird combines local and global context, as well as strided or random attention. Learnable Patterns identify the optimal attention pattern via learning.\n Reformer clusters tokens into clusters based on hash-based similarity (LSH); Routing Transformer runs $k$-means clustering on tokens; Sinkhorn Sorting Network learns to sort blocks of input sequence. Recurrence Recurrence mechanism connects multiple blocks/segments via recurrence.\n Transformer-XL makes use of longer context by reusing hidden states between segments. Universal Transformer combines self-attention with the recurrent mechanism in RNN. Compressive Transformer is an extension of Transformer-XL with additional memory, containing a set of memory slots for past activiations and compressive memory slots for compressed activations. Whenever the model accepts a new input segment, the oldest activations in the primary memory are moved to the compressed memory where a compression function is applied. Memory Saving Designs Memory saving designs refer to changes of the architecture to use less memory.\n Linformer projects the length dimension of keys and values to a lower-dimensional representation ($N \\to k$) and thus the memory complexity is reduced from $N \\times N$ to $N \\times k$. Shazeer (2019) proposed multi-query attention which has the keys and values shared across different attention \u0026ldquo;heads\u0026rdquo;, greatly reducing the size of these tensors and the memory cost. Random feature attention and Performer use kernel methods to achieve a cheaper mathematical format of the self-attention mechanism. Adaptive Attention Adaptive attention enables the model to learn the optimal attention span or decide on when to do early exiting for different input tokens.\n Adaptive Attention Span trains the model to learn the optimal attention span per token per head via a soft mask between the token and other keys. Universal Transformer incorporates recurrent mechanism and uses ACT (Adaptive computation time) to dynamically decide the number of recurrent steps. Depth-Adaptive Transformer and CALM learns when to early exit the computation layers per token using some confidence measures to achieve good performance-efficiency tradeoffs. Citation Cited as:\n Weng, Lilian. (Jan 2023). Large Transformer Model Inference Optimization. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2023-01-10-inference-optimization/.\n Or\n@article{weng2023inference, title = \u0026quot;Large Transformer Model Inference Optimization\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;Lil'Log\u0026quot;, year = \u0026quot;2023\u0026quot;, month = \u0026quot;Jan\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2023-01-10-inference-optimization/\u0026quot; } References [1] Bondarenko et al. \u0026ldquo;Understanding and overcoming the challenges of efficient transformer quantization\u0026rdquo; ACL 2021.\n[2] Dettmers et al. \u0026ldquo;LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale\u0026rdquo; NeuriPS 2022\n[3] Zadeh et al. \u0026ldquo;Gobo: Quantizing attention-based NLP models for low latency and energy efficient inference.\u0026quot; MICRO 2020\n[4] Shen, Dong \u0026amp; Ye, et al. \u0026ldquo;Q-BERT: Hessian based ultra low precision quantization of BERT\u0026rdquo; AAAI 2020.\n[5] Yao et al. \u0026ldquo;ZeroQuant: Efficient and affordable post-training quantization for large-scale transformers\u0026rdquo; arXiv preprint arXiv:2206.01861 (2022).\n[6] Frantar et al. \u0026ldquo;GPTQ: Accurate Quantization for Generative Pre-trained Transformers\u0026rdquo; arXiv preprint arXiv:2210.17323 (2022).\n[7] Xiao \u0026amp; Lin \u0026ldquo;SmoothQuant: Accelerated sparse neural training: A provable and efficient method to find N:M transposable masks.\u0026quot; arXiv preprint arXiv:2211.10438 (2022). | code\n[8] Pool \u0026amp; Yu. \u0026ldquo;Channel Permutations for N:M Sparsity.\u0026quot; NeuriPS 2021. | code\n[9] Zhou \u0026amp; Ma, et al. \u0026ldquo;Learning N:M fine-grained structured sparse neural networks from scratch.\u0026quot; arXiv preprint arXiv:2102.04010 (2021).\n[10] Jayakumar et al. \u0026ldquo;Top-KAST: Top-K Always Sparse Training.\u0026quot; NeuriPS 2020.\n[11] Nvidia. \u0026ldquo;Nvidia A100 tensor core GPU architecture.\u0026quot; 2020.\n[12] Gale, Elsen \u0026amp; Hooker \u0026ldquo;The State of Sparsity in Deep Neural Networks.\u0026quot; arXiv preprint arXiv:1902.09574 (2019).\n[13] Zhu \u0026amp; Gupta. \u0026ldquo;To Prune, or Not to Prune: Exploring the Efficacy of Pruning for Model Compression.\u0026quot; arXiv preprint arXiv:1710.01878 (2017).\n[14] Renda et al. \u0026ldquo;Comparing rewinding and fine-tuning in neural network pruning.\u0026quot; arXiv preprint arXiv:2003.02389 (2020).\n[15] Zhou \u0026amp; Ma, et al. \u0026ldquo;Learning N:M fine-grained structured sparse neural networks from scratch.\u0026quot; arXiv preprint arXiv:2102.04010 (2021).\n[16] Pool \u0026amp; Yu. \u0026ldquo;Channel Permutations for N:M Sparsity.\u0026quot; NeuriPS 2021. | code\n[17] Jaszczur et al. \u0026ldquo;Sparse is Enough in Scaling Transformers.\u0026quot; NeuriPS 2021.\n[18] Mishra et al. \u0026ldquo;An Survey of Neural Network Compression.\u0026quot; arXiv preprint arXiv:1710.09282 (2017).\n[19] Fedus et al. \u0026ldquo;A Review of Sparse Expert Models in Deep Learning.\u0026quot; arXiv preprint arXiv:2209.01667 (2022)..\n[20] Riquelme et al. \u0026ldquo;Scaling vision with sparse mixture of experts.\u0026quot; NeuriPS 2021.\n[21] Kudugunta et al. \u0026ldquo;Beyond Distillation: Task-level Mixture-of-Experts for Efficient Inference.\u0026quot; arXiv preprint arXiv:2110.03742 (2021).\n[22] Rajbhandari et al. \u0026ldquo;DeepSpeed-MoE: Advancing mixture-of-experts inference and training to power next-generation ai scale.\u0026quot; arXiv preprint arXiv:2201.05596 (2022).\n[23] Kossmann et al. \u0026ldquo;Optimizing mixture of experts using dynamic recompilations.\u0026quot; arXiv preprint arXiv:2205.01848 (2022).\n[24] Hwang et al. \u0026ldquo;Tutel: Adaptive mixture-of-experts at scale.\u0026quot; arXiv preprint arXiv:2206.03382 (2022). | code\n[25] Noam Shazeer. \u0026ldquo;Fast Transformer Decoding: One Write-Head is All You Need.\u0026quot; arXiv preprint arXiv:1911.02150 (2019).\n[26] Tay et al. \u0026ldquo;Efficient Transformers: A Survey.\u0026quot; ACM Computing Surveys 55.6 (2022): 1-28.\n[27] Pope et al. \u0026ldquo;Efficiently Scaling Transformer Inference.\u0026quot; arXiv preprint arXiv:2211.05102 (2022).\n[28] Frankle \u0026amp; Carbin. \u0026ldquo;The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks\u0026rdquo; ICLR 2019.\n[29] Elabyad et al. \u0026ldquo;Depth-Adaptive Transformer\u0026rdquo; ICLR 2020.\n[30] Schuster et al. \u0026ldquo;Confident Adaptive Language Modeling\u0026rdquo; arXiv preprint arXiv:2207.07061 (2022).\n[31] Gou et al. \u0026ldquo;https://arxiv.org/abs/2006.05525\u0026rdquo; arXiv preprint arXiv:2006.05525 (2020).\n[32] Hinton et al. \u0026ldquo;Distilling the Knowledge in a Neural Network\u0026rdquo; NIPS 2014.\n[33] Sanh et al. \u0026ldquo;DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter\u0026rdquo; Workshop on Energy Efficient Machine Learning and Cognitive Computing @ NeuriPS 2019.\n","permalink":"https://lilianweng.github.io/posts/2023-01-10-inference-optimization/","summary":"[Updated on 2023-01-24: add a small section on Distillation.]\nLarge transformer models are mainstream nowadays, creating SoTA results for a variety of tasks. They are powerful but very expensive to train and use. The extremely high inference cost, in both time and memory, is a big bottleneck for adopting a powerful transformer for solving real-world tasks at scale.\nWhy is it hard to run inference for large transformer models? Besides the increasing size of SoTA models, there are two main factors contributing to the inference challenge (Pope et al.","title":"Large Transformer Model Inference Optimization"},{"content":"Neural networks are well known to be over-parameterized and can often easily fit data with near-zero training loss with decent generalization performance on test dataset. Although all these parameters are initialized at random, the optimization process can consistently lead to similarly good outcomes. And this is true even when the number of model parameters exceeds the number of training data points.\nNeural tangent kernel (NTK) (Jacot et al. 2018) is a kernel to explain the evolution of neural networks during training via gradient descent. It leads to great insights into why neural networks with enough width can consistently converge to a global minimum when trained to minimize an empirical loss. In the post, we will do a deep dive into the motivation and definition of NTK, as well as the proof of a deterministic convergence at different initializations of neural networks with infinite width by characterizing NTK in such a setting.\n 🤓 Different from my previous posts, this one mainly focuses on a small number of core papers, less on the breadth of the literature review in the field. There are many interesting works after NTK, with modification or expansion of the theory for understanding the learning dynamics of NNs, but they won\u0026rsquo;t be covered here. The goal is to show all the math behind NTK in a clear and easy-to-follow format, so the post is quite math-intensive. If you notice any mistakes, please let me know and I will be happy to correct them quickly. Thanks in advance!\n Basics This section contains reviews of several very basic concepts which are core to understanding of neural tangent kernel. Feel free to skip.\nVector-to-vector Derivative Given an input vector $\\mathbf{x} \\in \\mathbb{R}^n$ (as a column vector) and a function $f: \\mathbb{R}^n \\to \\mathbb{R}^m$, the derivative of $f$ with respective to $\\mathbf{x}$ is a $m\\times n$ matrix, also known as Jacobian matrix:\n $$ J = \\frac{\\partial f}{\\partial \\mathbf{x}} = \\begin{bmatrix} \\frac{\\partial f_1}{\\partial x_1} \u0026 \\dots \u0026\\frac{\\partial f_1}{\\partial x_n} \\\\ \\vdots \u0026 \u0026 \\\\ \\frac{\\partial f_m}{\\partial x_1} \u0026 \\dots \u0026\\frac{\\partial f_m}{\\partial x_n} \\\\ \\end{bmatrix} \\in \\mathbb{R}^{m \\times n} $$ Throughout the post, I use integer subscript(s) to refer to a single entry out of a vector or matrix value; i.e. $x_i$ indicates the $i$-th value in the vector $\\mathbf{x}$ and $f_i(.)$ is the $i$-th entry in the output of the function.\nThe gradient of a vector with respect to a vector is defined as $\\nabla_\\mathbf{x} f = J^\\top \\in \\mathbb{R}^{n \\times m}$ and this formation is also valid when $m=1$ (i.e., scalar output).\nDifferential Equations Differential equations describe the relationship between one or multiple functions and their derivatives. There are two main types of differential equations.\n (1) ODE (Ordinary differential equation) contains only an unknown function of one random variable. ODEs are the main form of differential equations used in this post. A general form of ODE looks like $(x, y, \\frac{dy}{dx}, \\dots, \\frac{d^ny}{dx^n}) = 0$. (2) PDE (Partial differential equation) contains unknown multivariable functions and their partial derivatives. Let\u0026rsquo;s review the simplest case of differential equations and its solution. Separation of variables (Fourier method) can be used when all the terms containing one variable can be moved to one side, while the other terms are all moved to the other side. For example,\n $$ \\begin{aligned} \\text{Given }a\\text{ is a constant scalar:}\\quad\\frac{dy}{dx} \u0026= ay \\\\ \\text{Move same variables to the same side:}\\quad\\frac{dy}{y} \u0026= adx \\\\ \\text{Put integral on both sides:}\\quad\\int \\frac{dy}{y} \u0026= \\int adx \\\\ \\ln (y) \u0026= ax + C' \\\\ \\text{Finally}\\quad y \u0026= e^{ax + C'} = C e^{ax} \\end{aligned} $$ Central Limit Theorem Given a collection of i.i.d. random variables, $x_1, \\dots, x_N$ with mean $\\mu$ and variance $\\sigma^2$, the Central Limit Theorem (CTL) states that the expectation would be Gaussian distributed when $N$ becomes really large.\n $$ \\bar{x} = \\frac{1}{N}\\sum_{i=1}^N x_i \\sim \\mathcal{N}(\\mu, \\frac{\\sigma^2}{n})\\quad\\text{when }N \\to \\infty $$ CTL can also apply to multidimensional vectors, and then instead of a single scale $\\sigma^2$ we need to compute the covariance matrix of random variable $\\Sigma$.\nTaylor Expansion The Taylor expansion is to express a function as an infinite sum of components, each represented in terms of this function\u0026rsquo;s derivatives. The Tayler expansion of a function $f(x)$ at $x=a$ can be written as: $$ f(x) = f(a) + \\sum_{k=1}^\\infty \\frac{1}{k!} (x - a)^k\\nabla^k_xf(x)\\vert_{x=a} $$ where $\\nabla^k$ denotes the $k$-th derivative.\nThe first-order Taylor expansion is often used as a linear approximation of the function value:\n $$ f(x) \\approx f(a) + (x - a)\\nabla_x f(x)\\vert_{x=a} $$ Kernel \u0026amp; Kernel Methods A kernel is essentially a similarity function between two data points, $K: \\mathcal{X} \\times \\mathcal{X} \\to \\mathbb{R}$. It describes how sensitive the prediction for one data sample is to the prediction for the other; or in other words, how similar two data points are. The kernel should be symmetric, $K(x, x') = K(x', x)$.\nDepending on the problem structure, some kernels can be decomposed into two feature maps, one corresponding to one data point, and the kernel value is an inner product of these two features: $K(x, x') = \\langle \\varphi(x), \\varphi(x') \\rangle$.\nKernel methods are a type of non-parametric, instance-based machine learning algorithms. Assuming we have known all the labels of training samples $\\{x^{(i)}, y^{(i)}\\}$, the label for a new input $x$ is predicted by a weighted sum $\\sum_{i} K(x^{(i)}, x)y^{(i)}$.\nGaussian Processes Gaussian process (GP) is a non-parametric method by modeling a multivariate Gaussian probability distribution over a collection of random variables. GP assumes a prior over functions and then updates the posterior over functions based on what data points are observed.\nGiven a collection of data points $\\{x^{(1)}, \\dots, x^{(N)}\\}$, GP assumes that they follow a jointly multivariate Gaussian distribution, defined by a mean $\\mu(x)$ and a covariance matrix $\\Sigma(x)$. Each entry at location $(i,j)$ in the covariance matrix $\\Sigma(x)$ is defined by a kernel $\\Sigma_{i,j} = K(x^{(i)}, x^{(j)})$, also known as a covariance function. The core idea is \u0026ndash; if two data points are deemed similar by the kernel, the function outputs should be close, too. Making predictions with GP for unknown data points is equivalent to drawing samples from this distribution, via a conditional distribution of unknown data points given observed ones.\nCheck this post for a high-quality and highly visualization tutorial on what Gaussian Processes are.\nNotation Let us consider a fully-connected neural networks with parameter $\\theta$, $f(.;\\theta): \\mathbb{R}^{n_0} \\to \\mathbb{R}^{n_L}$. Layers are indexed from 0 (input) to $L$ (output), each containing $n_0, \\dots, n_L$ neurons, including the input of size $n_0$ and the output of size $n_L$. There are $P = \\sum_{l=0}^{L-1} (n_l + 1) n_{l+1}$ parameters in total and thus we have $\\theta \\in \\mathbb{R}^P$.\nThe training dataset contains $N$ data points, $\\mathcal{D}=\\{\\mathbf{x}^{(i)}, y^{(i)}\\}_{i=1}^N$. All the inputs are denoted as $\\mathcal{X}=\\{\\mathbf{x}^{(i)}\\}_{i=1}^N$ and all the labels are denoted as $\\mathcal{Y}=\\{y^{(i)}\\}_{i=1}^N$.\nNow let\u0026rsquo;s look into the forward pass computation in every layer in detail. For $l=0, \\dots, L-1$, each layer $l$ defines an affine transformation $A^{(l)}$ with a weight matrix $\\mathbf{w}^{(l)} \\in \\mathbb{R}^{n_{l} \\times n_{l+1}}$ and a bias term $\\mathbf{b}^{(l)} \\in \\mathbb{R}^{n_{l+1}}$, as well as a pointwise nonlinearity function $\\sigma(.)$ which is Lipschitz continuous.\n $$ \\begin{aligned} A^{(0)} \u0026= \\mathbf{x} \\\\ \\tilde{A}^{(l+1)}(\\mathbf{x}) \u0026= \\frac{1}{\\sqrt{n_l}} {\\mathbf{w}^{(l)}}^\\top A^{(l)} + \\beta\\mathbf{b}^{(l)}\\quad\\in\\mathbb{R}^{n_{l+1}} \u0026 \\text{; pre-activations}\\\\ A^{(l+1)}(\\mathbf{x}) \u0026= \\sigma(\\tilde{A}^{(l+1)}(\\mathbf{x}))\\quad\\in\\mathbb{R}^{n_{l+1}} \u0026 \\text{; post-activations} \\end{aligned} $$ Note that the NTK parameterization applies a rescale weight $1/\\sqrt{n_l}$ on the transformation to avoid divergence with infinite-width networks. The constant scalar $\\beta \\geq 0$ controls how much effort the bias terms have.\nAll the network parameters are initialized as an i.i.d Gaussian $\\mathcal{N}(0, 1)$ in the following analysis.\nNeural Tangent Kernel Neural tangent kernel (NTK) (Jacot et al. 2018) is an important concept for understanding neural network training via gradient descent. At its core, it explains how updating the model parameters on one data sample affects the predictions for other samples.\nLet\u0026rsquo;s start with the intuition behind NTK, step by step.\nThe empirical loss function $\\mathcal{L}: \\mathbb{R}^P \\to \\mathbb{R}_+$ to minimize during training is defined as follows, using a per-sample cost function $\\ell: \\mathbb{R}^{n_0} \\times \\mathbb{R}^{n_L} \\to \\mathbb{R}_+$:\n $$ \\mathcal{L}(\\theta) =\\frac{1}{N} \\sum_{i=1}^N \\ell(f(\\mathbf{x}^{(i)}; \\theta), y^{(i)}) $$ and according to the chain rule. the gradient of the loss is:\n $$ \\nabla_\\theta \\mathcal{L}(\\theta)= \\frac{1}{N} \\sum_{i=1}^N \\underbrace{\\nabla_\\theta f(\\mathbf{x}^{(i)}; \\theta)}_{\\text{size }P \\times n_L} \\underbrace{\\nabla_f \\ell(f, y^{(i)})}_{\\text{size } n_L \\times 1} $$ When tracking how the network parameter $\\theta$ evolves in time, each gradient descent update introduces a small incremental change of an infinitesimal step size. Because of the update step is small enough, it can be approximately viewed as a derivative on the time dimension:\n $$ \\frac{d\\theta}{d t} = - \\nabla_\\theta\\mathcal{L}(\\theta) = -\\frac{1}{N} \\sum_{i=1}^N \\nabla_\\theta f(\\mathbf{x}^{(i)}; \\theta) \\nabla_f \\ell(f, y^{(i)}) $$ Again, by the chain rule, the network output evolves according to the derivative:\n $$ \\frac{df(\\mathbf{x};\\theta)}{dt} = \\frac{df(\\mathbf{x};\\theta)}{d\\theta}\\frac{d\\theta}{dt} = -\\frac{1}{N} \\sum_{i=1}^N \\color{blue}{\\underbrace{\\nabla_\\theta f(\\mathbf{x};\\theta)^\\top \\nabla_\\theta f(\\mathbf{x}^{(i)}; \\theta)}_\\text{Neural tangent kernel}} \\color{black}{\\nabla_f \\ell(f, y^{(i)})} $$ Here we find the Neural Tangent Kernel (NTK), as defined in the blue part in the above formula, $K: \\mathbb{R}^{n_0}\\times\\mathbb{R}^{n_0} \\to \\mathbb{R}^{n_L \\times n_L}$ :\n $$ K(\\mathbf{x}, \\mathbf{x}'; \\theta) = \\nabla_\\theta f(\\mathbf{x};\\theta)^\\top \\nabla_\\theta f(\\mathbf{x}'; \\theta) $$ where each entry in the output matrix at location $(m, n), 1 \\leq m, n \\leq n_L$ is:\n $$ K_{m,n}(\\mathbf{x}, \\mathbf{x}'; \\theta) = \\sum_{p=1}^P \\frac{\\partial f_m(\\mathbf{x};\\theta)}{\\partial \\theta_p} \\frac{\\partial f_n(\\mathbf{x}';\\theta)}{\\partial \\theta_p} $$ The \u0026ldquo;feature map\u0026rdquo; form of one input $\\mathbf{x}$ is $\\varphi(\\mathbf{x}) = \\nabla_\\theta f(\\mathbf{x};\\theta)$.\nInfinite Width Networks To understand why the effect of one gradient descent is so similar for different initializations of network parameters, several pioneering theoretical work starts with infinite width networks. We will look into detailed proof using NTK of how it guarantees that infinite width networks can converge to a global minimum when trained to minimize an empirical loss.\nConnection with Gaussian Processes Deep neural networks have deep connection with gaussian processes (Neal 1994). The output functions of a $L$-layer network, $f_i(\\mathbf{x}; \\theta)$ for $i=1, \\dots, n_L$ , are i.i.d. centered Gaussian process of covariance $\\Sigma^{(L)}$, defined recursively as:\n $$ \\begin{aligned} \\Sigma^{(1)}(\\mathbf{x}, \\mathbf{x}') \u0026= \\frac{1}{n_0}\\mathbf{x}^\\top{\\mathbf{x}'} + \\beta^2 \\\\ \\lambda^{(l+1)}(\\mathbf{x}, \\mathbf{x}') \u0026= \\begin{bmatrix} \\Sigma^{(l)}(\\mathbf{x}, \\mathbf{x}) \u0026 \\Sigma^{(l)}(\\mathbf{x}, \\mathbf{x}') \\\\ \\Sigma^{(l)}(\\mathbf{x}', \\mathbf{x}) \u0026 \\Sigma^{(l)}(\\mathbf{x}', \\mathbf{x}') \\end{bmatrix} \\\\ \\Sigma^{(l+1)}(\\mathbf{x}, \\mathbf{x}') \u0026= \\mathbb{E}_{f \\sim \\mathcal{N}(0, \\lambda^{(l)})}[\\sigma(f(\\mathbf{x})) \\sigma(f(\\mathbf{x}'))] + \\beta^2 \\end{aligned} $$ Lee \u0026amp; Bahri et al. (2018) showed a proof by mathematical induction:\n(1) Let\u0026rsquo;s start with $L=1$, when there is no nonlinearity function and the input is only processed by a simple affine transformation:\n $$ \\begin{aligned} f(\\mathbf{x};\\theta) = \\tilde{A}^{(1)}(\\mathbf{x}) \u0026= \\frac{1}{\\sqrt{n_0}}{\\mathbf{w}^{(0)}}^\\top\\mathbf{x} + \\beta\\mathbf{b}^{(0)} \\\\ \\text{where }\\tilde{A}_m^{(1)}(\\mathbf{x}) \u0026= \\frac{1}{\\sqrt{n_0}}\\sum_{i=1}^{n_0} w^{(0)}_{im}x_i + \\beta b^{(0)}_m\\quad \\text{for }1 \\leq m \\leq n_1 \\end{aligned} $$ Since the weights and biases are initialized i.i.d., all the output dimensions of this network ${\\tilde{A}^{(1)}_1(\\mathbf{x}), \\dots, \\tilde{A}^{(1)}_{n_1}(\\mathbf{x})}$ are also i.i.d. Given different inputs, the $m$-th network outputs $\\tilde{A}^{(1)}_m(.)$ have a joint multivariate Gaussian distribution, equivalent to a Gaussian process with covariance function (We know that mean $\\mu_w=\\mu_b=0$ and variance $\\sigma^2_w = \\sigma^2_b=1$)\n $$ \\begin{aligned} \\Sigma^{(1)}(\\mathbf{x}, \\mathbf{x}') \u0026= \\mathbb{E}[\\tilde{A}_m^{(1)}(\\mathbf{x})\\tilde{A}_m^{(1)}(\\mathbf{x}')] \\\\ \u0026= \\mathbb{E}\\Big[\\Big( \\frac{1}{\\sqrt{n_0}}\\sum_{i=1}^{n_0} w^{(0)}_{i,m}x_i + \\beta b^{(0)}_m \\Big) \\Big( \\frac{1}{\\sqrt{n_0}}\\sum_{i=1}^{n_0} w^{(0)}_{i,m}x'_i + \\beta b^{(0)}_m \\Big)\\Big] \\\\ \u0026= \\frac{1}{n_0} \\sigma^2_w \\sum_{i=1}^{n_0} \\sum_{j=1}^{n_0} x_i{x'}_j + \\frac{\\beta \\mu_b}{\\sqrt{n_0}} \\sum_{i=1}^{n_0} w_{im}(x_i + x'_i) + \\sigma^2_b \\beta^2 \\\\ \u0026= \\frac{1}{n_0}\\mathbf{x}^\\top{\\mathbf{x}'} + \\beta^2 \\end{aligned} $$ (2) Using induction, we first assume the proposition is true for $L=l$, a $l$-layer network, and thus $\\tilde{A}^{(l)}_m(.)$ is a Gaussian process with covariance $\\Sigma^{(l)}$ and $\\{\\tilde{A}^{(l)}_i\\}_{i=1}^{n_l}$ are i.i.d.\nThen we need to prove the proposition is also true for $L=l+1$. We compute the outputs by:\n $$ \\begin{aligned} f(\\mathbf{x};\\theta) = \\tilde{A}^{(l+1)}(\\mathbf{x}) \u0026= \\frac{1}{\\sqrt{n_l}}{\\mathbf{w}^{(l)}}^\\top \\sigma(\\tilde{A}^{(l)}(\\mathbf{x})) + \\beta\\mathbf{b}^{(l)} \\\\ \\text{where }\\tilde{A}^{(l+1)}_m(\\mathbf{x}) \u0026= \\frac{1}{\\sqrt{n_l}}\\sum_{i=1}^{n_l} w^{(l)}_{im}\\sigma(\\tilde{A}^{(l)}_i(\\mathbf{x})) + \\beta b^{(l)}_m \\quad \\text{for }1 \\leq m \\leq n_{l+1} \\end{aligned} $$ We can infer that the expectation of the sum of contributions of the previous hidden layers is zero:\n $$ \\begin{aligned} \\mathbb{E}[w^{(l)}_{im}\\sigma(\\tilde{A}^{(l)}_i(\\mathbf{x}))] \u0026= \\mathbb{E}[w^{(l)}_{im}]\\mathbb{E}[\\sigma(\\tilde{A}^{(l)}_i(\\mathbf{x}))] = \\mu_w \\mathbb{E}[\\sigma(\\tilde{A}^{(l)}_i(\\mathbf{x}))] = 0 \\\\ \\mathbb{E}[\\big(w^{(l)}_{im}\\sigma(\\tilde{A}^{(l)}_i(\\mathbf{x}))\\big)^2] \u0026= \\mathbb{E}[{w^{(l)}_{im}}^2]\\mathbb{E}[\\sigma(\\tilde{A}^{(l)}_i(\\mathbf{x}))^2] = \\sigma_w^2 \\Sigma^{(l)}(\\mathbf{x}, \\mathbf{x}) = \\Sigma^{(l)}(\\mathbf{x}, \\mathbf{x}) \\end{aligned} $$ Since $\\{\\tilde{A}^{(l)}_i(\\mathbf{x})\\}_{i=1}^{n_l}$ are i.i.d., according to central limit theorem, when the hidden layer gets infinitely wide $n_l \\to \\infty$, $\\tilde{A}^{(l+1)}_m(\\mathbf{x})$ is Gaussian distributed with variance $\\beta^2 + \\text{Var}(\\tilde{A}_i^{(l)}(\\mathbf{x}))$. Note that ${\\tilde{A}^{(l+1)}_1(\\mathbf{x}), \\dots, \\tilde{A}^{(l+1)}_{n_{l+1}}(\\mathbf{x})}$ are still i.i.d.\n$\\tilde{A}^{(l+1)}_m(.)$ is equivalent to a Gaussian process with covariance function:\n $$ \\begin{aligned} \\Sigma^{(l+1)}(\\mathbf{x}, \\mathbf{x}') \u0026= \\mathbb{E}[\\tilde{A}^{(l+1)}_m(\\mathbf{x})\\tilde{A}^{(l+1)}_m(\\mathbf{x}')] \\\\ \u0026= \\frac{1}{n_l} \\sigma\\big(\\tilde{A}^{(l)}_i(\\mathbf{x})\\big)^\\top \\sigma\\big(\\tilde{A}^{(l)}_i(\\mathbf{x}')\\big) + \\beta^2 \\quad\\text{;similar to how we get }\\Sigma^{(1)} \\end{aligned} $$ When $n_l \\to \\infty$, according to central limit theorem,\n $$ \\Sigma^{(l+1)}(\\mathbf{x}, \\mathbf{x}') \\to \\mathbb{E}_{f \\sim \\mathcal{N}(0, \\Lambda^{(l)})}[\\sigma(f(\\mathbf{x}))^\\top \\sigma(f(\\mathbf{x}'))] + \\beta^2 $$ The form of Gaussian processes in the above process is referred to as the Neural Network Gaussian Process (NNGP) (Lee \u0026amp; Bahri et al. (2018)).\nDeterministic Neural Tangent Kernel Finally we are now prepared enough to look into the most critical proposition from the NTK paper:\nWhen $n_1, \\dots, n_L \\to \\infty$ (network with infinite width), the NTK converges to be:\n (1) deterministic at initialization, meaning that the kernel is irrelevant to the initialization values and only determined by the model architecture; and (2) stays constant during training. The proof depends on mathematical induction as well:\n(1) First of all, we always have $K^{(0)} = 0$. When $L=1$, we can get the representation of NTK directly. It is deterministic and does not depend on the network initialization. There is no hidden layer, so there is nothing to take on infinite width.\n $$ \\begin{aligned} f(\\mathbf{x};\\theta) \u0026= \\tilde{A}^{(1)}(\\mathbf{x}) = \\frac{1}{\\sqrt{n_0}} {\\mathbf{w}^{(0)}}^\\top\\mathbf{x} + \\beta\\mathbf{b}^{(0)} \\\\ K^{(1)}(\\mathbf{x}, \\mathbf{x}';\\theta) \u0026= \\Big(\\frac{\\partial f(\\mathbf{x}';\\theta)}{\\partial \\mathbf{w}^{(0)}}\\Big)^\\top \\frac{\\partial f(\\mathbf{x};\\theta)}{\\partial \\mathbf{w}^{(0)}} + \\Big(\\frac{\\partial f(\\mathbf{x}';\\theta)}{\\partial \\mathbf{b}^{(0)}}\\Big)^\\top \\frac{\\partial f(\\mathbf{x};\\theta)}{\\partial \\mathbf{b}^{(0)}} \\\\ \u0026= \\frac{1}{n_0} \\mathbf{x}^\\top{\\mathbf{x}'} + \\beta^2 = \\Sigma^{(1)}(\\mathbf{x}, \\mathbf{x}') \\end{aligned} $$ (2) Now when $L=l$, we assume that a $l$-layer network with $\\tilde{P}$ parameters in total, $\\tilde{\\theta} = (\\mathbf{w}^{(0)}, \\dots, \\mathbf{w}^{(l-1)}, \\mathbf{b}^{(0)}, \\dots, \\mathbf{b}^{(l-1)}) \\in \\mathbb{R}^\\tilde{P}$, has a NTK converging to a deterministic limit when $n_1, \\dots, n_{l-1} \\to \\infty$.\n $$ K^{(l)}(\\mathbf{x}, \\mathbf{x}';\\tilde{\\theta}) = \\nabla_{\\tilde{\\theta}} \\tilde{A}^{(l)}(\\mathbf{x})^\\top \\nabla_{\\tilde{\\theta}} \\tilde{A}^{(l)}(\\mathbf{x}') \\to K^{(l)}_{\\infty}(\\mathbf{x}, \\mathbf{x}') $$ Note that $K_\\infty^{(l)}$ has no dependency on $\\theta$.\nNext let\u0026rsquo;s check the case $L=l+1$. Compared to a $l$-layer network, a $(l+1)$-layer network has additional weight matrix $\\mathbf{w}^{(l)}$ and bias $\\mathbf{b}^{(l)}$ and thus the total parameters contain $\\theta = (\\tilde{\\theta}, \\mathbf{w}^{(l)}, \\mathbf{b}^{(l)})$.\nThe output function of this $(l+1)$-layer network is:\n $$ f(\\mathbf{x};\\theta) = \\tilde{A}^{(l+1)}(\\mathbf{x};\\theta) = \\frac{1}{\\sqrt{n_l}} {\\mathbf{w}^{(l)}}^\\top \\sigma\\big(\\tilde{A}^{(l)}(\\mathbf{x})\\big) + \\beta \\mathbf{b}^{(l)} $$ And we know its derivative with respect to different sets of parameters; let denote $\\tilde{A}^{(l)} = \\tilde{A}^{(l)}(\\mathbf{x})$ for brevity in the following equation:\n $$ \\begin{aligned} \\nabla_{\\color{blue}{\\mathbf{w}^{(l)}}} f(\\mathbf{x};\\theta) \u0026= \\color{blue}{ \\frac{1}{\\sqrt{n_l}} \\sigma\\big(\\tilde{A}^{(l)}\\big)^\\top } \\color{black}{\\quad \\in \\mathbb{R}^{1 \\times n_l}} \\\\ \\nabla_{\\color{green}{\\mathbf{b}^{(l)}}} f(\\mathbf{x};\\theta) \u0026= \\color{green}{ \\beta } \\\\ \\nabla_{\\color{red}{\\tilde{\\theta}}} f(\\mathbf{x};\\theta) \u0026= \\frac{1}{\\sqrt{n_l}} \\nabla_\\tilde{\\theta}\\sigma(\\tilde{A}^{(l)}) \\mathbf{w}^{(l)} \\\\ \u0026= \\color{red}{ \\frac{1}{\\sqrt{n_l}} \\begin{bmatrix} \\dot{\\sigma}(\\tilde{A}_1^{(l)})\\frac{\\partial \\tilde{A}_1^{(l)}}{\\partial \\tilde{\\theta}_1} \u0026 \\dots \u0026 \\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})\\frac{\\partial \\tilde{A}_{n_l}^{(l)}}{\\partial \\tilde{\\theta}_1} \\\\ \\vdots \\\\ \\dot{\\sigma}(\\tilde{A}_1^{(l)})\\frac{\\partial \\tilde{A}_1^{(l)}}{\\partial \\tilde{\\theta}_\\tilde{P}} \u0026 \\dots \u0026 \\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})\\frac{\\partial \\tilde{A}_{n_l}^{(l)}}{\\partial \\tilde{\\theta}_\\tilde{P}}\\\\ \\end{bmatrix} \\mathbf{w}^{(l)} \\color{black}{\\quad \\in \\mathbb{R}^{\\tilde{P} \\times n_{l+1}}} } \\end{aligned} $$ where $\\dot{\\sigma}$ is the derivative of $\\sigma$ and each entry at location $(p, m), 1 \\leq p \\leq \\tilde{P}, 1 \\leq m \\leq n_{l+1}$ in the matrix $\\nabla_{\\tilde{\\theta}} f(\\mathbf{x};\\theta)$ can be written as\n $$ \\frac{\\partial f_m(\\mathbf{x};\\theta)}{\\partial \\tilde{\\theta}_p} = \\sum_{i=1}^{n_l} w^{(l)}_{im} \\dot{\\sigma}\\big(\\tilde{A}_i^{(l)} \\big) \\nabla_{\\tilde{\\theta}_p} \\tilde{A}_i^{(l)} $$ The NTK for this $(l+1)$-layer network can be defined accordingly:\n $$ \\begin{aligned} \u0026 K^{(l+1)}(\\mathbf{x}, \\mathbf{x}'; \\theta) \\\\ =\u0026 \\nabla_{\\theta} f(\\mathbf{x};\\theta)^\\top \\nabla_{\\theta} f(\\mathbf{x};\\theta) \\\\ =\u0026 \\color{blue}{\\nabla_{\\mathbf{w}^{(l)}} f(\\mathbf{x};\\theta)^\\top \\nabla_{\\mathbf{w}^{(l)}} f(\\mathbf{x};\\theta)} + \\color{green}{\\nabla_{\\mathbf{b}^{(l)}} f(\\mathbf{x};\\theta)^\\top \\nabla_{\\mathbf{b}^{(l)}} f(\\mathbf{x};\\theta)} + \\color{red}{\\nabla_{\\tilde{\\theta}} f(\\mathbf{x};\\theta)^\\top \\nabla_{\\tilde{\\theta}} f(\\mathbf{x};\\theta)} \\\\ =\u0026 \\frac{1}{n_l} \\Big[ \\color{blue}{\\sigma(\\tilde{A}^{(l)})\\sigma(\\tilde{A}^{(l)})^\\top} + \\color{green}{\\beta^2} \\\\ \u0026+ \\color{red}{ {\\mathbf{w}^{(l)}}^\\top \\begin{bmatrix} \\dot{\\sigma}(\\tilde{A}_1^{(l)})\\dot{\\sigma}(\\tilde{A}_1^{(l)})\\sum_{p=1}^\\tilde{P} \\frac{\\partial \\tilde{A}_1^{(l)}}{\\partial \\tilde{\\theta}_p}\\frac{\\partial \\tilde{A}_1^{(l)}}{\\partial \\tilde{\\theta}_p} \u0026 \\dots \u0026 \\dot{\\sigma}(\\tilde{A}_1^{(l)})\\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})\\sum_{p=1}^\\tilde{P} \\frac{\\partial \\tilde{A}_1^{(l)}}{\\partial \\tilde{\\theta}_p}\\frac{\\partial \\tilde{A}_{n_l}^{(l)}}{\\partial \\tilde{\\theta}_p} \\\\ \\vdots \\\\ \\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})\\dot{\\sigma}(\\tilde{A}_1^{(l)})\\sum_{p=1}^\\tilde{P} \\frac{\\partial \\tilde{A}_{n_l}^{(l)}}{\\partial \\tilde{\\theta}_p}\\frac{\\partial \\tilde{A}_1^{(l)}}{\\partial \\tilde{\\theta}_p} \u0026 \\dots \u0026 \\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})\\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})\\sum_{p=1}^\\tilde{P} \\frac{\\partial \\tilde{A}_{n_l}^{(l)}}{\\partial \\tilde{\\theta}_p}\\frac{\\partial \\tilde{A}_{n_l}^{(l)}}{\\partial \\tilde{\\theta}_p} \\\\ \\end{bmatrix} \\mathbf{w}^{(l)} } \\color{black}{\\Big]} \\\\ =\u0026 \\frac{1}{n_l} \\Big[ \\color{blue}{\\sigma(\\tilde{A}^{(l)})\\sigma(\\tilde{A}^{(l)})^\\top} + \\color{green}{\\beta^2} \\\\ \u0026+ \\color{red}{ {\\mathbf{w}^{(l)}}^\\top \\begin{bmatrix} \\dot{\\sigma}(\\tilde{A}_1^{(l)})\\dot{\\sigma}(\\tilde{A}_1^{(l)})K^{(l)}_{11} \u0026 \\dots \u0026 \\dot{\\sigma}(\\tilde{A}_1^{(l)})\\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})K^{(l)}_{1n_l} \\\\ \\vdots \\\\ \\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})\\dot{\\sigma}(\\tilde{A}_1^{(l)})K^{(l)}_{n_l1} \u0026 \\dots \u0026 \\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})\\dot{\\sigma}(\\tilde{A}_{n_l}^{(l)})K^{(l)}_{n_ln_l} \\\\ \\end{bmatrix} \\mathbf{w}^{(l)} } \\color{black}{\\Big]} \\end{aligned} $$ where each individual entry at location $(m, n), 1 \\leq m, n \\leq n_{l+1}$ of the matrix $K^{(l+1)}$ can be written as:\n $$ \\begin{aligned} K^{(l+1)}_{mn} =\u0026 \\frac{1}{n_l}\\Big[ \\color{blue}{\\sigma(\\tilde{A}_m^{(l)})\\sigma(\\tilde{A}_n^{(l)})} + \\color{green}{\\beta^2} + \\color{red}{ \\sum_{i=1}^{n_l} \\sum_{j=1}^{n_l} w^{(l)}_{im} w^{(l)}_{in} \\dot{\\sigma}(\\tilde{A}_i^{(l)}) \\dot{\\sigma}(\\tilde{A}_{j}^{(l)}) K_{ij}^{(l)} } \\Big] \\end{aligned} $$ When $n_l \\to \\infty$, the section in blue and green has the limit (See the proof in the previous section):\n $$ \\frac{1}{n_l}\\sigma(\\tilde{A}^{(l)})\\sigma(\\tilde{A}^{(l)}) + \\beta^2\\to \\Sigma^{(l+1)} $$ and the red section has the limit:\n $$ \\sum_{i=1}^{n_l} \\sum_{j=1}^{n_l} w^{(l)}_{im} w^{(l)}_{in} \\dot{\\sigma}(\\tilde{A}_i^{(l)}) \\dot{\\sigma}(\\tilde{A}_{j}^{(l)}) K_{ij}^{(l)} \\to \\sum_{i=1}^{n_l} \\sum_{j=1}^{n_l} w^{(l)}_{im} w^{(l)}_{in} \\dot{\\sigma}(\\tilde{A}_i^{(l)}) \\dot{\\sigma}(\\tilde{A}_{j}^{(l)}) K_{\\infty,ij}^{(l)} $$ Later, Arora et al. (2019) provided a proof with a weaker limit, that does not require all the hidden layers to be infinitely wide, but only requires the minimum width to be sufficiently large.\nLinearized Models From the previous section, according to the derivative chain rule, we have known that the gradient update on the output of an infinite width network is as follows; For brevity, we omit the inputs in the following analysis:\n $$ \\begin{aligned} \\frac{df(\\theta)}{dt} \u0026= -\\eta\\nabla_\\theta f(\\theta)^\\top \\nabla_\\theta f(\\theta) \\nabla_f \\mathcal{L} \u0026 \\\\ \u0026= -\\eta\\nabla_\\theta f(\\theta)^\\top \\nabla_\\theta f(\\theta) \\nabla_f \\mathcal{L} \u0026 \\\\ \u0026= -\\eta K(\\theta) \\nabla_f \\mathcal{L} \\\\ \u0026= \\color{cyan}{-\\eta K_\\infty \\nabla_f \\mathcal{L}} \u0026 \\text{; for infinite width network}\\\\ \\end{aligned} $$ To track the evolution of $\\theta$ in time, let\u0026rsquo;s consider it as a function of time step $t$. With Taylor expansion, the network learning dynamics can be simplified as:\n $$ f(\\theta(t)) \\approx f^\\text{lin}(\\theta(t)) = f(\\theta(0)) + \\underbrace{\\nabla_\\theta f(\\theta(0))}_{\\text{formally }\\nabla_\\theta f(\\mathbf{x}; \\theta) \\vert_{\\theta=\\theta(0)}} (\\theta(t) - \\theta(0)) $$ Such formation is commonly referred to as the linearized model, given $\\theta(0)$, $f(\\theta(0))$, and $\\nabla_\\theta f(\\theta(0))$ are all constants. Assuming that the incremental time step $t$ is extremely small and the parameter is updated by gradient descent:\n $$ \\begin{aligned} \\theta(t) - \\theta(0) \u0026= - \\eta \\nabla_\\theta \\mathcal{L}(\\theta) = - \\eta \\nabla_\\theta f(\\theta)^\\top \\nabla_f \\mathcal{L} \\\\ f^\\text{lin}(\\theta(t)) - f(\\theta(0)) \u0026= - \\eta\\nabla_\\theta f(\\theta(0))^\\top \\nabla_\\theta f(\\mathcal{X};\\theta(0)) \\nabla_f \\mathcal{L} \\\\ \\frac{df(\\theta(t))}{dt} \u0026= - \\eta K(\\theta(0)) \\nabla_f \\mathcal{L} \\\\ \\frac{df(\\theta(t))}{dt} \u0026= \\color{cyan}{- \\eta K_\\infty \\nabla_f \\mathcal{L}} \u0026 \\text{; for infinite width network}\\\\ \\end{aligned} $$ Eventually we get the same learning dynamics, which implies that a neural network with infinite width can be considerably simplified as governed by the above linearized model (Lee \u0026amp; Xiao, et al. 2019).\nIn a simple case when the empirical loss is an MSE loss, $\\nabla_\\theta \\mathcal{L}(\\theta) = f(\\mathcal{X}; \\theta) - \\mathcal{Y}$, the dynamics of the network becomes a simple linear ODE and it can be solved in a closed form:\n $$ \\begin{aligned} \\frac{df(\\theta)}{dt} =\u0026 -\\eta K_\\infty (f(\\theta) - \\mathcal{Y}) \u0026 \\\\ \\frac{dg(\\theta)}{dt} =\u0026 -\\eta K_\\infty g(\\theta) \u0026 \\text{; let }g(\\theta)=f(\\theta) - \\mathcal{Y} \\\\ \\int \\frac{dg(\\theta)}{g(\\theta)} =\u0026 -\\eta \\int K_\\infty dt \u0026 \\\\ g(\\theta) \u0026= C e^{-\\eta K_\\infty t} \u0026 \\end{aligned} $$ When $t=0$, we have $C=f(\\theta(0)) - \\mathcal{Y}$ and therefore,\n $$ f(\\theta) = (f(\\theta(0)) - \\mathcal{Y})e^{-\\eta K_\\infty t} + \\mathcal{Y} \\\\ = f(\\theta(0))e^{-K_\\infty t} + (I - e^{-\\eta K_\\infty t})\\mathcal{Y} $$ Lazy Training People observe that when a neural network is heavily over-parameterized, the model is able to learn with the training loss quickly converging to zero but the network parameters hardly change. Lazy training refers to the phenomenon. In other words, when the loss $\\mathcal{L}$ has a decent amount of reduction, the change in the differential of the network $f$ (aka the Jacobian matrix) is still very small.\nLet $\\theta(0)$ be the initial network parameters and $\\theta(T)$ be the final network parameters when the loss has been minimized to zero. The delta change in parameter space can be approximated with first-order Taylor expansion:\n $$ \\begin{aligned} \\hat{y} = f(\\theta(T)) \u0026\\approx f(\\theta(0)) + \\nabla_\\theta f(\\theta(0)) (\\theta(T) - \\theta(0)) \\\\ \\text{Thus }\\Delta \\theta \u0026= \\theta(T) - \\theta(0) \\approx \\frac{\\|\\hat{y} - f(\\theta(0))\\|}{\\| \\nabla_\\theta f(\\theta(0)) \\|} \\end{aligned} $$ Still following the first-order Taylor expansion, we can track the change in the differential of $f$:\n $$ \\begin{aligned} \\nabla_\\theta f(\\theta(T)) \u0026\\approx \\nabla_\\theta f(\\theta(0)) + \\nabla^2_\\theta f(\\theta(0)) \\Delta\\theta \\\\ \u0026= \\nabla_\\theta f(\\theta(0)) + \\nabla^2_\\theta f(\\theta(0)) \\frac{\\|\\hat{y} - f(\\mathbf{x};\\theta(0))\\|}{\\| \\nabla_\\theta f(\\theta(0)) \\|} \\\\ \\text{Thus }\\Delta\\big(\\nabla_\\theta f\\big) \u0026= \\nabla_\\theta f(\\theta(T)) - \\nabla_\\theta f(\\theta(0)) = \\|\\hat{y} - f(\\mathbf{x};\\theta(0))\\| \\frac{\\nabla^2_\\theta f(\\theta(0))}{\\| \\nabla_\\theta f(\\theta(0)) \\|} \\end{aligned} $$ Let $\\kappa(\\theta)$ be the relative change of the differential of $f$ to the change in the parameter space:\n $$ \\kappa(\\theta = \\frac{\\Delta\\big(\\nabla_\\theta f\\big)}{\\| \\nabla_\\theta f(\\theta(0)) \\|} = \\|\\hat{y} - f(\\theta(0))\\| \\frac{\\nabla^2_\\theta f(\\theta(0))}{\\| \\nabla_\\theta f(\\theta(0)) \\|^2} $$ Chizat et al. (2019) showed the proof for a two-layer neural network that $\\mathbb{E}[\\kappa(\\theta_0)] \\to 0$ (getting into the lazy regime) when the number of hidden neurons $\\to \\infty$. Also, recommend this post for more discussion on linearized models and lazy training.\nCitation Cited as:\n Weng, Lilian. (Sep 2022). Some math behind neural tangent kernel. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2022-09-08-ntk/.\n Or\n@article{weng2022ntk, title = \u0026quot;Some Math behind Neural Tangent Kernel\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;Lil'Log\u0026quot;, year = \u0026quot;2022\u0026quot;, month = \u0026quot;Sep\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2022-09-08-ntk/\u0026quot; } References [1] Jacot et al. \u0026ldquo;Neural Tangent Kernel: Convergence and Generalization in Neural Networks.\u0026quot; NeuriPS 2018.\n[2]Radford M. Neal. \u0026ldquo;Priors for Infinite Networks.\u0026quot; Bayesian Learning for Neural Networks. Springer, New York, NY, 1996. 29-53.\n[3] Lee \u0026amp; Bahri et al. \u0026ldquo;Deep Neural Networks as Gaussian Processes.\u0026quot; ICLR 2018.\n[4] Chizat et al. \u0026ldquo;On Lazy Training in Differentiable Programming\u0026rdquo; NeuriPS 2019.\n[5] Lee \u0026amp; Xiao, et al. \u0026ldquo;Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent.\u0026quot; NeuriPS 2019.\n[6] Arora, et al. \u0026ldquo;On Exact Computation with an Infinitely Wide Neural Net.\u0026quot; NeurIPS 2019.\n[7] (YouTube video) \u0026ldquo;Neural Tangent Kernel: Convergence and Generalization in Neural Networks\u0026rdquo; by Arthur Jacot, Nov 2018.\n[8] (YouTube video) \u0026ldquo;Lecture 7 - Deep Learning Foundations: Neural Tangent Kernels\u0026rdquo; by Soheil Feizi, Sep 2020.\n[9] \u0026ldquo;Understanding the Neural Tangent Kernel.\u0026quot; Rajat\u0026rsquo;s Blog.\n[10] \u0026ldquo;Neural Tangent Kernel.\u0026quot;Applied Probability Notes, Mar 2021.\n[11] \u0026ldquo;Some Intuition on the Neural Tangent Kernel.\u0026quot; inFERENCe, Nov 2020.\n","permalink":"https://lilianweng.github.io/posts/2022-09-08-ntk/","summary":"Neural networks are well known to be over-parameterized and can often easily fit data with near-zero training loss with decent generalization performance on test dataset. Although all these parameters are initialized at random, the optimization process can consistently lead to similarly good outcomes. And this is true even when the number of model parameters exceeds the number of training data points.\nNeural tangent kernel (NTK) (Jacot et al. 2018) is a kernel to explain the evolution of neural networks during training via gradient descent.","title":"Some Math behind Neural Tangent Kernel"},{"content":"Processing images to generate text, such as image captioning and visual question-answering, has been studied for years. Traditionally such systems rely on an object detection network as a vision encoder to capture visual features and then produce text via a text decoder. Given a large amount of existing literature, in this post, I would like to only focus on one approach for solving vision language tasks, which is to extend pre-trained generalized language models to be capable of consuming visual signals.\nI roughly group such vision language models (VLMs) into four buckets:\n Translating images into embedding features that can be jointly trained with token embeddings. Learning good image embeddings that can work as a prefix for a frozen, pre-trained language model. Using a specially designed cross-attention mechanism to fuse visual information into layers of the language model. Combine vision and language models without any training. Jointly Training with Image and Text One straightforward approach to fuse visual information into language models is to treat images as normal text tokens and train the model on a sequence of joint representations of both text and images. Precisely, images are divided into multiple smaller patches and each patch is treated as one \u0026ldquo;token\u0026rdquo; in the input sequence.\nVisualBERT (Li et al. 2019) feeds both text inputs and image regions into BERT such that it is able to discover the internal alignment between images and text with self-attention mechanism.\nFig. 1. VisualBERT is trained on the combination of both text and image embeddings. (Image source: Li et al. 2019) Similar to text embedding in BERT, each visual embedding in VisualBERT also sums up three types of embeddings, tokenized features $f_o$, segmentation embedding $f_s$ and position embedding $f_p$, precisely:\n $f_o$ is a visual feature vector computed for a bounding region of the image by a convolutional neural network; $f_s$ is a segment embedding to indicate whether the embedding is for vision not for text; $f_p$ is a position embedding used for aligning the order of bounding regions. The model is trained on MS COCO image caption dataset with both text and image as inputs to predict text captions, using two visually-grounded language model objectives:\n MLM with the image. The model needs to predict masked text tokens, while image embeddings always stay not masked. Sentence-image prediction. When provided with an image and two associated captions, one of two captions might be a random unrelated caption with 50% probability. The model is asked to distinguish these two situations. According to ablation experiments, the most important configuration is to fuse visual information early on into the transformer layers and to pretrain the model on the COCO caption dataset. Initialization from a pre-trained BERT and the adoption of the sentence-image prediction training objective have relatively small impacts.\nFig. 2. Ablation study results of VisualBERT on NLVR. (Image source: Li et al. 2019) VisualBERT outperforms SoTA at the time on NLVR and Flickr30K, but still has some performance gap with SoTA on VQA.\nSimVLM (Simple Visual Language Model; Wang et al. 2022) is a simple prefix language model, where the prefix sequence is processed with bi-directional attention like BERT, but the main input sequence only has causal attention like GPT. Images are encoded as prefix tokens such that the model can fully consume the visual information and then generates associated text in an autoregressive manner.\nInspired by ViT and CoAtNet, SimVLM splits the image into smaller patches in a flatten 1D sequence of patches. They use the convolutional stage consisting of the first 3 blocks of ResNet to extract contextualized patches and this setup is found to work better than a naive linear projection.\nFig. 3. Training architecture for SimVLM, where the image patches are processed by the cross-attention encoder and the text decoder has causal attention. (Image source: Wang et al. 2022) Training data for SimVLM consists of a large number of image-text pairs from ALIGN (Jia et al. 2021) and text-only data from C4 dataset (Raffel et al. 2019). They mix the two pretraining datasets within each batch, containing 4,096 image-text pairs (ALIGN) and 512 text-only documents (C4).\nAccording to ablation studies, it is important to have both image-text and text-only data for training. The PrefixLM objective outperforms both span corruption and naive LM.\nFig. 4. Ablation study results of SimVLM on VQA. (Image source: Wang et al. 2022) CM3 (Causally-Masked Multimodal Modeling; Aghajanyan, et al. 2022) is a hyper-text language model, learning to generate the content (hypertext markup, hyperlinks and images) of large scale HTML web pages of CC-NEWS and Wikipedia articles. The resulting CM3 models can generate rich structured, multi-modal outputs while conditioning on arbitrary masked document contexts.\nArchitecture-wise, CM3 is an autoregressive model. However, in order to combine causal and masked language modeling, CM3 also masks out a small number of long token spans and tries to generate them at the end of the sequences.\nFig. 5. Illustration of how a causally masked language model works. (Image source: Aghajanyan, et al. 2022) The training dataset for CM3 contains close to 1T Web data. During preprocessing, images are first downloaded from src and resized to 256 x 256 with random cropping. Then they are tokenized by VQVAE-GAN, resulting in 256 tokens per image. These tokens, joined with spaces, are inserted back into the src attribute.\nCM3 can be used to complete several types of tasks by prompt engineering:\n Image in-filling: Infilling Prompt: \u0026lt;img src=\u0026quot;{prefix}\u0026lt;mask:0\u0026gt;{postfix}\u0026quot;\u0026gt;\u0026lt;mask:0\u0026gt; Conditional image in-filling: Conditional Infilling Prompt: \u0026lt;img alt=\u0026quot;Photo: {text}\u0026quot; src=\u0026quot;{prefix}\u0026lt;mask:0\u0026gt;{postfix}\u0026quot;\u0026gt;\u0026lt;mask:0\u0026gt; Conditional image generation: Conditional Generation Prompt: \u0026lt;img alt=\u0026quot;{prompt} Image captions: Captioning Masked Prompt #1: \u0026lt;img alt=\u0026quot;Photo: A photo taken of\u0026lt;mask:0\u0026gt;\u0026quot; src=\u0026quot;{image}\u0026quot;\u0026gt; Captioning Causal Prompt #1: \u0026lt;img src=\u0026quot;{image}\u0026quot; title=\u0026quot;Photo: A photo taken of Entity disambiguation Original: Manetho writes that these kings ruled from \u0026lt;a title=\u0026quot;Memphis, Egypt\u0026quot;\u0026gt;Memphis\u0026lt;/a\u0026gt; Prompt: Manetho writes that these kings ruled from \u0026lt;a title=\u0026quot;\u0026lt;mask:0\u0026gt;\u0026quot;\u0026gt;Memphis\u0026lt;/a\u0026gt;...\u0026lt;mask:0\u0026gt; Target: Manetho writes that these kings ruled from \u0026lt;a title=\u0026quot;\u0026lt;mask:0\u0026gt;\u0026quot;\u0026gt;Memphis\u0026lt;/a\u0026gt;...\u0026lt;mask:0\u0026gt; Memphis, Egypt Learned Image Embedding as (Frozen) LM Prefix What if we don’t want to change the language model parameters when adapting it to handle visual signals? Instead we learn such an embedding space for images that it is compatible with the language model’s.\nInspired by prefix or prompt tuning, both Frozen (Tsimpoukelli et al. 2021) and ClipCap (Mokady, Hertz \u0026amp; Hertz, 2021) only update the parameters of the vision module during training to produce image embeddings that can work with a pretrained, frozen language model. Both are trained with aligned image caption datasets to produce the next text token in caption conditioned on the image and previous text tokens. The powerful language capability is retained by freezing LM parameters. In addition, even though such setup is trained with limited image caption data, they can also rely on the encyclopedic knowledge of the language model at test time.\nThe vision encoder of Frozen is based on NF-ResNet-50 and uses the final output vector of the NF-Resnet after the global pooling layer. The Frozen VLM can be used as a multi-model few-shot learner to adapt to new tasks at test time for zero-shot or few-shot transfer with a sequence of interleaved images and text.\nFig. 6. Illustration of Frozen model (left) training architecture and (right) testing pipeline. (Image source: Tsimpoukelli et al. 2021) Experiments showed that fine-tuning the pre-trained LM interestingly leads to worse performance on VQA tasks. It is important to initialize the language model from a pre-trained version, as training from scratch (${Frozen}_\\text{scratch}$) does not show any meaningful progress. The baseline ${Frozen}_\\text{train-blind}$ blacks out the image but still can achieve decent performance because of the innate power of the pre-trained LM.\nFig. 7. Performance of different versions of Frozen on (left) VQAv2 and (right) OKVQA, trained on Conceptual Captions. \"Frozen scratch\" does not load a pre-trained LM and is trained from scratch. \"Frozen finetuned\" has the language model finetuned, while \"Frozen\" keeps LM frozen. \"Frozen train-blind\" blacks out the image. (Image source: Tsimpoukelli et al. 2021) ClipCap relies on CLIP (Radford et al. 2021) for vision encoding, but it needs to be processed by a light mapping network $F$ such that image embedding vectors are translated into the same semantic space as the pre-trained LM. The network $F$ maps CLIP embeddings into a sequence of $k$ embedding vectors, each with the same dimension as a word embedding in GPT2. Increasing the prefix size $k$ helps improve the performance. Both CLIP vision encoder and the LM are frozen during training and only the mapping network $F$ is learned. They found that when LM is frozen, $F$ should be a transformer, with 8 multi-head self-attention layers with 8 heads each, but when LM can be fine-tuned, a MLP is enough.\nEven though ClipCap only trains such a minimum set of parameters, it still achieves decent performance on image captioning tasks, comparable with SoTA at the time (e.g. Oscar, VLP, BUTD). Hence they postulate that \u0026ldquo;the CLIP space already encapsulates the required information, and adapting it towards specific styles does not contribute to flexibility.\u0026rdquo;\nFig. 8. Overview of ClipCap training pipeline where only the mapping network needs to be train to transform CLIP image embedding to work with the pre-trained LM. (Image source: Mokady, Hertz \u0026 Hertz, 2021) The fun fact is - because ClipCap translates CLIP image embeddings into LM space, the processed prefixes can be even interpreted as words.\nFig. 9. The learned image embedding can be interpreted as text, containing words related to the image context. (Image source: Mokady, Hertz \u0026 Hertz, 2021) Text-Image Cross-Attention Fuse Mechanisms To more efficiently fuse visual information into different layers of the language model, we can consider a specially designed cross-attention fuse mechanism to balance the mixture of text generation capacity and visual information.\nVisualGPT (Chen et al. 2021) employs a self-resurrecting encoder-decoder attention mechanism to quickly adapt the pre-trained LM with a small amount of in-domain image-text data.\nFig. 10. Illustration of VisualGPT architecture. (Image source: Chen et al. 2021) Let $I$ be the output of a visual encoder and $H$ be the hidden state of the LM decoder. VisualGPT introduced a self-resurrecting activation unit (SRAU) to control the tradeoff between a mixture of pre-trained linguistic information $H$ and visual component, $\\text{EncDecAttn}(H, I)$ via two complementary gates $B^\\text{vis}$ and $B^\\text{lan}$:\n$$ \\begin{aligned} \u0026amp; B^\\text{vis} \\otimes \\text{EncDecAttn}(H, I) + B^\\text{lan} \\otimes H \\\\ \\text{where } \u0026amp; B^\\text{vis}[i,j] = \\sigma(H[i,j]) \\mathbb{1}[\\sigma(H[i,j]) \u0026gt; \\tau] \\\\ \u0026amp; B^\\text{lan}[i,j] = (1 - \\sigma(H[i,j])) \\mathbb{1}[1 - \\sigma(H[i,j]) \u0026gt; \\tau] \\\\ \\end{aligned} $$ where $\\otimes$ is element-wise multiplication and $[i,j]$ denotes one element in the matrix. $\\tau$ is a predefined threshold hyperparameter.\nFig. 11. Comparison of different models trained on 0.1% and 1% of the MS COCO and Conceptual Caption datasets. (Image source: Chen et al. 2021) VC-GPT (Visual Conditioned GPT; Luo et al. 2022) combines a pretrained visual transformer (CLIP-ViT) as visual encoder and a pretrained LM as language decoder.\nFig. 12. Illustration of VC-GPT training framework. (Image source: Luo et al. 2022) The CLIP-ViT takes a sequence of image patches as inputs and outputs representation for each patch. To avoid catastrophic forgetting, instead of injecting the visual information directly into GPT2, VC-GPT introduces extra cross-attention layers on top of the output of visual encoder and language decoder. Then a self-ensemble module linearly combines the single model language decoder logits $h^G$ and cross-model vision-language fused module logits $h^\\text{fuse}$. The self-ensemble module (see \u0026ldquo;VC-GPT w/o SE\u0026rdquo; in Fig. 13) is important for the performance.\n$$ \\text{logits} = W^G h^G + W^\\text{fuse}h^\\text{fuse} $$\nwhere $W^G$ is a linear projection of the language decoder, initialized by the word embedding matrix of GPT2 and $W^\\text{fuse}$ is a linear projection of the fusion module and initialized randomly.\nFig. 13. Performance of VC-GPT on the MS COCO test set, in comparison with other end-to-end image captioning baseline models. Metric abbreviation: C = CIDEr; B = BLEU; M = METEOR; S = SPICE. (Image source: Luo et al. 2022) MERLOT (Zellers, et al. 2021) is trained with 6 millions of YouTube videos with transcribed speech (YT-Temporal-180M) to learn both spatial (frame-level) and temporal (video-level) objectives and demonstrated strong performance on VQA and visual reasoning tasks when fine-tuned.\nEach video $\\mathcal{V}$ is split into multiple segments $\\{ \\boldsymbol{s}_t \\}$, each segment $\\boldsymbol{s}_t$ containing an image frame $\\mathbf{I}_t$ extracted from the middle timestep and $L=32$ tokens of words associated. Images are encoded by a learned image encoder and words are encoded using a learned embedding. Then both are encoded together within a joint vision-language transformer.\nThere are 3 learning objectives in MERLOT:\n Masked language modeling (MLM) is useful especially because in videos, people tend to ramble, resulting in many repeated keywords or filler words. Contrastive frame-caption matching uses the language-only part from the joint vision-language transformer. Matched representations for each frame $\\mathbf{I}_t$ and caption $\\boldsymbol{w}_t$ are treated as positive examples, while the negative examples come from all other frame-caption pairs in the minibatch. Temporal reordering learns temporal reasoning: scramble random $i$ frames and replace the segment-level position embeddings with a random and unique position embedding. The random position embeddings are learned, allowing the model to unshuffle these \u0026ldquo;\u0026lsquo;shuffled\u0026rsquo;\u0026rdquo; frames conditioned on correctly-ordered ones. The loss is to predict whether $t_i \u0026lt; t_j$ or $t_j \u0026lt; t_i$ for each frame-frame pair. Fig. 14. Illustration of MERLOT training framework: (Left) contrastive frame-caption matching training; (Right) joint vision-language transformer is trained with MLM loss, as well as on the temporal reordering task to unshuffle scrambled video frames. (Image source: Zellers, et al. 2021) Ablation studies showed that it is important to (1) train on videos instead of images, (2) scale up the size and diversity of the training dataset and (3) use diverse objectives to encourage full-stack multimodal reasoning.\nFlamingo (Alayrac et al. 2022) is a visual language model that accepts text interleaved with images/videos and outputs free-form text. Flamingo connects a pretrained LM and a pretrained vision encoder (i.e. CLIP image encoder) via a transformer-based mapper. To more efficiently incorporate vision signals, Flamingo adopts a Perceiver-based architecture to produce a few hundreds of tokens out of a large number of visual input features and then use cross-attention layers interleaved with the LM layers to fuse visual information into the language decoding process. The training objective is an autoregressive, NLL loss.\n The Perceiver resampler receives spatio-temporal features from the vision encoder of image/video inputs to produce fixed-size visual tokens. The frozen LM is equipped with newly initialized cross-attention layers interleaved between the pretrained LM layers. Thus the LM can generate text conditioned on the above visual tokens. Similar to ClipCap, both pretrained models are frozen during training and thus Flamingo is only trained to harmoniously connect existing, powerful language and vision models together. Tha main difference between ClipCap and Flamingo is that the former treats the image embedding as simple prefix for LM, while the latter uses the gated cross-attention-dense layer to fuse image information. In addition, Flamingo incorporates a lot more training data than ClipCap.\nFig. 15. Overview of the Flamingo model. (Image source: Alayrac et al. 2022) Fig. 16. The architecture illustration and pseudo code of the gated cross-attention-dense layer in Flamingo. (Image source: Alayrac et al. 2022) To easily handle text with interleaved images, masking in Flamingo is designed such that text token only cross-attends to visual tokens corresponding to the last preceding image, largely reducing the number of visual tokens that a certain text token can see. They found this works better than allowing text tokens to attend to all preceding images directly. Text still can attend to all previous images because there is a causal self-attention dependency in the text encoder. This design can deal with an arbitrary number of images in the context.\nThey scraped 43 million webpages from the Internet, named MultiModal MassiveWeb (M3W) dataset, containing text with interleaved images. In addition, Flamingo is also trained on paired image/text and video/text datasets, including ALIGN, LTIP and VTP.\nData processing of the Internet dataset includes:\n The input Web page text is processed by inserting \u0026lt;image\u0026gt; tags at the location of visual inputs, as well as special tokens, \u0026lt;BOS\u0026gt; (beginning of sentence) and \u0026lt;EOC\u0026gt; (end of chunks; always at the end of the document, before any image tag). From each document, they sample a random subsequence of $L = 256$ tokens and take up to $N = 5$ images included in the sampled sequence (using only the first $N$ within that sampled subsequence if there are more, or padding to $N$ if fewer) A function $\\phi: [1,L] \\to [0,N]$ is computed to track the text and image interleaving order, which assigns to each text position the index of the last image/video appearing before this position; 0 if no preceding visual data. Since Flamingo is trained on a mixture of three different datasets, it optimizes for a weighted sum of dataset-specific NLL losses. Tuning the dataset weights is very important for the final performance. In practice, instead of round-robin between datasets, they actually sample one batch from each dataset and apply a weighted sum of these gradients in each update. Gradient accumulation across different heterogeneous datasets can be viewed as a mean to stabilize training, as it reduces the gradient variance between each update.\nAt test time, Flamingo naturally supports few-shot learning since it can work with any sequence of interleaved text and images. And more examples in the context contribute to better performance.\nFig. 17. Larger model sizes and more few-shot examples lead to better performance. (Image source: Alayrac et al. 2022) Flamingo outperforms SoTA fine-tuned models on 6 out of the 16 tasks despite even when not using any fine-tuning but only few-shot prompting. Fine-tuning Flamingo is expensive and it is difficult to do hyperparemeter tuning, but it does lead to better results.\nFig. 18. Performance of Flamingo model using different numbers of shots and of different sizes, in comparison with SoTA fine-tuned baseline. (Image source: Alayrac et al. 2022) CoCa (Contrastive Captioner; Yu \u0026amp; Wang et al., 2022) captures both the merits of contrastive learning and image-to-caption generation. It is a model jointly trained with contrastive loss on CLIP-style representation and generative loss on image captioning, achieving SoTA zero-shot transfer on a variety of multi-modal evaluation tasks.\nFig. 19. Overview of CoCa training framework. (Image source: Yu \u0026 Wang et al., 2022) CoCa is pretrained from scratch, using web-scale alt-text data ALIGN and annotated images by treating all labels as texts in JTB-3B.\nThere are two major training components in CoCa. The final loss is a weighted sum of the following two losses, with weight scalars $\\lambda_\\text{cap}=2.0, \\lambda_\\text{con} = 1.0$.:\n $\\mathcal{L}_\\text{con}$ - Dual-encoder contrastive learning optimizes the symmetric contrastive learning loss, similar to CLIP. $\\mathcal{L}_\\text{cap}$ - Encoder-decoder captioning has the decoder predict the caption based on the latent encoded features from the image encoder, by optimizing an autoregressive loss. The text decoder is decoupled into two components, unimodal and multimodal; a good balance is to split the decoder by half for these two components: The bottom unimodal component encodes the input text with causally-masked self-attention. The top multimodal component applies both causally-masked self-attention and cross-attention to the output of the vision encoder. CoCa performs better than the contrastive-only model and on par with the captioning-only model on VQA. Captioning loss is found to be beneficial to the zero-shot classification capacity too.\nFig. 20. Illustration of how CoCa can be used to solve various downstream tasks at test time. (Image source: Yu \u0026 Wang et al., 2022) They use task-specific attention pooling, or attention pooler, as a natural task adapter, as they found that a single pooled image embedding helps visual recognition tasks (e.g. ImageNet classification), while a more fine-grained embedding helps multimodal understanding tasks (e.g. VQA). A pooler is a single multi-head attention layer with $n_\\text{query}$ learnable queries (note that $\\mathbf{X} \\in \\mathbb{R}^{L \\times d}$, $\\mathbf{W}^q \\in \\mathbb{R}^{d \\times d_q}$, and $d_k = d_q$), with the encoder output as both keys and values. CoCa uses attentional poolers in pretraining for generative loss $n_\\text{query} = 256$ and contrastive loss $n_\\text{query} = 1$. This enables the model to obtain strong performance as a frozen encoder where we only learn a new pooler to aggregate features.\nFig. 21. Pseudo code for CoCa architecture and training. (Image source: Yu \u0026 Wang et al., 2022) No Training Finally it is possible to solve vision language tasks by stitching pretrained language and vision models together without training any additional parameters.\nDecoding Guided with Vision-based Scores MAGiC (iMAge-Guided text generatIon with CLIP; Su et al. 2022) does guided decoding according to a CLIP-based score named magic score to sample the next token, without fine-tuning. The generated text is encouraged to be relevant to the given image, while still stay coherent to the previously generated text.\nThe next token $x_t$ at a time step $t$ is selected according to the following equation. Model confidence and degeneration penalty (Su et al. 2022) are added to avoid corrupted generation from LM.\n$$ \\begin{aligned} \u0026amp; x_t = \\arg\\max_{v \\in \\mathcal{V}^{(k)}} \\big\\{ (1-\\alpha) \\underbrace{p(v \\vert \\boldsymbol{x}_{\u0026lt;t})}_\\text{model confidence} - \\alpha \\underbrace{\\max_{1 \\leq j \\leq t-1} { \\text{cosine}(h_v, h_{x_j})}}_\\text{degeneration penalty} + \\beta \\underbrace{f_\\text{magic}(v \\vert \\mathcal{I}, \\boldsymbol{x}_{\u0026lt;t}, \\mathcal{V}^{(k)})}_\\text{magic score} \\big\\} \\\\ \\text{where } \u0026amp; f_\\text{magic} ( v \\vert \\mathcal{I}, \\mathbf{x}_{\u0026lt;t}, \\mathcal{V}^{(k)} ) = \\frac{ \\exp(\\text{CLIP}(\\mathcal{I}, [\\boldsymbol{x}_{\u0026lt;t}:v])) }{ \\sum_{z \\in \\mathcal{V}^{(k)}} \\exp(\\text{CLIP}(\\mathcal{I}, [\\boldsymbol{x}_{\u0026lt;t}:z])) } = \\frac{ \\exp\\big({h^\\text{image}(\\mathcal{I})}^\\top h^\\text{text}([\\boldsymbol{x}_{\u0026lt;t}:v])\\big) }{ \\sum_{z \\in \\mathcal{V}^{(k)}} \\exp\\big({h^\\text{image}(\\mathcal{I})}^\\top h^\\text{text}([\\boldsymbol{x}_{\u0026lt;t}:z])\\big) } \\end{aligned} $$\nwhere $\\mathcal{I}$ is the input image; $\\mathcal{V}^{(k)}$ contains top-$k$ possible tokens predicted by the language model $p$; $\\boldsymbol{x}_{\u0026lt;t}$ refers to the past generated tokens before time step $t$; $h_v$ is the representation of the token $v$ computed by LM conditioned on the concatenation of $\\boldsymbol{x}_{\u0026lt;t}$ and $v$; $h^\\text{image}(.)$ and $h^\\text{text}(.)$ are embeddings generated by CLIP image and text encoders, respectively.\nMAGiC has decent performance compared to other unsupervised approaches, but still has big gaps with supervised methods.\nFig. 22. Image captioning performance on COCO and Flickr30k. (Image source: Su et al. 2022) Language as Communication Interface For knowledge-based VQA tasks, PICa (Prompts GPT-3 via the use of Image Captions; Yang et al. 2021) first converts the images into captions or tags and then uses few-shot examples to prompt GPT3 to provide answers. Image captions or tags are extracted by some existing models (e.g. VinVL) or Azure Tagging API. And GPT3 is considered as an unstructured, implicit knowledge base.\nFig. 23. How PICa works for $n$-shot VQA at inference time. (Image source: Yang et al. 2021) PICa explored two ways to improve few-shot examples to achieve better results:\n In-context examples are selected based on how similar they are to the question using CLIP embedding. Multi-query ensembling is to prompt the model multiple times to get multiple answers and the one with highest logprob is selected. This simple approach with only 16 examples improved SoTA on OK-VQA by +8.6 points and got decent performance on VQAv2.\nFig. 24. Performance of PICa on OK-VQA. \"PICa-Base\" has random in-context examples, while \"PICa-Full\" incorporates both similar in-context example selection and multi-query ensembling. (Image source: Yang et al. 2021) Socratic Models (SM) (Zeng et al. 2022) is a framework to compose multiple pretrained models for different modality via language (prompting) into one model without further training. Here language is considered as the intermediate representation by which different models can exchange information. The key idea is to use multi-model multimodal prompting, in which output of a non-language model is inserted into a language prompt and then it is used for LM for reasoning.\nLet’s examine a concrete example. Given an ego-centric video (images + audio), SM can produce a summary of the person’s activity using text-to-text LM, image-to-text VLM and speech-to-text ALM. They are chained as follows:\n(Image source: Zeng et al. 2022) the VLM detects visual entities; the LM suggests sounds that may be heard; the ALM chooses the most likely sound; the LM suggests possible activities; the VLM ranks the most likely activity; the LM generates a summary of the Socratic interaction. Fig. 25. Illustration of the Socratic Model solution for image captioning. (Image source: Zeng et al. 2022) SM can generate image captions by first using VLM to zero-shot predict different place categories, object categories, image type and the number of people; and then the VLM-filled language prompt is fed into a causal LM to generate caption candidates. The Socratic approach still has performance gap with ClipCap on image captioning but pretty decent given it does not involve any training.\nFig. 26. Comparison of image captioning performance of different models on random 100 COCO text examples. (Image source: Zeng et al. 2022) SM framework is very flexible and can be used on a lot more complicated tasks other than image captions. For example, the egocentric perception (User inputs + VLM + LM + ALM) task is to take as inputs egocentric videos to (1) summarize content; (2) answer free-form reasoning questions; (3) and do forecasting.\nFig. 27. The Socratic Model approach for generating captions and question answering based on the egocentric videos. (Image source: Zeng et al. 2022) Datasets Image Caption Datasets MS COCO (Chen et al. 2015): contains 328K images and each paired with 5 independent captions. NoCaps (Agrawal et al., 2019) is designed to measure generalization to unseen classes and concepts, where in-domain contains images portraying only COCO classes, near-domain contains both COCO and novel classes, and out-of-domain consists of only novel classes. Conceptual Captions (Sharma et al. 2018) contains 3 million pairs of images and captions, mined from the web and post-processed. To focus on the concepts, specific entities in this dataset are replaced with general notions (e.g. a politician’s name is replaced with \u0026ldquo;politician\u0026rdquo;) Crisscrossed Captions (CxC) (Parekh et al. 2021) contains 247,315 human-labeled annotations including positive and negative associations between image pairs, caption pairs and image-caption pairs. Concadia (Kreiss et al. 2021) is a Wikipedia-based dataset containing 96,918 images with corresponding English-language descriptions, captions, and surrounding context. Pair Image-Text Datasets (*) Not a public dataset.\n ALIGN (Jia et al., 2021) contains 1.8 billion images with alt-text. The dataset is large but noisy with only minimal frequency-based filtration. (*) LTIP (Long text \u0026amp; image pairs; Alayrac et al. 2022): 312 million images, paired with descriptive captions. (*) VTP (Video \u0026amp; text pairs; Alayrac et al. 2022): 27 million short videos (~22 seconds on average), paired with descriptive captions. (*) JFT-300M / JFT-3B are internal Google datasets, containing 300M / 3B images annotated with a class-hierarchy of around 30k labels via a semi-automatic pipeline. Thus the data and associated labels are noisy. Evaluation Tasks Visual Question-Answering Given an image and a question, the task is to correctly answer the question.\n VQAv2 (Goyal et al., 2017) contains 1+ million questions about 200K images from COCO. OK-VQA (Marino et al. 2019) contains 14K open-ended questions that require outside knowledge (e.g. from Wikipedia). A-OKVQA: the augmented successor of OK-VQA, with no overlapped questions with OK-VAQ. TextVQA (Singh, et al. 2019) contains 45,336 questions on 28,408 images that require reasoning about text to answer. VizWiz (Gurari, et al. 2018) contains over 31,000 visual questions originating from blind people who each took a picture using a mobile phone and recorded a spoken question about it, together with 10 crowdsourced answers per visual question. Visual Language Reasoning VCR (Visual Commonsense Reasoning; Zellers et al. 2018) contains 290k multiple choice QA questions derived from 110k movie scenes, with focus on visual commonsense. NLVR2 (Natural Language for Visual Reasoning; Suhr et al. 2019) contains 100k+ examples of sentences paired with web images and the task is to determine whether a natural language caption is true about a pair of images, with a focus on semantic diversity. Flickr30K (Jia et al. 2015) contains 30k images collected from Flickr and 250k annotations and the task is to select the bounding regions given spans of a sentence. SNLI-VE (Visual Entailment; Xie et al. 2019) is built on top of SNLI and Flickr30K and the task is to reason about the relationship between an image premise and a text hypothesis. Video QA and Understanding MSR-VTT (MSR Video to Text; Xu et al. 2016) contains 10K web video clips with 41.2 hours and 200K clip-sentence pairs in total; the task is to translate videos to text. ActivityNet-QA (Yu et al. 2019) contains 58,000 human-annotated QA pairs on 5,800 videos derived from the popular ActivityNet dataset. TGIF (Tumblr GIF; Li et al. .2016) contains 100K animated GIFs and 120K sentences describing visual content of the animated GIFs, randomly selected posts published between May and June of 2015 on Tumblr. TGIF-QA contains 165K QA pairs for the animated GIFs from the TGIF dataset. LSMDC (Large Scale Movie Description Challenge; Rohrbach et al. 2015) contains 118,081 short video clips extracted from 202 movies. Each video has a caption, either extracted from the movie script or from transcribed DVS (descriptive video services) for the visually impaired. TVQA (Lei et al. 2018) / TVQA+ (Lei et al. 2019) is a large-scale video QA dataset based on 6 popular TV shows (Friends, The Big Bang Theory, How I Met Your Mother, House M.D., Grey\u0026rsquo;s Anatomy, Castle). It consists of 152.5K QA pairs from 21.8K video clips, spanning over 460 hours of video. DramaQA (Choi et al. 2020) is a large-scale video QA dataset based on a Korean popular TV show, \u0026ldquo;Another Miss Oh\u0026rdquo;. This dataset contains four levels of QA on difficulty and multi-level character-centered story descriptions. VLEP (Video-and-Language Event Prediction; Lei et al. 2020) contains 28,726 future event prediction examples (along with their rationales) from 10,234 diverse TV Show and YouTube Lifestyle Vlog video clips. Citation Cited as:\n Weng, Lilian. (Jun 2022). Generalized visual language models. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2022-06-09-vlm/.\n Or\n@article{weng2022vlm, title = \u0026quot;Generalized Visual Language Models\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;Lil'Log\u0026quot;, year = \u0026quot;2022\u0026quot;, month = \u0026quot;Jun\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2022-06-09-vlm/\u0026quot; } References [1] Li et al. \u0026ldquo;VisualBERT: A Simple and Performant Baseline for Vision and Language.\u0026quot; arXiv preprint:1908.03557 (2019).\n[2] Wang et al. \u0026ldquo;SimVLM: Simple Visual Language Model Pretraining with Weak Supervision.\u0026quot; ICLR 2022.\n[3] Aghajanyan, et al. \u0026ldquo;CM3: A Causal Masked Multimodal Model of the Internet.\u0026quot; arXiv preprint arXiv: 2201.07520 (2022).\n[4] Tsimpoukelli et al. \u0026ldquo;Multimodal Few-Shot Learning with Frozen Language Models.\u0026quot; NeuriPS 2021.\n[5] Mokady, Hertz \u0026amp; Hertz. \u0026ldquo;ClipCap: CLIP Prefix for Image Captioning.\u0026quot; 2021.\n[6] Chen et al. \u0026ldquo;VisualGPT: Data-efficient Adaptation of Pretrained Language Models for Image Captioning.\u0026quot; arXiv preprint arXiv:2111.09734 (2021).\n[7] Luo et al. \u0026ldquo;A Frustratingly Simple Approach for End-to-End Image Captioning.\u0026quot; arXiv preprint arXiv:2201.12723 (2022).\n[8] Zellers et al. \u0026ldquo;MERLOT: Multimodal neural script knowledge models.\u0026quot; NeuriPS 2021.\n[9] Alayrac et al. \u0026ldquo;Flamingo: a Visual Language Model for Few-Shot Learning.\u0026quot; arXiv preprint arXiv:2204.14198 (2022).\n[10] Yu \u0026amp; Wang et al. \u0026ldquo;CoCa: Contrastive Captioners are Image-Text Foundation Models.\u0026quot; arXiv preprint arXiv:2205.01917 (2022).\n[11] Yang et al. \u0026ldquo;An Empirical Study of GPT-3 for Few-Shot Knowledge-Based VQA.\u0026quot; arXiv preprint arXiv:2109.05014 (2021).\n[12] Su et al. \u0026ldquo;Language models can see: Plugging visual controls in text generation.\u0026quot; arXiv preprint arXiv:2205.02655 (2022).\n[13] Zeng et al. \u0026ldquo;Socratic Models: Composing Zero-Shot Multimodal Reasoning with Language.\u0026quot; arXiv preprint arXiv:2204.00598 (2022).\n","permalink":"https://lilianweng.github.io/posts/2022-06-09-vlm/","summary":"Processing images to generate text, such as image captioning and visual question-answering, has been studied for years. Traditionally such systems rely on an object detection network as a vision encoder to capture visual features and then produce text via a text decoder. Given a large amount of existing literature, in this post, I would like to only focus on one approach for solving vision language tasks, which is to extend pre-trained generalized language models to be capable of consuming visual signals.","title":"Generalized Visual Language Models"},{"content":"Here comes the Part 3 on learning with not enough data (Previous: Part 1 and Part 2). Let’s consider two approaches for generating synthetic data for training.\n Augmented data. Given a set of existing training samples, we can apply a variety of augmentation, distortion and transformation to derive new data points without losing the key attributes. We have covered a bunch of augmentation methods on text and images in a previous post on contrastive learning. For the sake of post completeness, I duplicate the section on data augmentation here with some edits. New data. Given few or even no data points, we can rely on powerful pretrained models to generate a number of new data points. This is especially true in recent years given the fast progress in large pretrained language models (LM). Few shot prompting is shown to be effective for LM to learn within context without extra training. Data Augmentation The goal of data augmentation is to modify the input format (e.g. text wording, visual appearance) while the semantic meaning stays unchanged.\nImage Augmentation Basic Image Processing Operations There are several ways to modify an image while retaining its semantic information. We can use any one of the following augmentation or a composition of multiple operations.\n Random cropping and then resize back to the original size. Random color distortions Random Gaussian blur Random color jittering Random horizontal flip Random grayscale conversion And many more. Check PIL.ImageOps for inspiration. Task-Specific Augmentation Strategies If the downstream task is known, it is possible to learn the optimal augmentation strategies (i.e. what processing operations to use and how to combine them in sequence) to maximize the downstream task performance.\n AutoAugment (Cubuk, et al. 2018) is inspired by neural architecture search, AutoAugment frames the problem of learning best data augmentation operations (i.e. shearing, rotation, invert, etc.) for image classification as an RL problem and looks for the combination that leads to the highest accuracy on the evaluation set. AutoAugment can be executed in adversarial fashion (Zhang, et al 2019). RandAugment (Cubuk et al., 2019) greatly reduces the search space of AutoAugment by controlling the magnitudes of different transformation operations with a single magnitude parameter. Population based augmentation (PBA; Ho et al., 2019) combines PBT (\u0026ldquo;population based training\u0026rdquo;; Jaderberg et al, 2017) with AutoAugment, using the evolutionary algorithm to train a population of children models in parallel to evolve the best augmentation strategies. Unsupervised Data Augmentation (UDA; Xie et al., 2019), among a set of possible augmentation strategies, selects a subset to minimize the KL divergence between the predicted distribution over an unlabelled example and its unlabelled augmented version. Image Mixture Image mixture methods can construct new training examples from existing data points.\n Mixup (Zhang et al., 2018) runs global-level mixture by creating a weighted pixel-wise combination of two existing images $I_1$ and $I_2$: $I_\\text{mixup} \\gets \\alpha I_1 + (1-\\alpha) I_2$ and $\\alpha \\in [0, 1]$. Cutmix (Yun et al., 2019) does region-level mixture by generating a new example by combining a local region of one image with the rest of the other image. $I_\\text{cutmix} \\gets \\mathbf{M}_b \\odot I_1 + (1-\\mathbf{M}_b) \\odot I_2$, where $\\mathbf{M}_b \\in \\{0, 1\\}^I$ is a binary mask and $\\odot$ is element-wise multiplication. It is equivalent to filling the cutout (DeVries \u0026amp; Taylor 2017) region with the same region from another image. Given a query $\\mathbf{q}$, MoCHi (\u0026ldquo;mixing of contrastive hard negatives\u0026rdquo;; Kalantidis et al. 2020) maintains a queue of $K$ negative features $Q={\\mathbf{n}_1, \\dots, \\mathbf{n}_K }$ and sorts these negative features by similarity to the query, $\\mathbf{q}^\\top \\mathbf{n}$, in descending order. The first $N$ items in the queue are considered as the hardest negatives, $Q^N$. Then synthetic hard examples can be generated by $\\mathbf{h} = \\tilde{\\mathbf{h}} / |\\tilde{\\mathbf{h}}|_2$ where $\\tilde{\\mathbf{h}} = \\alpha\\mathbf{n}_i + (1-\\alpha) \\mathbf{n}_j$ and $\\alpha \\in (0, 1)$. Even harder examples can be created by mixing with the query feature, $\\mathbf{h}' = \\tilde{\\mathbf{h}'} / |\\tilde{\\mathbf{h}'}|_2$ where $\\tilde{\\mathbf{h}'} = \\beta\\mathbf{q} + (1-\\beta) \\mathbf{n}_j$ and $\\beta \\in (0, 0.5)$. Text Augmentation Lexical Edits Easy Data Augmentation (EDA; Wei \u0026amp; Zou 2019) defines a set of simple but powerful operations for text augmentation. Given a sentence, EDA randomly chooses and applies one of four simple operations:\n Synonym replacement (SR): Replace $n$ random non-stop words with their synonyms. Random insertion (RI): Place a random synonym of a randomly selected non-stop word in the sentence at a random position. Random swap (RS): Randomly swap two words and repeat $n$ times. Random deletion (RD): Randomly delete each word in the sentence with probability $p$. where $p=\\alpha$ and $n=\\alpha \\times \\text{sentence_length}$, with the intuition that longer sentences can absorb more noise while maintaining the original label. The hyperparameter $\\alpha$ roughly indicates the percent of words in one sentence that may be changed by one augmentation.\nEDA is shown to improve the classification accuracy on several classification benchmark datasets compared to baseline without EDA. The performance lift is more significant on a smaller training set. All the four operations in EDA help improve the classification accuracy, but get to optimal at different $\\alpha$\u0026rsquo;s.\nFig. 1. EDA leads to performance improvement on several classification benchmarks. (Image source: Wei \u0026 Zou 2019) Contextual Augmentation (Kobayashi, 2018) replaces word $w_i$ at position $i$ by sampling from a probability distribution learned by a bidirectional LM such as BERT, $p(.\\mid S\\setminus{w_i})$. In this way, the words are substituted by synonyms, or similar words suitable for the context. To guarantee such operations do not alter the labels, the LM is fit to be label-conditioned bidirectional LM. Conditional BERT (CBERT; Xing Wu et al. 2018) extends BERT to predict masked tokens conditioned on the class label and can be used for contextual augmentation prediction.\nBack-translation Back-translation produces augmented data by translating text samples to another language and then translating them back. The translation happens in two ways and both directions should have decent enough performance to avoid significant loss of semantic meaning.\nMix-up It is also possible to apply Mixup to text (Guo et al. 2019) but on the embedding space to obtain some performance gain. The proposed method relies on a specially designed model architecture to operate the prediction on the word or sentence embedding. Adding adversarial noise in the embedding space as a way of data augmentation is shown to improve the generalization of model training (Zhu et al. 2019).\nAudio Augmentation Here is a list of several commonly used audio data augmentation methods, operated on raw audio or spectrograms, summarized by Wang \u0026amp; van den Oord (2021).\nAudio mixup. Given two audio clips $\\mathbf{x}_1$ and $\\mathbf{x}_2$, the mixed-up version $\\hat{\\mathbf{x}} = \\alpha \\mathbf{x}_1 + (1-\\alpha)\\mathbf{x}_2$ should be associated with the label of the more dominant input. The audio mixup augments the data with more realistic noise.\nTime masking. A small consecutive chunk of the audio can be masked without losing semantic information.\nFrequency masking. A small amount of frequency components on the spectrogram can be dropped off and it should not change the associated label.\nFrequency shift. The spectrogram can be shifted by an integer between $[-F, F]$, where $F$ is the maximum shift size. It is a cheap augmentation to change the pitch of the audio.\nArchitectural Augmentation Models with dropout layers can create augmented samples by applying different dropout masks on the same input sample. For example, in the contrastive learning model SimCSE (Guo et al. 2021), a sample is simply fed into the encoder twice with different dropout masks and these two versions are the positive pair where the other in-batch samples are considered as negative pairs.\nDropout augments data by adding noise onto the internal representation of the model. It can be applied in a more structured way, such as in cutoff (Shen et al. (2020)), where random chunks of the token embedding matrix are removed.\nData Synthesis Given that generating high-quality, photorealistic images is a lot more difficult than generating human-like natural language text and recent success with large pretrained language models, this section only focuses on text generation. To read more on how to synthesize realistic images, check posts on GAN, VAE, flow and diffusion models.\nLanguage Model as Noisy Annotator Wang et al. (2021) explored ways to leverage GPT-3 as a weak annotator via few-shot prompting, achieving 10x cheaper than human labeling. The paper argues that by using data labeled by GPT-3, it essentially performs self-training: The predictions on unlabeled samples apply entropy regularization on the model to avoid high class overlaps so as to help improve the model performance.\nFig. 2. Illustration of how to use GPT-3 to generate more training data with the human-in-the-loop active learning pipeline to improve the data quality. (Image source: Wang et al. 2021) GPT-3-labeled samples selected by active learning with highest uncertainty are sent to human labelers to be re-annotated. The few-shot prompt contains a small number of human labeled examples and thus the labeling cost is restricted. Synthetic samples are ranked by predicted logits of label $y$ and those with the lowest scores go through relabeling.\nGPT-3 labeling achieves better results in the low-cost regime, but has a gap with human labeling when enough money is spent on data collection. This implies the following inequation, although to what extent \u0026ldquo;a lot\u0026rdquo; or \u0026ldquo;noisy\u0026rdquo; means depends on the task details.\n A lot of high-quality data \u0026gt; A lot of noisy data \u0026gt; A little high quality data.\n Fig. 3. GPT-3 labeling technique improves the classification performance in the low-cost regime. (Image source: Wang et al. 2021) Language Model as Data Generator If enough training dataset for text classification tasks are available, we can fine-tune language models to synthesize more training samples conditioned on labels (Anaby-Tavor et al. 2019, Kumar et al. 2021).\nLanguage-model-based data augmentation (LAMBADA; Anaby-Tavor et al. 2019) takes such an idea, where the process involves fine-tuning both a classifier and a sample generation model.\n Train a baseline classifier using the existing training dataset: $h = \\mathcal{A}(\\mathcal{D}_\\text{train})$. Independently of step 1, a LM $\\mathcal{M}$ is fine-tuned on $\\mathcal{D}_{\\text{train}}$ to obtain $\\mathcal{M}_{\\text{tuned}}$. Synthesize a labeled dataset $\\mathcal{D}^*$ by generating the continuation of the sequence y[SEP] until EOS using $\\mathcal{M}_\\text{tuned}$. Filter synthesized dataset by, (1) Verifying that the predicted label is correct $h(x)=y$; (2) Selecting the top ranked samples when they are ranked by the classifier probability. $\\mathcal{D}_\\text{syn} \\subset \\mathcal{D}^*$. They generate 10x more samples needed for augmentation and only the top 10% synthesized samples with highest confidence scores remain. The final classifier is trained on $\\mathcal{D}_\\text{syn} \\cup \\mathcal{D}_\\text{train}$ . The process can be repeated multiple times, but it is unclear whether the benefit would quickly diminish or the repetitive process would bring in self-bias.\nFig. 4. Accuracy of LAMBADA vs. other generative approaches over all datasets and classifiers. (Image source: Anaby-Tavor et al. 2019) To simplify LAMBADA, we can actually remove the dependency of a fine-tuned generation model and an existing training dataset of a decent size (Step 2 above). Unsupervised data generation (UDG; Wang et al. 2021) relies on few-shot prompting on a large pretrained language model to generate high-quality synthetic data for training. Opposite to the above approach where LM is asked to predict $y$ given $\\mathbf{x}$, UDG instead synthetizes the inputs $\\mathbf{x}$ given labels $y$. Then a task-specific model is trained on this synthetic dataset.\nSchick \u0026amp; Schutze (2021) proposed a similar idea but on the NLI task instead of classification, asking PLM to write sentence pairs that are similar or different while the model is prompted with task-specific instructions.\nFig. 5. Illustration of the unsupervised data generation (UDG) framework. (Image source: Wang et al., 2021) The few-shot prompts of UDG contain a small number of unlabeled examples, as well as a task-specific natural language description of the desired label. Because some generated examples are noisy, they implemented noisy label annealing (NLA) techniques to filter potentially misaligned samples out during the training processes. NLA gradually removes noisy training signals in time during training when the model starts to disagree with its pseudo label with high confidence. At each training step $t$, a given example $(\\mathbf{x}_i, \\hat{y}_i)$ is considered noisy and should be removed if:\n The model predicted probability is higher than a threshold $p(\\bar{y}_i \\vert \\mathbf{x}_i) \u0026gt; \\mu_t$ where $\\bar{y}_i = \\arg\\max_y p(y \\vert \\mathbf{x}_i)$; And the predicted label is different from the synthetic label, $\\bar{y}_i \\neq \\hat{y}_i$. Note that the threshold $\\mu_t$ is time-dependent, initialized as 0.9 and then gradually annealed to $1/\\text{num_of_classes}$ in time.\nAs shown in their experiments, the improvement of UDG over few-shot inference is quit significant, where NLA brings in some extra boost. The results are even comparable with supervised fine-tuning on several cases.\nFig. 6. Comparison of accuracy of UDG and other methods on different classification datasets. (Image source: Wang et al., 2021) Han et al (2021) achieved SOTA results on translation tasks using few-shot data generation, distillation and back-translation. The proposed method contains the following steps, assuming no access to paired translation data:\n Zero-shot Generation. First use the zero-shot translation ability of a pre-trained LM to generate translations for a small set of unlabeled sentences. Few-shot Generation. Then amplify these zero-shot translations by using them as few-shot demonstrations to gather an even larger synthetic dataset. Distillation. Fine-tune the model on this dataset. The translation task is formulated as a language modeling task [L1] \u0026lt;seq1\u0026gt; [[TRANSLATE]] [L2] \u0026lt;seq2\u0026gt;. given a pair of two sequences \u0026lt;seq1, seq2\u0026gt; in two different languages. At test-time, the LM is prompted with [L1] \u0026lt;seq\u0026gt; [[TRANSLATE]] [L2] and a candidate translation \u0026lt;sampledSeq\u0026gt; is parsed from the sampled completion. Back-translation. Continue fine-tuning on the back-translation dataset where the order of samples is reversed, \u0026lt;sampledSeq, seq\u0026gt;. Step 1-4 can be repeated. Fig. 7. Algorithm of using distillation and back-translation to train a language model on translation tasks. (Image source: Han et al. 2021) The success of the above method depends on a good pretrained LM to kick off the initial translation dataset. Iterative few-shot generation and distillation with back-translation is an effective way to extract and refine the translation capability out of a pretrained LM and further to distill that into a new model.\nFig. 8. Comparison of BLEU scores of the translation models of different training runs using: only distillation, back-translation, both and with more monolingual training data. (Image source: Han et al. 2021) How to Quantify Generated Data Quality? Given all the generated data, either by data augmentation or data synthesis, how can we quantify data quality in terms of how they improve model generalization? Gontijo-Lopes et al. (2020) introduced two dimensions to track, affinity and diversity.\n Affinity is a model-sensitive metric for distribution shift, quantifying how much an augmentation shifts the training data distribution from what a model learned. Definition: The performance difference between the model tested on clean data vs augmented data, while the model is trained on clean data. As a comparison, KL can also measure distribution shift but does not consider the model performance. Diversity is a measure of augmentation complexity, measuring the complexity of the augmented data with respect to the model and learning procedure. Definition: The final training loss of a model trained with a given augmentation. Another potential diversity measure is the entropy of the transformed data. A third potential diversity measure is the training time needed for a model to reach a given training accuracy threshold. All three metrics above are correlated. The final model performance is dependent on both metrics to be high enough.\nFig. 9. (a) Left: A scatter plot of affinity vs diversity metric, where each point represents a different augmentation method and its color indicates the final test accuracy. (b) Right: The conceptual illustration of the relationship between clean and augmented data in different regions of affinity and diversity metrics. (Image source: Gontijo-Lopes et al. 2020) There are many quantitative metrics on relevancy and diversity, in different formations depending on whether a reference is available, such as perplexity, BLEU for text and inception score for images. I\u0026rsquo;m skipping the list of concrete quantitative metrics on quality here, given it could be very long.\nTraining with Noisy Data It is convenient to collect a large amount of noisy data via model generation or data augmentation, but it is hard to guarantee that augmented and generated data can be 100% accurate. Knowing that deep neural networks can easily overfit noisy labels and \u0026ldquo;memotize\u0026rdquo; corrupted labels, we can apply the techniques for training on noisy labels (noise-robust training) when using generated data to stabilize and optimize the performance. Please check this survey paper (Song et al. 2021) on learning from noisy labels for a more thorough coverage of related work.\nRegularization and Robust Architecture Generally speaking, mechanisms designed for avoiding overfitting should help improve training robustness when working with moderately noisy data, such as weight decay, dropout, batch normalization. In fact, good data augmentation (i.e. only non-essential attributes are modified) can be considered as a way of regularization as well.\nA different approach is to enhance the network with a dedicated noisy adaptation layer to approximate the unknown projection of label corruption (Sukhbaatar et al. 2015, Goldberger \u0026amp; Ben-Reuven, 2017).\nSukhbaatar et al. (2015) introduced an extra linear layer $Q$ into the network architecture to adapt the predictions to match the noisy label distribution. The noise matrix $Q$ is initially fixed to the identity function while only the base model parameters is updated. After some time, $Q$ starts to be updated and expected to capture the noise in the data. The noise matrix is trained with regularization to encourage it to match the noise distribution while keeping the base model prediction accurate for true labels.\nFig. 10. (a) Left: A noise matrix $Q$ is added between softmax and the final output for the loss. (b) Right: The noise matrix $Q$ is fixed at the identity function initially and only gets updated with regularization after some training. (Image source: Sukhbaatar et al. 2015) However, it is hard to guarantee such a noise matrix layer would only capture the noise transition distribution and it is actually non-trivial to learn. Goldberger \u0026amp; Ben-Reuven (2017)) proposed to add an additional softmax layer end-to-end with the base model and apply the EM algorithm by treating the correct labels as latent random variable and the noise processes as a communication channel with unknown parameters.\nRobust Learning Objective Besides the most commonly used cross entropy loss, some other choices of learning objectives are shown to be more robust to noisy labels.\nFor example, MAE (mean absolute error) is more robust to noisy labels than CCE (categorical cross entropy), as it treats every sample equally (Ghosh et al. 2017). Lack of different weighting among training samples of MAE lead to significantly longer training time. Motivated by the tradeoff between MAE and CCE, Zhang \u0026amp; Sabuncu (2018) proposed generalized cross entropy (GCE), a generalization of CCE loss to be robust to noisy data.\nTo exploit the benefits of both the noise-robustness provided by MAE and the implicit weighting scheme of CCE, GCE adopts the the negative Box-Cox transformation as a loss function:\n$$ \\mathcal{L}_q(f(\\mathbf{x}_i, y_i = j)) = \\frac{1 - f^{(j)}(\\mathbf{x}_i)^q}{q} $$\nwhere $f^{(j)}$ denotes the $j$-th element of $f(.)$ and $q \\in (0, 1]$. $\\mathcal{L}_q$ is equivalent to CCE when $q \\to 0$ and becomes MAE when $q=1$. Empirical experiments show that there exists a threshold of $q$ with which overfitting never emerges and the noisier the data the higher such a threshold should be.\nGiven true and predicted labels, $y_i, \\hat{y}_i \\in \\{0, 1\\}$ and let $u_i=y_i \\cdot \\hat{y}_i$, the zero-one loss, $\\mathcal{L}_{01}(\\mathbf{u}) = \\sum_{i=1}^n \\mathbb{1}[u_i \u0026lt; 0]$, is another learning subjective shown to be robust to noisy data. Minimizing the empirical risk with the zero-one loss is shown to be equivalent to minimizing the empirical adversarial (worse-case) risk (Hu et al 2018). Because the worst-case risk is the upper bound of the classification risk of the clean data distribution, minimizing the worst-case risk can lead to decreased true risk, which makes the zero-one loss especially robust. However, the zero-one loss is non-differentiable and cannot be optimized directly. One solution is to approximate an upper bound of the zero-one loss and to minimize the upper bound loss instead.\nThe hinge loss, $\\mathcal{L}_\\text{hinge}(\\mathbf{u}) = \\sum_{i=1}^n \\max(0, 1 - u_i)$, defines a rough upper bound of the zero-one loss. Lyu \u0026amp; Tsang (2020) proposed a curriculum loss (CL), which is a tighter upper bound compared to a conventional surrogate loss like the hinge loss, $\\mathcal{L}_\\text{01}(\\mathbf{u}) \\leq \\mathcal{L}_\\text{CL}(\\mathbf{u}) \\leq \\mathcal{L}_\\text{hinge}(\\mathbf{u})$.\n$$ \\mathcal{L}_\\text{CL}(\\mathbf{u}) = \\min_{\\mathbf{w}\\in\\{0,1\\}^n}\\max(\\sum_{i=1}^n w_i \\ell(u_i), n - \\sum_{i=1}^n w_i + \\sum_{i=1}^n\\mathbb{1}[u_i \u0026lt; 0]) $$\nwhere $\\ell(u_i)$ is a base surrogate loss for the zero-one loss (e.g. hinge loss) and the optimal weighting variable $\\mathbf{w}$ is to be learned.\nGiven a label corruption rate $\\rho$, the noise pruned curriculum loss (NPCL) is constructed based on the intuition that an ideal model should correctly classify $n(1-\\rho)$ samples with clean labels but misclassify $n\\rho$ corrupted labels. If $\\rho$ is a known prior, we would know how many samples (with largest losses) to be pruned. Assuming $\\ell(u_1) \\leq \\dots \\leq \\ell(u_n)$, then $u_{n(1-\\rho)+1} = \\dots = u_n =0$ and the following NPCL is the basic CL for only $n(1-\\rho)$ samples:\n$$ \\text{NPCL}(\\mathbf{u}) = \\min_{\\mathbf{w}\\in\\{0,1\\}^{n(1-\\rho)}} \\max(\\sum_{i=1}^{n(1-\\rho)} w_i \\ell(u_i), n(1-\\rho) - \\sum_{i=1}^{n(1-\\rho)} w_i) $$\nWhen experimenting on CIFAR-10, NPCL is comparable with GCE and performs better when the noise rate increases.\nLabel Correction Since it is known some labels are incorrect, noise-robust training can explicitly take the label correction into consideration.\nOne approach is to rely on the estimation of a noise transition matrix and use that to correct the forward or backward loss, named F-correction (Patrini et al. 2017). Let’s first assume that there are $k$ classes and the noise transition matrix $C \\in [0, 1]^{k\\times k}$ is observable and the label flipping probability does not depend on the sample input but only the label (i.e. known as random classification noise, RCN). Let $\\tilde{y}$ denote a corrupted label. Each entry of $C$ represents the probability of one label flipping to another1,\n$$ C_{ij} = p(\\tilde{y}= j \\vert y =i, \\mathbf{x}) \\approx p(\\tilde{y}= j \\vert y =i) $$\nThen we can proceed a forward label correction procedure to incorporate the prior knowledge of noisy transition matrix into the prediction.\n$$ \\begin{aligned} \\mathcal{L}(\\hat{p}(\\tilde{y}\\vert\\mathbf{x}), y) \u0026amp;= - \\log \\hat{p}(\\tilde{y}=i\\vert\\mathbf{x}) \\\\ \u0026amp;= - \\log \\sum_{j=1}^k p(\\tilde{y}=i\\vert y=j) \\hat{p}(y=j\\vert\\mathbf{x}) \\\\ \u0026amp;= - \\log \\sum_{j=1}^k C_{ji} \\hat{p}(y=j\\vert\\mathbf{x}) \\end{aligned} $$\nIn matrix form, we have $\\mathcal{L}(\\hat{p}(y \\vert \\mathbf{x})) = - \\log C^\\top \\hat{p}(y \\vert \\mathbf{x})$. However, such a noise transition matrix is usually unknown. If we have access to a clean dataset, the noise matrix $C$ can be estimated (Hendrycks et al. 2018) by calculating confusion matrix on the clean data. Let’s denote a clean trusted dataset as $\\mathcal{D}_c$ and a noisy dataset as $\\mathcal{D}_n$ going forward.\n$$ \\hat{C}_{ij} = \\frac{1}{\\vert \\mathcal{A}_i\\vert} \\sum_{\\mathbf{x} \\in \\mathcal{A}_i} \\hat{p}(\\tilde{y}=j \\vert y=i, \\mathbf{x}) \\approx p(\\tilde{y}=j \\vert y=i) $$\nwhere $\\mathcal{A}_i$ is a subset of data points from $\\mathcal{D}_c$ with label $i$.\nLet $f(x) = \\hat{p}(\\tilde{y} \\vert \\mathbf{x}; \\theta)$ and this model should be trained with $\\mathcal{L}(f(\\mathbf{x}), y)$ on clean data $\\mathcal{D}_c$ and with $\\mathcal{L}(\\hat{C}^\\top f(\\mathbf{x}), \\hat{y})$ on noisy data $\\mathcal{D}_n$.\nFig. 11. Algorithm of gold loss correction (GLC), estimating the noise transition matrix with a trusted dataset. (Image source: Hendrycks et al. 2018) If the trusted training dataset $\\mathcal{D}_c$ gets large, we can train a neural network only on clean data and distill its knowledge into the primary model (i.e. the final model to make predictions at test time) using corrected pseudo labels (Li et al. 2017). The primary model is trained on the entire dataset, $\\mathcal{D} = \\mathcal{D}_c \\cup \\mathcal{D}_n$. Optionally the \u0026ldquo;side\u0026rdquo; information of label relations in the knowledge graph, if available, can be incorporated into distillation to help the robustness of the predictions of the network that is trained on limited data.\nThe label correction distillation works as following:\n First train an auxiliary model $f_c$ from the small clean dataset $\\mathcal{D}_c$ to provide a soft label for each sample $x_i$, $s_i = \\delta(f_c(\\mathbf{x}_i)/T)$ is the sigmoid activation with temperature $T$. Because the clean dataset is not large, $f_c$ is likely to overfit, Li et al. (2017) turn to a knowledge graph $\\mathcal{G}$ that defines the relations in the label space and propagate the prediction among labels accordingly. The new soft label is donated as $\\hat{s}_i = \\mathcal{G}(s_i)$. The primary model $f$ is trained with predictions from $f_c$ to imitate, $$ \\mathcal{L}(y_i, f(\\mathbf{x}_i)) = \\text{CE}(\\underbrace{\\lambda y_i + (1 - \\lambda) \\hat{s}_i}_\\text{pseudo label}, f(\\mathbf{x}_i)) $$\nSample Reweighting and Selection Some samples may be more likely to have inaccurate labels than others. Such estimation gives us intuition on which samples should be weighted less or more in the loss function. However, considering two types of biases in training data, class imbalance and noisy labels, there is actually a contradictory preference \u0026mdash; We would prefer samples with larger loss to balance the label distribution but those with smaller loss for mitigating the potential noise. Some work (Ren et al. 2018) thus argue that in order to learn general forms of training data biases, it is necessary to have a small unbiased validation to guide training. The sample reweighting methods presented in this section all assume access to a small trusted set of clean data.\nConsidering a binary classification task with random classification noise, $y, \\hat{y} \\in \\{-1, +1\\}$, the label flipping probabilities, $\\rho_{-1}, \\rho_{+1} \\in [0, 0.5)$, are defined as:\n$$ \\rho_{-1} = P(\\tilde{y} = +1 \\vert y=-1)\\quad\\rho_{+1} = P(\\tilde{y}=-1 \\vert y =+1) $$\nLiu \u0026amp; Tao (2015) applies importance reweighting to adjust the weighted distribution of observed $\\hat{y}$ to match the distribution of unobservable $y$. Let $\\mathcal{D}$ be the true data distribution and $\\mathcal{D}_\\rho$ be the corrupted version.\n$$ \\begin{aligned} \\mathcal{L}_{\\ell,\\mathcal{D}}(f) \u0026amp;= \\mathbb{E}_{(\\mathbf{x},y)\\sim \\mathcal{D}}[\\ell(f(\\mathbf{x}), y)] \\\\ \u0026amp;= \\mathbb{E}_{(\\mathbf{x},\\tilde{y})\\sim \\mathcal{D}_\\rho} \\Big[ \\frac{P_\\mathcal{D}(\\mathbf{x}, y=\\tilde{y})}{P_{\\mathcal{D}_\\rho}(\\mathbf{x}, \\tilde{y})} \\ell(f(\\mathbf{x}), \\tilde{y}) \\Big] \\\\ \u0026amp;= \\mathbb{E}_{(\\mathbf{x},\\tilde{y})\\sim \\mathcal{D}_\\rho} \\Big[ \\frac{P_\\mathcal{D}(y=\\tilde{y} \\vert \\mathbf{x})}{P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert \\mathbf{x})} \\ell(f(\\mathbf{x}), \\tilde{y}) \\Big] \u0026amp; \\text{; because }P_\\mathcal{D}(\\mathbf{x})=P_{\\mathcal{D}_\\rho}(\\mathbf{x}) \\\\ \u0026amp;= \\mathbb{E}_{(\\mathbf{x},\\tilde{y})\\sim \\mathcal{D}_\\rho} [ w(\\mathbf{x}, \\hat{y})\\ell(f(\\mathbf{x}), \\tilde{y}) ] = \\mathcal{L}_{w\\ell,\\mathcal{D}}(f) \\end{aligned} $$\nBecause,\n$$ \\begin{aligned} P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert \\mathbf{x}) \u0026amp;= P_\\mathcal{D}(y = \\tilde{y} \\vert \\mathbf{x}) P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert y=\\tilde{y}) + P_\\mathcal{D}(y = - \\tilde{y} \\vert \\mathbf{x}) P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert y = - \\tilde{y}) \\\\ \u0026amp;= P_\\mathcal{D}(y = \\tilde{y} \\vert \\mathbf{x}) (1 - P_{\\mathcal{D}_\\rho}(- \\tilde{y} \\vert y=\\tilde{y})) + (1 - P_\\mathcal{D}(y = \\tilde{y} \\vert \\mathbf{x})) P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert y = - \\tilde{y}) \\\\ \u0026amp;= P_\\mathcal{D}(y = \\tilde{y} \\vert \\mathbf{x}) (1 - \\rho_{\\tilde{y}}) + (1 - P_\\mathcal{D}(y = \\tilde{y} \\vert \\mathbf{x})) \\rho_{-\\tilde{y}} \\\\ \u0026amp;= P_\\mathcal{D}(y = \\tilde{y} \\vert \\mathbf{x})(1 - \\rho_{\\tilde{y}} - \\rho_{-\\tilde{y}}) + \\rho_{-\\tilde{y}} \\end{aligned} $$\nThus the weight assigned to a noisy sample is,\n$$ w(x, \\tilde{y}) = \\frac{P_\\mathcal{D}(y=\\tilde{y} \\vert \\mathbf{x})}{P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert \\mathbf{x})} = \\frac{P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert \\mathbf{x}) - \\rho_{-\\tilde{y}}}{(1-\\rho_0-\\rho_1) P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert \\mathbf{x})} $$\nwhere $P_{\\mathcal{D}_\\rho}(\\tilde{y} \\vert \\mathbf{x})$ can be estimated using a simple logistic regression, but estimating the note rates is more challenging. Naive cross-validation can work out but is costly as the quality depends on the amount of trusted labels available. The paper approximates the upper bounds for noise rates first, $\\rho_\\tilde{y} \\leq P_{\\mathcal{D}_\\rho}(- \\tilde{y} \\vert \\mathbf{x})$ and then use a mild assumption to efficiently estimate them, $\\hat{\\rho}_{\\tilde{y}} = \\min_{\\mathbf{x} \\in {\\mathbf{x}_1, \\dots, \\mathbf{x}_n}} \\hat{P}_{\\mathcal{D}_\\rho}(- \\tilde{y} \\vert \\mathbf{x})$. In their experiments, the advantage of importance reweighting only varies across datasets and is more beneficial when the noise rates are high in general.\nSample reweighting schemes can be learned by a separate network. Learning to reweight (L2R; Ren et al. 2018) is a meta-learning approach to directly optimize the weights in pursuit of best validation performance on a known set of clean data. Each example gets assigned with the weight based on its gradient direction. The weighted loss to minimize $\\theta^*(\\mathbf{w})$ involves a set of training weights $\\{w_i\\}_{i=1}^n$ as unknown hyperparameters. These sample training weights $w_i$ are learned to minimize the loss on this unbiased validate set, $\\mathcal{D}_c = \\{x^\\text{valid}_j\\}_{j=1}^m$.\n$$ \\begin{aligned} \\theta^{*}(\\mathbf{w}) \u0026amp;= \\arg\\min_\\theta \\sum_{i=1}^n w_i f(x_i; \\theta) \\\\ \\text{where optimal }\\mathbf{w}^{*} \u0026amp;= \\arg\\min_{\\mathbf{w}, \\mathbf{w} \\geq \\mathbf{0}} \\frac{1}{m} \\sum_{j=1}^m f(\\mathbf{x}^\\text{valid}_j; \\theta^{*}(\\mathbf{w})) \\end{aligned} $$\nThe learning process involves two nested loops of optimization, so pretty expensive, 3x training time.\nFig. 12. Illustration of updates implemented by second order automatic differentiation. (Image source: Ren et al. 2018) They ran experiments on (1) two-class MNIST to test the robustness of L2R when the class distribution is imbalanced and (2) CIFAR-10 with noisy labels. L2R is shown to be better than other baseline methods at the time on both tasks.\nFig. 13. Left: Imbalanced classes on MNIST (class 4 and 9); Right: Effect of the number of clean samples. Task is on CIFAR-10 with 40% of data flipped to label 3. (Image source: Ren et al. 2018) MentorNet (Jiang et al. 2018) uses teach-student curriculum learning to weight data. It incorporates two different networks, a mentor and a student. The mentor network provides a data-driven curriculum (i.e. sample training weighting scheme) for the student to focus on learning likely correct labels.\nLet $g_\\psi$ be the MentorNet parameterized by $\\psi$ , $f_\\theta$ be the StudentNet parametrized by $\\theta$ and $G$ be a predefined curriculum parameterized by $\\lambda$. Given the training data $\\mathcal{D} = \\{(\\mathbf{x}_i, y_i)\\}_{i=1}^n$ for a $k$-class classification task, the MentorNet needs to predict a time-varying latent weight variable $\\mathbf{w} \\in [0, 1]^{n \\times k}$ to guide the learning of StudentNet, taking an intermediate feature processed by StudentNet $f$ , $\\mathbf{z}_i = \\phi_{f_\\theta}(\\mathbf{x}_i, y_i)$:\n$$ g_{\\psi^{*}}(\\mathbf{z}_i) = \\arg\\min_{w_i \\in [0,1]} \\mathcal{L}(\\theta, \\mathbf{w}), \\forall i \\in [1, n] $$\nStudentNet learns to minimize the following learning objective,\n$$ \\begin{aligned} \\mathcal{L}(\\theta, \\mathbf{w}) \u0026amp;= \\frac{1}{n}\\sum_{i=1}^n \\mathbf{w}_i^\\top \\ell(y_i, f_\\theta(\\mathbf{x}_i)) + G_\\lambda(\\mathbf{w}) + \\alpha |\\theta|^2_2 \\\\ \u0026amp;= \\frac{1}{n}\\sum_{i=1}^n g_\\psi(\\mathbf{z}_i)^\\top \\ell_i + G_\\lambda(\\mathbf{w}) + \\alpha |\\theta|^2_2 \u0026amp; \\text{; Let }\\ell_i = \\ell(y_i, f_\\theta(\\mathbf{x}_i)) \\\\ \\end{aligned} $$\nThe mentor network $g_\\psi$ is trained with cross entropy on the input $(\\phi_{f_\\theta}(\\mathbf{x}_i, y_i), w^{*}_i)$ , where $v^*_i=1$ if $y_i$ is known to be a correct label, otherwise 0. The architecture of MentorNet does not have to be very complicated. In the paper, they adopted a LSTM layer to capture the prediction variance in time.\nFig. 14. Model architecture of MentorNet and StudentNet which are trained simultaneously, where MentorNet predicts the sample weights for StudentNet to train on. (Image source: Jiang et al. 2018) Different from MentorNet where one network explicitly learns weighting scheme and curriculum for the other network, Co-teaching (Han et al. 2018) trains two neural networks, $f_1$ and $f_2$, simultaneously and lets them teach each other by feeding data to each other selectively. Co-teaching consists of three steps:\n First, each network feeds forward the current mini-batch and selects samples with potentially clean labels; Then two networks exchange information on which samples in the batch should be used for training. Small-loss instances are selected as they are more likely to be associated with correct labels. The percentage of the batch to select is determined by a time-dependent function $R(T)$. The value of $R(T)$ decreases in time because the network is more likely to overfit and memorize noisy labels as training progresses and thus we use a smaller sampling percentage to keep the selected data quality high. Finally, each network runs back-propagation updates with the data selected by its peer. According to their experiments, co-teaching performs better than F-correction where the noise rates are high or the corruption transition matrix is not symmetric.\nFig. 15. Algorithm of co-teaching in which two networks are trained separately in parallel and each selects samples for the other to train on. (Image source: Han et al. 2018) Citation Cited as:\n Weng, Lilian. (Apr 2022). Learning with not enough data part 3: data generation. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2022-04-15-data-gen/.\n Or\n@article{weng2022datagen, title = \u0026quot;Learning with not Enough Data Part 3: Data Generation\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;Lil'Log\u0026quot;, year = \u0026quot;2022\u0026quot;, month = \u0026quot;Apr\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2022-04-15-data-gen/\u0026quot; } Reference [1] Zhang et al. \u0026ldquo;Adversarial AutoAgument\u0026rdquo; ICLR 2020.\n[2] Kumar et al. \u0026ldquo;Data Augmentation using Pre-trained Transformer Models.\u0026quot; AACL 2020 Workshop.\n[3] Anaby-Tavor et al. \u0026ldquo;Not enough data? Deep learning to rescue!\u0026quot; AAAI 2020.\n[4] Wang et al. \u0026ldquo;Want To Reduce Labeling Cost? GPT-3 Can Help.\u0026quot; EMNLP 2021.\n[5] Wang et al. \u0026ldquo;Towards Zero-Label Language Learning.\u0026quot; arXiv preprint arXiv:2109.09193 (2021).\n[6] Schick \u0026amp; Schutze. Generating Datasets with Pretrained Language Models.\u0026quot; EMNLP 2021.\n[7] Han et al. \u0026ldquo;Unsupervised Neural Machine Translation with Generative Language Models Only.\u0026quot; arXiv preprint arXiv:2110.05448 (2021).\n[8] Guo et al. \u0026ldquo;Augmenting data with mixup for sentence classification: An empirical study.\u0026quot; arXiv preprint arXiv:1905.08941 (2019).\n[9] Ekin D. Cubuk et al. \u0026ldquo;AutoAugment: Learning augmentation policies from data.\u0026quot; arXiv preprint arXiv:1805.09501 (2018).\n[10] Daniel Ho et al. \u0026ldquo;Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules.\u0026quot; ICML 2019.\n[11] Cubuk \u0026amp; Zoph et al. \u0026ldquo;RandAugment: Practical automated data augmentation with a reduced search space.\u0026quot; arXiv preprint arXiv:1909.13719 (2019).\n[12] Zhang et al. \u0026ldquo;mixup: Beyond Empirical Risk Minimization.\u0026quot; ICLR 2017.\n[13] Yun et al. \u0026ldquo;CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features.\u0026quot; ICCV 2019.\n[14] Kalantidis et al. \u0026ldquo;Mixing of Contrastive Hard Negatives\u0026rdquo; NeuriPS 2020.\n[15] Wei \u0026amp; Zou. \u0026ldquo;EDA: Easy data augmentation techniques for boosting performance on text classification tasks.\u0026quot; EMNLP-IJCNLP 2019.\n[16] Kobayashi. \u0026ldquo;Contextual Augmentation: Data Augmentation by Words with Paradigmatic Relations.\u0026quot; NAACL 2018\n[17] Fang et al. \u0026ldquo;CERT: Contrastive self-supervised learning for language understanding.\u0026quot; arXiv preprint arXiv:2005.12766 (2020).\n[18] Gao et al. \u0026ldquo;SimCSE: Simple Contrastive Learning of Sentence Embeddings.\u0026quot; arXiv preprint arXiv:2104.08821 (2020). [code]\n[19] Shen et al. \u0026ldquo;A Simple but Tough-to-Beat Data Augmentation Approach for Natural Language Understanding and Generation.\u0026quot; arXiv preprint arXiv:2009.13818 (2020) [code]\n[20] Wang \u0026amp; van den Oord. \u0026ldquo;Multi-Format Contrastive Learning of Audio Representations.\u0026quot; NeuriPS Workshop 2020.\n[21] Wu et al. \u0026ldquo;Conditional BERT Contextual Augmentation\u0026rdquo; arXiv preprint arXiv:1812.06705 (2018).\n[22 Zhu et al. \u0026ldquo;FreeLB: Enhanced Adversarial Training for Natural Language Understanding.\u0026quot; ICLR 2020.\n[23] Affinity and Diversity: Quantifying Mechanisms of Data Augmentation Gontijo-Lopes et al. 2020 (https://arxiv.org/abs/2002.08973)\n[24] Song et al. \u0026ldquo;Learning from Noisy Labels with Deep Neural Networks: A Survey.\u0026quot; TNNLS 2020.\n[25] Zhang \u0026amp; Sabuncu. \u0026ldquo;Generalized cross entropy loss for training deep neural networks with noisy labels.\u0026quot; NeuriPS 2018.\n[26] Goldberger \u0026amp; Ben-Reuven. \u0026ldquo;Training deep neural-networks using a noise adaptation layer.\u0026quot; ICLR 2017.\n[27] Sukhbaatar et al. \u0026ldquo;Training convolutional networks with noisy labels.\u0026quot; ICLR Workshop 2015.\n[28] Patrini et al. \u0026ldquo;Making Deep Neural Networks Robust to Label Noise: a Loss Correction Approach\u0026rdquo; CVPR 2017.\n[29] Hendrycks et al. \u0026ldquo;Using trusted data to train deep networks on labels corrupted by severe noise.\u0026quot; NeuriPS 2018.\n[30] Zhang \u0026amp; Sabuncu. \u0026ldquo;Generalized cross entropy loss for training deep neural networks with noisy labels.\u0026quot; NeuriPS 2018.\n[31] Lyu \u0026amp; Tsang. \u0026ldquo;Curriculum loss: Robust learning and generalization against label corruption.\u0026quot; ICLR 2020.\n[32] Han et al. \u0026ldquo;Co-teaching: Robust training of deep neural networks with extremely noisy labels.\u0026quot; NeuriPS 2018. (code)\n[33] Ren et al. \u0026ldquo;Learning to reweight examples for robust deep learning.\u0026quot; ICML 2018.\n[34] Jiang et al. \u0026ldquo;MentorNet: Learning data-driven curriculum for very deep neural networks on corrupted labels.\u0026quot; ICML 2018.\n[35] Li et al. \u0026ldquo;Learning from noisy labels with distillation.\u0026quot; ICCV 2017.\n[36] Liu \u0026amp; Tao. \u0026ldquo;Classification with noisy labels by importance reweighting.\u0026quot; TPAMI 2015.\n[37] Ghosh, et al. \u0026ldquo;Robust loss functions under label noise for deep neural networks.\u0026quot; AAAI 2017.\n[38] Hu et al. \u0026ldquo;Does Distributionally Robust Supervised Learning Give Robust Classifiers? \u0026ldquo; ICML 2018.\n $y=i$ is not a technically correct way to annotate a label being a certain value, since we usually use one-hot encoding (i.e. $\\mathbf{y} = \\mathbf{e}_i$). We use this form for simplicity.\u0026#160;\u0026#x21a9;\u0026#xfe0e;\n ","permalink":"https://lilianweng.github.io/posts/2022-04-15-data-gen/","summary":"Here comes the Part 3 on learning with not enough data (Previous: Part 1 and Part 2). Let’s consider two approaches for generating synthetic data for training.\n Augmented data. Given a set of existing training samples, we can apply a variety of augmentation, distortion and transformation to derive new data points without losing the key attributes. We have covered a bunch of augmentation methods on text and images in a previous post on contrastive learning.","title":"Learning with not Enough Data Part 3: Data Generation"},{"content":"This is part 2 of what to do when facing a limited amount of labeled data for supervised learning tasks. This time we will get some amount of human labeling work involved, but within a budget limit, and therefore we need to be smart when selecting which samples to label.\nNotations Symbol Meaning $K$ Number of unique class labels. $(\\mathbf{x}^l, y) \\sim \\mathcal{X}, y \\in \\{0, 1\\}^K$ Labeled dataset. $y$ is a one-hot representation of the true label. $\\mathbf{u} \\sim \\mathcal{U}$ Unlabeled dataset. $\\mathcal{D} = \\mathcal{X} \\cup \\mathcal{U}$ The entire dataset, including both labeled and unlabeled examples. $\\mathbf{x}$ Any sample which can be either labeled or unlabeled. $\\mathbf{x}_i$ The $i$-th sample. $U(\\mathbf{x})$ Scoring function for active learning selection. $P_\\theta(y \\vert \\mathbf{x})$ A softmax classifier parameterized by $\\theta$. $\\hat{y} = \\arg\\max_{y \\in \\mathcal{Y}} P_\\theta(y \\vert \\mathbf{x})$ The most confident prediction by the classifier. $B$ Labeling budget (the maximum number of samples to label). $b$ Batch size. What is Active Learning? Given an unlabeled dataset $\\mathcal{U}$ and a fixed amount of labeling cost $B$, active learning aims to select a subset of $B$ examples from $\\mathcal{U}$ to be labeled such that they can result in maximized improvement in model performance. This is an effective way of learning especially when data labeling is difficult and costly, e.g. medical images. This classical survey paper in 2010 lists many key concepts. While some conventional approaches may not apply to deep learning, discussion in this post mainly focuses on deep neural models and training in batch mode.\nFig. 1. Illustration of a cyclic workflow of active learning, producing better models more efficiently by smartly choosing which samples to label. To simplify the discussion, we assume that the task is a $K$-class classification problem in all the following sections. The model with parameters $\\theta$ outputs a probability distribution over the label candidates, which may or may not be calibrated, $P_\\theta(y \\vert \\mathbf{x})$ and the most likely prediction is $\\hat{y} = \\arg\\max_{y \\in \\mathcal{Y}} P_\\theta(y \\vert \\mathbf{x})$.\nAcquisition Function The process of identifying the most valuable examples to label next is referred to as \u0026ldquo;sampling strategy\u0026rdquo; or \u0026ldquo;query strategy\u0026rdquo;. The scoring function in the sampling process is named \u0026ldquo;acquisition function\u0026rdquo;, denoted as $U(\\mathbf{x})$. Data points with higher scores are expected to produce higher value for model training if they get labeled.\nHere is a list of basic sampling strategies.\nUncertainty Sampling Uncertainty sampling selects examples for which the model produces most uncertain predictions. Given a single model, uncertainty can be estimated by the predicted probabilities, although one common complaint is that deep learning model predictions are often not calibrated and not correlated with true uncertainty well. In fact, deep learning models are often overconfident.\n Least confident score, also known as variation ratio: $U(\\mathbf{x}) = 1 - P_\\theta(\\hat{y} \\vert \\mathbf{x})$. Margin score: $U(\\mathbf{x}) = P_\\theta(\\hat{y}_1 \\vert \\mathbf{x}) - P_\\theta(\\hat{y}_2 \\vert \\mathbf{x})$, where $\\hat{y}_1$ and $\\hat{y}_2$ are the most likely and the second likely predicted labels. Entropy: $U(\\mathbf{x}) = \\mathcal{H}(P_\\theta(y \\vert \\mathbf{x})) = - \\sum_{y \\in \\mathcal{Y}} P_\\theta(y \\vert \\mathbf{x}) \\log P_\\theta(y \\vert \\mathbf{x})$. Another way to quantify uncertainty is to rely on a committee of expert models, known as Query-By-Committee (QBC). QBC measures uncertainty based on a pool of opinions and thus it is critical to keep a level of disagreement among committee members. Given $C$ models in the committee pool, each parameterized by $\\theta_1, \\dots, \\theta_C$.\n Voter entropy: $U(\\mathbf{x}) = \\mathcal{H}(\\frac{V(y)}{C})$, where $V(y)$ counts the number of votes from the committee on the label $y$. Consensus entropy: $U(\\mathbf{x}) = \\mathcal{H}(P_\\mathcal{C})$, where $P_\\mathcal{C}$ is the prediction averaging across the committee. KL divergence: $U(\\mathbf{x}) = \\frac{1}{C} \\sum_{c=1}^C D_\\text{KL} (P_{\\theta_c} | P_\\mathcal{C})$ Diversity Sampling Diversity sampling intend to find a collection of samples that can well represent the entire data distribution. Diversity is important because the model is expected to work well on any data in the wild, just not on a narrow subset. Selected samples should be representative of the underlying distribution. Common approaches often rely on quantifying the similarity between samples.\nExpected Model Change Expected model change refers to the impact that a sample brings onto the model training. The impact can be the influence on the model weights or the improvement over the training loss. A later section reviews several works on how to measure model impact triggered by selected data samples.\nHybrid Strategy Many methods above are not mutually exclusive. A hybrid sampling strategy values different attributes of data points, combining different sampling preferences into one. Often we want to select uncertain but also highly representative samples.\nDeep Acquisition Function Measuring Uncertainty The model uncertainty is commonly categorized into two buckets (Der Kiureghian \u0026amp; Ditlevsen 2009, Kendall \u0026amp; Gal 2017):\n Aleatoric uncertainty is introduced by noise in the data (e.g. sensor data, noise in the measurement process) and it can be input-dependent or input-independent. It is generally considered as irreducible since there is missing information about the ground truth. Epistemic uncertainty refers to the uncertainty within the model parameters and therefore we do not know whether the model can best explain the data. This type of uncertainty is theoretically reducible given more data Ensemble and Approximated Ensemble There is a long tradition in machine learning of using ensembles to improve model performance. When there is a significant diversity among models, ensembles are expected to yield better results. This ensemble theory is proved to be correct by many ML algorithms; for example, AdaBoost aggregates many weak learners to perform similar or even better than a single strong learner. Bootstrapping ensembles multiple trials of resampling to achieve more accurate estimation of metrics. Random forests or GBM is also a good example for the effectiveness of ensembling.\nTo get better uncertainty estimation, it is intuitive to aggregate a collection of independently trained models. However, it is expensive to train a single deep neural network model, let alone many of them. In reinforcement learning, Bootstrapped DQN (Osband, et al. 2016) is equipped with multiple value heads and relies on the uncertainty among an ensemble of Q value approximation to guide exploration in RL.\nIn active learning, a commoner approach is to use dropout to \u0026ldquo;simulate\u0026rdquo; a probabilistic Gaussian process (Gal \u0026amp; Ghahramani 2016). We thus ensemble multiple samples collected from the same model but with different dropout masks applied during the forward pass to estimate the model uncertainty (epistemic uncertainty). The process is named MC dropout (Monte Carlo dropout), where dropout is applied before every weight layer, is approved to be mathematically equivalent to an approximation to the probabilistic deep Gaussian process (Gal \u0026amp; Ghahramani 2016). This simple idea has been shown to be effective for classification with small datasets and widely adopted in scenarios when efficient model uncertainty estimation is needed.\nDBAL (Deep Bayesian active learning; Gal et al. 2017) approximates Bayesian neural networks with MC dropout such that it learns a distribution over model weights. In their experiment, MC dropout performed better than random baseline and mean standard deviation (Mean STD), similarly to variation ratios and entropy measurement.\nFig. 2. Active learning results of DBAL on MNIST. (Image source: Gal et al. 2017). Beluch et al. (2018) compared ensemble-based models with MC dropout and found that the combination of naive ensemble (i.e. train multiple models separately and independently) and variation ratio yields better calibrated predictions than others. However, naive ensembles are very expensive, so they explored a few alternative cheaper options:\n Snapshot ensemble: Use a cyclic learning rate schedule to train an implicit ensemble such that it converges to different local minima. Diversity encouraging ensemble (DEE): Use a base network trained for a small number of epochs as initialization for $n$ different networks, each trained with dropout to encourage diversity. Split head approach: One base model has multiple heads, each corresponding to one classifier. Unfortunately all the cheap implicit ensemble options above perform worse than naive ensembles. Considering the limit on computational resources, MC dropout is still a pretty good and economical choice. Naturally, people also try to combine ensemble and MC dropout (Pop \u0026amp; Fulop 2018) to get a bit of additional performance gain by stochastic ensemble.\nUncertainty in Parameter Space Bayes-by-backprop (Blundell et al. 2015) measures weight uncertainty in neural networks directly. The method maintains a probability distribution over the weights $\\mathbf{w}$, which is modeled as a variational distribution $q(\\mathbf{w} \\vert \\theta)$ since the true posterior $p(\\mathbf{w} \\vert \\mathcal{D})$ is not tractable directly. The loss is to minimize the KL divergence between $q(\\mathbf{w} \\vert \\theta)$ and $p(\\mathbf{w} \\vert \\mathcal{D})$,\n $$ \\begin{aligned} \\mathcal{L}(\\theta) \u0026= \\text{KL}[q(\\mathbf{w}\\vert\\theta) \\| p(\\mathbf{w} \\vert \\mathcal{D})] \\\\ \u0026= \\int q(\\mathbf{w}\\vert\\theta) \\log \\frac{q(\\mathbf{w}\\vert\\theta)}{p(\\mathbf{w}) p(\\mathcal{D}\\vert \\mathbf{w})} d\\mathbf{w} \\\\ \u0026= \\text{KL}[q(\\mathbf{w}\\vert\\theta) \\| p(w)] - \\mathbb{E}_{q(\\mathbf{w}\\vert\\theta)} [\\log p(\\mathcal{D} \\vert \\mathbf{w})] \\\\ \u0026\\approx \\log q(\\mathbf{w} \\vert \\theta) - \\log p(\\mathbf{w}) p(\\mathcal{D}\\vert \\mathbf{w}) \u0026 \\text{; monte carlo sampling; }q(\\mathbf{w} \\vert \\theta)\\text{ \u0026 }p(\\mathbf{w})\\text{ are close.} \\end{aligned} $$ The variational distribution $q$ is typically a Gaussian with diagonal covariance and each weight is sampled from $\\mathcal{N}(\\mu_i, \\sigma_i^2)$. To ensure non-negativity of $\\sigma_i$, it is further parameterized via softplus, $\\sigma_i = \\log(1 + \\exp(\\rho_i))$ where the variational parameters are $\\theta = \\{\\mu_i , \\rho_i\\}^d_{i=1}$.\nThe process of Bayes-by-backprop can be summarized as:\n Sample $\\epsilon \\sim \\mathcal{N}(0, I)$ Let $\\mathbf{w} = \\mu + \\log(1+ \\exp(\\rho)) \\circ \\epsilon$ Let $\\theta = (\\mu, \\rho)$ Let $f(\\mathbf{w}, \\theta) = \\log q(\\mathbf{w} \\vert \\theta) - \\log p(\\mathbf{w})p(\\mathcal{D}\\vert \\mathbf{w})$ Calculate the gradient of $f(\\mathbf{w}, \\theta)$ w.r.t. to $\\mu$ and $\\rho$ and then update $\\theta$. Uncertainty is measured by sampling different model weights during inference. Loss Prediction The loss objective guides model training. A low loss value indicates that a model can make good and accurate predictions. Yoo \u0026amp; Kweon (2019) designed a loss prediction module to predict the loss value for unlabeled inputs, as an estimation of how good a model prediction is on the given data. Data samples are selected if the loss prediction module makes uncertain predictions (high loss value) for them. The loss prediction module is a simple MLP with dropout, that takes several intermediate layer features as inputs and concatenates them after a global average pooling.\nFig. 3. Use the model with a loss prediction module to do active learning selection. (Image source: Yoo \u0026 Kweon 2019) Let $\\hat{l}$ be the output of the loss prediction module and $l$ be the true loss. When training the loss prediction module, a simple MSE loss $=(l - \\hat{l})^2$ is not a good choice, because the loss decreases in time as the model learns to behave better. A good learning objective should be independent of the scale changes of the target loss. They instead rely on the comparison of sample pairs. Within each batch of size $b$, there are $b/2$ pairs of samples $(\\mathbf{x}_i, \\mathbf{x}_j)$ and the loss prediction model is expected to correctly predict which sample has a larger loss.\n $$ \\begin{aligned} \\mathcal{L}_\\text{loss}(\\mathbf{x}_i, \\mathbf{x}_j) \u0026= \\max\\big( 0, -\\mathbb{1}(l(\\mathbf{x}_i), l(\\mathbf{x}_j)) \\cdot (\\hat{l}(\\mathbf{x}_i) - \\hat{l}(\\mathbf{x}_j)) + \\epsilon \\big) \\\\ \\text{where } \\mathbb{1}(l_i, l_j) \u0026= \\begin{cases} +1 \u0026 \\text{if }l_i l_j \\\\ -1 \u0026 \\text{otherwise} \\end{cases} \\end{aligned} $$ where $\\epsilon$ is a predefined positive margin constant.\nIn experiments on three vision tasks, active learning selection based on the loss prediction performs better than random baseline, entropy based acquisition and core-set.\nFig. 4. Active learning results of loss prediction module based selection, in comparison with other approaches. (Image source: Yoo \u0026 Kweon 2019) Adversarial Setup Sinha et al. (2019) proposed a GAN-like setup, named VAAL (Variational Adversarial Active Learning), where a discriminator is trained to distinguish unlabeled data from labeled data. Interestingly, active learning acquisition criteria does not depend on the task performance in VAAL.\nFig. 5. Illustration of VAAL (Variational adversarial active learning). (Image source: Sinha et al. 2019) The $\\beta$-VAE learns a latent feature space $\\mathbf{z}^l \\cup \\mathbf{z}^u$, for labeled and unlabeled data respectively, aiming to trick the discriminator $D(.)$ that all the data points are from the labeled pool; The discriminator $D(.)$ predicts whether a sample is labeled (1) or not (0) based on a latent representation $\\mathbf{z}$. VAAL selects unlabeled samples with low discriminator scores, which indicates that those samples are sufficiently different from previously labeled ones. The loss for VAE representation learning in VAAL contains both a reconstruction part (minimizing the ELBO of given samples) and an adversarial part (labeled and unlabeled data is drawn from the same probability distribution $q_\\phi$):\n $$ \\begin{aligned} \\mathcal{L}_\\text{VAE} \u0026= \\lambda_1 \\mathcal{L}^\\text{rec}_\\text{VAE} + \\lambda_2 \\mathcal{L}^\\text{adv}_\\text{VAE} \\\\ \\mathcal{L}^\\text{rec}_\\text{VAE} \u0026= \\mathbb{E}[\\log p_\\theta(\\mathbf{x}^l \\vert \\mathbf{z}^l)] - \\beta \\text{KL}(q_\\phi(\\mathbf{z}^l \\vert \\mathbf{x}^l) \\| p(\\mathbf{\\tilde{z}})) + \\mathbb{E}[\\log p_\\theta(\\mathbf{u} \\vert \\mathbf{z}^u)] - \\beta \\text{KL}(q_\\phi(\\mathbf{z}^u \\vert \\mathbf{u}) \\| p(\\mathbf{\\tilde{z}})) \\\\ \\mathcal{L}^\\text{adv}_\\text{VAE} \u0026= - \\mathbb{E}[\\log D(q_\\phi (\\mathbf{z}^l \\vert \\mathbf{x}^l))] - \\mathbb{E}[\\log D(q_\\phi(\\mathbf{z}^u \\vert \\mathbf{u}))] \\end{aligned} $$ where $p(\\mathbf{\\tilde{z}})$ is a unit Gaussian as a predefined prior and $\\beta$ is the Lagrangian parameter.\nThe discriminator loss is:\n $$ \\mathcal{L}_D = -\\mathbb{E}[\\log D(q_\\phi (\\mathbf{z}^l \\vert \\mathbf{x}^l))] - \\mathbb{E}[\\log (1 - D(q_\\phi (\\mathbf{z}^u \\vert \\mathbf{u})))] $$ Fig. 6. Experiment results of VAAL (variational adversarial active learning) on several image classification tasks. (Image source: Sinha et al. 2019 Ablation studies showed that jointly training VAE and discriminator is critical. Their results are robust to the biased initial labeled pool, different labeling budgets and noisy oracle.\nMAL (Minimax Active Learning; Ebrahimiet al. 2021) is an extension of VAAL. The MAL framework consists of an entropy minimizing feature encoding network $F$ followed by an entropy maximizing classifier $C$. This minimax setup reduces the distribution gap between labeled and unlabeled data.\nFig. 7. Illustration of the MAL (minimax active learning) framework. (Image source: Ebrahimiet al. 2021) A feature encoder $F$ encodes a sample into a $\\ell_2$-normalized $d$-dimensional latent vector. Assuming there are $K$ classes, a classifier $C$ is parameterized by $\\mathbf{W} \\in \\mathbb{R}^{d \\times K}$.\n(1) First $F$ and $C$ are trained on labeled samples by a simple cross entropy loss to achieve good classification results,\n $$ \\mathcal{L}_\\text{CE} = -\\mathbb{E}_{(\\mathbf{x}^l, y) \\sim \\mathcal{X}} \\sum_{k=1}^K \\mathbb{1}[k=y] \\log\\Big( \\sigma(\\frac{1}{T} \\frac{\\mathbf{W}^\\top F\\big(\\mathbf{x}^l)}{\\|F(\\mathbf{x}^l)\\|}\\big) \\Big) $$ (2) When training on the unlabeled examples, MAL relies on a minimax game setup\n $$ \\begin{aligned} \\mathcal{L}_\\text{Ent} \u0026= -\\sum^K_{k=1} p(y=k \\vert \\mathbf{u}) \\log p(y=k\\vert \\mathbf{u}) \\\\ \\theta^*_F, \\theta^*_C \u0026= \\min_F\\max_C \\mathcal{L}_\\text{Ent} \\\\ \\theta_F \u0026\\gets \\theta_F - \\alpha_1 \\nabla \\mathcal{L}_\\text{Ent} \\\\ \\theta_C \u0026\\gets \\theta_C + \\alpha_2 \\nabla \\mathcal{L}_\\text{Ent} \\end{aligned} $$ where,\n First, minimizing the entropy in $F$ encourages unlabeled samples associated with similar predicted labels to have similar features. Maximizing the entropy in $C$ adversarially makes the prediction to follow a more uniform class distribution. (My understanding here is that because the true label of an unlabeled sample is unknown, we should not optimize the classifier to maximize the predicted labels just yet.) The discriminator is trained in the same way as in VAAL.\nSampling strategy in MAL considers both diversity and uncertainty:\n Diversity: the score of $D$ indicates how similar a sample is to previously seen examples. A score closer to 0 is better to select unfamiliar data points. Uncertainty: use the entropy obtained by $C$. A higher entropy score indicates that the model cannot make a confident prediction yet. The experiments compared MAL to random, entropy, core-set, BALD and VAAL baselines, on image classification and segmentation tasks. The results look pretty strong.\nFig. 8. Performance of MAL on ImageNet. (Table source: Ebrahimiet al. 2021) CAL (Contrastive Active Learning; Margatina et al. 2021) intends to select contrastive examples. If two data points with different labels share similar network representations $\\Phi(.)$, they are considered as contrastive examples in CAL. Given a pair of contrastive examples $(\\mathbf{x}_i, \\mathbf{x}_j)$, they should\n $$ d(\\Phi(\\mathbf{x}_i), \\Phi(\\mathbf{x}_j)) Given an unlabeled sample $\\mathbf{x}$, CAL runs the following process:\n Select the top $k$ nearest neighbors in the model feature space among the labeled samples, $\\{(\\mathbf{x}^l_i, y_i\\}_{i=1}^M \\subset \\mathcal{X}$. Compute the KL divergence between the model output probabilities of $\\mathbf{x}$ and each in $\\{\\mathbf{x}^l\\}$. The contrastive score of $\\mathbf{x}$ is the average of these KL divergence values: $s(\\mathbf{x}) = \\frac{1}{M} \\sum_{i=1}^M \\text{KL}(p(y \\vert \\mathbf{x}^l_i | p(y \\vert \\mathbf{x}))$. Samples with high contrastive scores are selected for active learning. On a variety of classification tasks, the experiment results of CAL look similar to the entropy baseline.\nMeasuring Representativeness Core-sets Approach A core-set is a concept in computational geometry, referring to a small set of points that approximates the shape of a larger point set. Approximation can be captured by some geometric measure. In the active learning, we expect a model that is trained over the core-set to behave comparably with the model on the entire data points.\nSener \u0026amp; Savarese (2018) treats active learning as a core-set selection problem. Let’s say, there are $N$ samples in total accessible during training. During active learning, a small set of data points get labeled at every time step $t$, denoted as $\\mathcal{S}^{(t)}$. The upper bound of the learning objective can be written as follows, where the core-set loss is defined as the difference between average empirical loss over the labeled samples and the loss over the entire dataset including unlabelled ones.\n $$ \\begin{aligned} \\mathbb{E}_{(\\mathbf{x}, y) \\sim p} [\\mathcal{L}(\\mathbf{x}, y)] \\leq\u0026 \\bigg\\vert \\mathbb{E}_{(\\mathbf{x}, y) \\sim p} [\\mathcal{L}(\\mathbf{x}, y)] - \\frac{1}{N} \\sum_{i=1}^N \\mathcal{L}(\\mathbf{x}_i, y_i) \\bigg\\vert \u0026 \\text{; Generalization error}\\\\ +\u0026 \\frac{1}{\\vert \\mathcal{S}^{(t)} \\vert} \\sum_{j=1}^{\\vert \\mathcal{S}^{(t)} \\vert} \\mathcal{L}(\\mathbf{x}^l_j, y_j) \u0026 \\text{; Training error}\\\\ +\u0026 \\bigg\\vert \\frac{1}{N} \\sum_{i=1}^N \\mathcal{L}(\\mathbf{x}_i, y_i) - \\frac{1}{\\vert \\mathcal{S}^{(t)} \\vert} \\sum_{j=1}^{\\vert \\mathcal{S}^{(t)} \\vert} \\mathcal{L}(\\mathbf{x}^l_j, y_j) \\bigg\\vert \u0026 \\text{; Core-set error} \\end{aligned} $$ Then the active learning problem can be redefined as:\n $$ \\min_{\\mathcal{S}^{(t+1)} : \\vert \\mathcal{S}^{(t+1)} \\vert \\leq b} \\bigg\\vert \\frac{1}{N}\\sum_{i=1}^N \\mathcal{L}(\\mathbf{x}_i, y_i) - \\frac{1}{\\vert \\mathcal{S}^{(t)} \\cup \\mathcal{S}^{(t+1)} \\vert} \\sum_{j=1}^{\\vert \\mathcal{S}^{(t)} \\cup \\mathcal{S}^{(t+1)} \\vert} \\mathcal{L}(\\mathbf{x}^l_j, y_j) \\bigg\\vert $$ It is equivalent to the $k$-Center problem: choose $b$ center points such that the largest distance between a data point and its nearest center is minimized. This problem is NP-hard. An approximate solution depends on the greedy algorithm.\nFig. 9. Active learning results of core-sets algorithm in comparison with several common baselines on CIFAR-10, CIFAR-100, SVHN. (Image source: Sener \u0026 Savarese 2018) It works well on image classification tasks when there is a small number of classes. When the number of classes grows to be large or the data dimensionality increases (\u0026ldquo;curse of dimensionality\u0026rdquo;), the core-set method becomes less effective (Sinha et al. 2019).\nBecause the core-set selection is expensive, Coleman et al. (2020) experimented with a weaker model (e.g. smaller, weaker architecture, not fully trained) and found that empirically using a weaker model as a proxy can significantly shorten each repeated data selection cycle of training models and selecting samples, without hurting the final error much. Their method is referred to as SVP (Selection via Proxy).\nDiverse Gradient Embedding BADGE (Batch Active learning by Diverse Gradient Embeddings; Ash et al. 2020) tracks both model uncertainty and data diversity in the gradient space. Uncertainty is measured by the gradient magnitude w.r.t. the final layer of the network and diversity is captured by a diverse set of samples that span in the gradient space.\n Uncertainty. Given an unlabeled sample $\\mathbf{x}$, BADGE first computes the prediction $\\hat{y}$ and the gradient $g_\\mathbf{x}$ of the loss on $(\\mathbf{x}, \\hat{y})$ w.r.t. the last layer’s parameters. They observed that the norm of $g_\\mathbf{x}$ conservatively estimates the example\u0026rsquo;s influence on the model learning and high-confidence samples tend to have gradient embeddings of small magnitude. Diversity. Given many gradient embeddings of many samples, $g_\\mathbf{x}$, BADGE runs $k$-means++ to sample data points accordingly. Fig. 10. Algorithm of BADGE (batch active learning by diverse gradient embeddings). (Image source: Ash et al. 2020) Measuring Training Effects Quantify Model Changes Settles et al. (2008) introduced an active learning query strategy, named EGL (Expected Gradient Length). The motivation is to find samples that can trigger the greatest update on the model if their labels are known.\nLet $\\nabla \\mathcal{L}(\\theta)$ be the gradient of the loss function with respect to the model parameters. Specifically, given an unlabeled sample $\\mathbf{x}_i$, we need to calculate the gradient assuming the label is $y \\in \\mathcal{Y}$, $\\nabla \\mathcal{L}^{(y)}(\\theta)$. Because the true label $y_i$ is unknown, EGL relies on the current model belief to compute the expected gradient change:\n $$ \\text{EGL}(\\mathbf{x}_i) = \\sum_{y_i \\in \\mathcal{Y}} p(y=y_i \\vert \\mathbf{x}) \\|\\nabla \\mathcal{L}^{(y_i)}(\\theta)\\| $$ BALD (Bayesian Active Learning by Disagreement; Houlsby et al. 2011) aims to identify samples to maximize the information gain about the model weights, that is equivalent to maximize the decrease in expected posterior entropy.\n $$ \\begin{aligned} I[\\boldsymbol{\\theta}, y \\vert x,\\mathcal{D}] \u0026= H(\\boldsymbol{\\theta} \\vert \\mathcal{D}) - \\mathbb{E}_{y \\sim p(y \\vert \\boldsymbol{x}, \\mathcal{D})} \\big[ H(\\boldsymbol{\\theta} \\vert y, \\boldsymbol{x}, \\mathcal{D}) \\big] \u0026 \\text{; Decrease in expected posterior entropy}\\\\ \u0026= H(y \\vert \\boldsymbol{x}, \\mathcal{D}) - \\mathbb{E}_{\\boldsymbol{\\theta} \\sim p(\\boldsymbol{\\theta} \\vert \\mathcal{D})} \\big[ H(y \\vert \\boldsymbol{x}, \\mathcal{\\theta}) \\big] \\end{aligned} $$ The underlying interpretation is to \u0026ldquo;seek $\\mathbf{x}$ for which the model is marginally most uncertain about $y$ (high $H(y \\vert \\mathbf{x}, \\mathcal{D})$), but for which individual settings of the parameters are confident (low $H(y \\vert \\mathbf{x}, \\boldsymbol{\\theta})$).\u0026rdquo; In other words, each individual posterior draw is confident but a collection of draws carry diverse opinions.\nBALD was originally proposed for an individual sample and Kirsch et al. (2019) extended it to work in batch mode.\nForgetting Events To investigate whether neural networks have a tendency to forget previously learned information, Mariya Toneva et al. (2019) designed an experiment: They track the model prediction for each sample during the training process and count the transitions for each sample from being classified correctly to incorrectly or vice-versa. Then samples can be categorized accordingly,\n Forgettable (redundant) samples: If the class label changes across training epochs. Unforgettable samples: If the class label assignment is consistent across training epochs. Those samples are never forgotten once learned. They found that there are a large number of unforgettable examples that are never forgotten once learnt. Examples with noisy labels or images with \u0026ldquo;uncommon\u0026rdquo; features (visually complicated to classify) are among the most forgotten examples. The experiments empirically validated that unforgettable examples can be safely removed without compromising model performance.\nIn the implementation, the forgetting event is only counted when a sample is included in the current training batch; that is, they compute forgetting across presentations of the same example in subsequent mini-batches. The number of forgetting events per sample is quite stable across different seeds and forgettable examples have a small tendency to be first-time learned later in the training. The forgetting events are also found to be transferable throughout the training period and between architectures.\nForgetting events can be used as a signal for active learning acquisition if we hypothesize a model changing predictions during training is an indicator of model uncertainty. However, ground truth is unknown for unlabeled samples. Bengar et al. (2021) proposed a new metric called label dispersion for such a purpose. Let’s see across the training time, $c^*$ is the most commonly predicted label for the input $\\mathbf{x}$ and the label dispersion measures the fraction of training steps when the model does not assign $c^**$ to this sample:\n $$ \\text{Dispersion}(\\mathbf{x}) = 1 - \\frac{f_\\mathbf{x}}{T} \\text{ where } f_\\mathbf{x} = \\sum_{t=1}^T \\mathbb{1}[\\hat{y}_t = c^*], c^* = \\arg\\max_{c=1,\\dots,C}\\sum_{t=1}^T \\mathbb{1}[\\hat{y}_t = c] $$ In their implementation, dispersion is computed at every epoch. Label dispersion is low if the model consistently assigns the same label to the same sample but high if the prediction changes often. Label dispersion is correlated with network uncertainty, as shown in Fig. 11.\nFig. 11. Label dispersion is correlated with network uncertainty. On the x-axis, data points are sorted by label dispersion scores. The y-axis is the model prediction accuracy when the model trys to infer the labels for those samples. (Image source: Bengar et al. 2021) Hybrid When running active learning in batch mode, it is important to control diversity within a batch. Suggestive Annotation (SA; Yang et al. 2017) is a two-step hybrid strategy, aiming to select both high uncertainty \u0026amp; highly representative labeled samples. It uses uncertainty obtained from an ensemble of models trained on the labeled data and core-sets for choosing representative data samples.\n First, SA selects top $K$ images with high uncertainty scores to form a candidate pool $\\mathcal{S}_c \\subseteq \\mathcal{S}_U$. The uncertainty is measured as disagreement between multiple models training with bootstrapping. The next step is to find a subset $\\mathcal{S}_a \\subseteq \\mathcal{S}_c$ with highest representativeness. The cosine similarity between feature vectors of two inputs approximates how similar they are. The representativeness of $\\mathcal{S}_a$ for $\\mathcal{S}_U$ reflects how well $\\mathcal{S}_a$ can represent all the samples in $\\mathcal{S}_u$, defined as: $$ F(\\mathcal{S}_a, \\mathcal{S}_u) = \\sum_{\\mathbf{x}_j \\in \\mathcal{S}_u} f(\\mathcal{S}_a, \\mathbf{x}_j) = \\sum_{\\mathbf{x}_j \\in \\mathcal{S}_u} \\max_{\\mathbf{x}_i \\in \\mathcal{S}_a} \\text{sim}(\\mathbf{x}_i, \\mathbf{x}_j) $$ Formulating $\\mathcal{S}_a \\subseteq \\mathcal{S}_c$ with $k$ data points that maximizes $F(\\mathcal{S}_a, \\mathcal{S}_u)$ is a generalized version of the maximum set cover problem. It is NP-hard and its best possible polynomial time approximation algorithm is a simple greedy method.\n Initially, $\\mathcal{S}_a = \\emptyset$ and $F(\\mathcal{S}_a, \\mathcal{S}_u) = 0$. Then, iteratively add $\\mathbf{x}_i \\in \\mathcal{S}_c$ that maximizes $F(\\mathcal{S}_a \\cup I_i, \\mathcal{S}_u)$ over $\\mathcal{S}_a$, until $\\mathcal{S}_s$ contains $k$ images. Zhdanov (2019) runs a similar process as SA, but at step 2, it relies on $k$-means instead of core-set, where the size of the candidate pool is configured relative to the batch size. Given batch size $b$ and a constant $beta$ (between 10 and 50), it follows these steps:\n Train a classifier on the labeled data; Measure informativeness of every unlabeled example (e.g. using uncertainty metrics); Prefilter top $\\beta b \\geq b$ most informative examples; Cluster $\\beta b$ examples into $B$ clusters; Select $b$ different examples closest to the cluster centers for this round of active learning. Active learning can be further combined with semi-supervised learning to save the budget. CEAL (Cost-Effective Active Learning; Yang et al. 2017) runs two things in parallel:\n Select uncertain samples via active learning and get them labeled; Select samples with the most confident prediction and assign them pseudo labels. The confidence prediction is judged by whether the prediction entropy is below a threshold $\\delta$. As the model is getting better in time, the threshold $\\delta$ decays in time as well. Fig. 12. Illustration of CEAL (cost-effective active learning). (Image source: Yang et al. 2017) Citation Cited as:\n Weng, Lilian. (Feb 2022). Learning with not enough data part 2: active learning. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2022-02-20-active-learning/.\n Or\n@article{weng2022active, title = \u0026quot;Learning with not Enough Data Part 2: Active Learning\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2022\u0026quot;, month = \u0026quot;Feb\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2022-02-20-active-learning/\u0026quot; } References [1] Burr Settles. Active learning literature survey. University of Wisconsin, Madison, 52(55-66):11, 2010.\n[2] https://jacobgil.github.io/deeplearning/activelearning\n[3] Yang et al. \u0026ldquo;Cost-effective active learning for deep image classification\u0026rdquo; TCSVT 2016.\n[4] Yarin Gal et al. \u0026ldquo;Dropout as a Bayesian Approximation: representing model uncertainty in deep learning.\u0026quot; ICML 2016.\n[5] Blundell et al. \u0026ldquo;Weight uncertainty in neural networks (Bayes-by-Backprop)\u0026quot; ICML 2015.\n[6] Settles et al. \u0026ldquo;Multiple-Instance Active Learning.\u0026quot; NIPS 2007.\n[7] Houlsby et al. Bayesian Active Learning for Classification and Preference Learning.\u0026quot; arXiv preprint arXiv:1112.5745 (2020).\n[8] Kirsch et al. \u0026ldquo;BatchBALD: Efficient and Diverse Batch Acquisition for Deep Bayesian Active Learning.\u0026quot; NeurIPS 2019.\n[9] Beluch et al. \u0026ldquo;The power of ensembles for active learning in image classification.\u0026quot; CVPR 2018.\n[10] Sener \u0026amp; Savarese. \u0026ldquo;Active learning for convolutional neural networks: A core-set approach.\u0026quot; ICLR 2018.\n[11] Donggeun Yoo \u0026amp; In So Kweon. \u0026ldquo;Learning Loss for Active Learning.\u0026quot; CVPR 2019.\n[12] Margatina et al. \u0026ldquo;Active Learning by Acquiring Contrastive Examples.\u0026quot; EMNLP 2021.\n[13] Sinha et al. \u0026ldquo;Variational Adversarial Active Learning\u0026rdquo; ICCV 2019\n[14] Ebrahimiet al. \u0026ldquo;Minmax Active Learning\u0026rdquo; arXiv preprint arXiv:2012.10467 (2021).\n[15] Mariya Toneva et al. \u0026ldquo;An empirical study of example forgetting during deep neural network learning.\u0026quot; ICLR 2019.\n[16] Javad Zolfaghari Bengar et al. \u0026ldquo;When Deep Learners Change Their Mind: Learning Dynamics for Active Learning.\u0026quot; CAIP 2021.\n[17] Yang et al. \u0026ldquo;Suggestive annotation: A deep active learning framework for biomedical image segmentation.\u0026quot; MICCAI 2017.\n[18] Fedor Zhdanov. \u0026ldquo;Diverse mini-batch Active Learning\u0026rdquo; arXiv preprint arXiv:1901.05954 (2019).\n","permalink":"https://lilianweng.github.io/posts/2022-02-20-active-learning/","summary":"This is part 2 of what to do when facing a limited amount of labeled data for supervised learning tasks. This time we will get some amount of human labeling work involved, but within a budget limit, and therefore we need to be smart when selecting which samples to label.\nNotations Symbol Meaning $K$ Number of unique class labels. $(\\mathbf{x}^l, y) \\sim \\mathcal{X}, y \\in \\{0, 1\\}^K$ Labeled dataset.","title":"Learning with not Enough Data Part 2: Active Learning"},{"content":"When facing a limited amount of labeled data for supervised learning tasks, four approaches are commonly discussed.\n Pre-training + fine-tuning: Pre-train a powerful task-agnostic model on a large unsupervised data corpus, e.g. pre-training LMs on free text, or pre-training vision models on unlabelled images via self-supervised learning, and then fine-tune it on the downstream task with a small set of labeled samples. Semi-supervised learning: Learn from the labelled and unlabeled samples together. A lot of research has happened on vision tasks within this approach. Active learning: Labeling is expensive, but we still want to collect more given a cost budget. Active learning learns to select most valuable unlabeled samples to be collected next and helps us act smartly with a limited budget. Pre-training + dataset auto-generation: Given a capable pre-trained model, we can utilize it to auto-generate a lot more labeled samples. This has been especially popular within the language domain driven by the success of few-shot learning. I plan to write a series of posts on the topic of “Learning with not enough data”. Part 1 is on Semi-Supervised Learning.\nWhat is semi-supervised learning? Semi-supervised learning uses both labeled and unlabeled data to train a model.\nInterestingly most existing literature on semi-supervised learning focuses on vision tasks. And instead pre-training + fine-tuning is a more common paradigm for language tasks.\nAll the methods introduced in this post have a loss combining two parts: $\\mathcal{L} = \\mathcal{L}_s + \\mu(t) \\mathcal{L}_u$. The supervised loss $\\mathcal{L}_s$ is easy to get given all the labeled examples. We will focus on how the unsupervised loss $\\mathcal{L}_u$ is designed. A common choice of the weighting term $\\mu(t)$ is a ramp function increasing the importance of $\\mathcal{L}_u$ in time, where $t$ is the training step.\n Disclaimer: The post is not gonna cover semi-supervised methods with focus on model architecture modification. Check this survey for how to use generative models and graph-based methods in semi-supervised learning.\n Notations Symbol Meaning $L$ Number of unique labels. $(\\mathbf{x}^l, y) \\sim \\mathcal{X}, y \\in \\{0, 1\\}^L$ Labeled dataset. $y$ is a one-hot representation of the true label. $\\mathbf{u} \\sim \\mathcal{U}$ Unlabeled dataset. $\\mathcal{D} = \\mathcal{X} \\cup \\mathcal{U}$ The entire dataset, including both labeled and unlabeled examples. $\\mathbf{x}$ Any sample which can be either labeled or unlabeled. $\\bar{\\mathbf{x}}$ $\\mathbf{x}$ with augmentation applied. $\\mathbf{x}_i$ The $i$-th sample. $\\mathcal{L}$, $\\mathcal{L}_s$, $\\mathcal{L}_u$ Loss, supervised loss, and unsupervised loss. $\\mu(t)$ The unsupervised loss weight, increasing in time. $p(y \\vert \\mathbf{x}), p_\\theta(y \\vert \\mathbf{x})$ The conditional probability over the label set given the input. $f_\\theta(.)$ The implemented neural network with weights $\\theta$, the model that we want to train. $\\mathbf{z} = f_\\theta(\\mathbf{x})$ A vector of logits output by $f$. $\\hat{y} = \\text{softmax}(\\mathbf{z})$ The predicted label distribution. $D[.,.]$ A distance function between two distributions, such as MSE, cross entropy, KL divergence, etc. $\\beta$ EMA weighting hyperparameter for teacher model weights. $\\alpha, \\lambda$ Parameters for MixUp, $\\lambda \\sim \\text{Beta}(\\alpha, \\alpha)$. $T$ Temperature for sharpening the predicted distribution. $\\tau$ A confidence threshold for selecting the qualified prediction. Hypotheses Several hypotheses have been discussed in literature to support certain design decisions in semi-supervised learning methods.\n H1: Smoothness Assumptions: If two data samples are close in a high-density region of the feature space, their labels should be the same or very similar.\n H2: Cluster Assumptions: The feature space has both dense regions and sparse regions. Densely grouped data points naturally form a cluster. Samples in the same cluster are expected to have the same label. This is a small extension of H1.\n H3: Low-density Separation Assumptions: The decision boundary between classes tends to be located in the sparse, low density regions, because otherwise the decision boundary would cut a high-density cluster into two classes, corresponding to two clusters, which invalidates H1 and H2.\n H4: Manifold Assumptions: The high-dimensional data tends to locate on a low-dimensional manifold. Even though real-world data might be observed in very high dimensions (e.g. such as images of real-world objects/scenes), they actually can be captured by a lower dimensional manifold where certain attributes are captured and similar points are grouped closely (e.g. images of real-world objects/scenes are not drawn from a uniform distribution over all pixel combinations). This enables us to learn a more efficient representation for us to discover and measure similarity between unlabeled data points. This is also the foundation for representation learning. [see a helpful link].\n Consistency Regularization Consistency Regularization, also known as Consistency Training, assumes that randomness within the neural network (e.g. with Dropout) or data augmentation transformations should not modify model predictions given the same input. Every method in this section has a consistency regularization loss as $\\mathcal{L}_u$.\nThis idea has been adopted in several self-supervised learning methods, such as SimCLR, BYOL, SimCSE, etc. Different augmented versions of the same sample should result in the same representation. Cross-view training in language modeling and multi-view learning in self-supervised learning all share the same motivation.\nΠ-model Fig. 1. Overview of the Π-model. Two versions of the same input with different stochastic augmentation and dropout masks pass through the network and the outputs are expected to be consistent. (Image source: Laine \u0026 Aila (2017)) Sajjadi et al. (2016) proposed an unsupervised learning loss to minimize the difference between two passes through the network with stochastic transformations (e.g. dropout, random max-pooling) for the same data point. The label is not explicitly used, so the loss can be applied to unlabeled dataset. Laine \u0026amp; Aila (2017) later coined the name, Π-Model, for such a setup.\n $$ \\mathcal{L}_u^\\Pi = \\sum_{\\mathbf{x} \\in \\mathcal{D}} \\text{MSE}(f_\\theta(\\mathbf{x}), f'_\\theta(\\mathbf{x})) $$ where $f'$ is the same neural network with different stochastic augmentation or dropout masks applied. This loss utilizes the entire dataset.\nTemporal ensembling Fig. 2. Overview of Temporal Ensembling. The per-sample EMA label prediction is the learning target. (Image source: Laine \u0026 Aila (2017)) Π-model requests the network to run two passes per sample, doubling the computation cost. To reduce the cost, Temporal Ensembling (Laine \u0026amp; Aila 2017) maintains an exponential moving average (EMA) of the model prediction in time per training sample $\\tilde{\\mathbf{z}}_i$ as the learning target, which is only evaluated and updated once per epoch. Because the ensemble output $\\tilde{\\mathbf{z}}_i$ is initialized to $\\mathbf{0}$, it is normalized by $(1-\\alpha^t)$ to correct this startup bias. Adam optimizer has such bias correction terms for the same reason.\n $$ \\tilde{\\mathbf{z}}^{(t)}_i = \\frac{\\alpha \\tilde{\\mathbf{z}}^{(t-1)}_i + (1-\\alpha) \\mathbf{z}_i}{1-\\alpha^t} $$ where $\\tilde{\\mathbf{z}}^{(t)}$ is the ensemble prediction at epoch $t$ and $\\mathbf{z}_i$ is the model prediction in the current round. Note that since $\\tilde{\\mathbf{z}}^{(0)} = \\mathbf{0}$, with correction, $\\tilde{\\mathbf{z}}^{(1)}$ is simply equivalent to $\\mathbf{z}_i$ at epoch 1.\nMean teachers Fig. 3. Overview of the Mean Teacher framework. (Image source: Tarvaninen \u0026 Valpola, 2017) Temporal Ensembling keeps track of an EMA of label predictions for each training sample as a learning target. However, this label prediction only changes every epoch, making the approach clumsy when the training dataset is large. Mean Teacher (Tarvaninen \u0026amp; Valpola, 2017) is proposed to overcome the slowness of target update by tracking the moving average of model weights instead of model outputs. Let’s call the original model with weights $\\theta$ as the student model and the model with moving averaged weights $\\theta’$ across consecutive student models as the mean teacher: $\\theta’ \\gets \\beta \\theta’ + (1-\\beta)\\theta$\nThe consistency regularization loss is the distance between predictions by the student and teacher and the student-teacher gap should be minimized. The mean teacher is expected to provide more accurate predictions than the student. It got confirmed in the empirical experiments, as shown in Fig. 4.\nFig. 4. Classification error on SVHN of Mean Teacher and the Π Model. The mean teacher (in orange) has better performance than the student model (in blue). (Image source: Tarvaninen \u0026 Valpola, 2017) According to their ablation studies,\n Input augmentation (e.g. random flips of input images, Gaussian noise) or student model dropout is necessary for good performance. Dropout is not needed on the teacher model. The performance is sensitive to the EMA decay hyperparameter $\\beta$. A good strategy is to use a small $\\beta=0.99$ during the ramp up stage and a larger $\\beta=0.999$ in the later stage when the student model improvement slows down. They found that MSE as the consistency cost function performs better than other cost functions like KL divergence. Noisy samples as learning targets Several recent consistency training methods learn to minimize prediction difference between the original unlabeled sample and its corresponding augmented version. It is quite similar to the Π-model but the consistency regularization loss is only applied to the unlabeled data.\nFig. 5. Consistency training with noisy samples. Adversarial Training (Goodfellow et al. 2014) applies adversarial noise onto the input and trains the model to be robust to such adversarial attack. The setup works in supervised learning,\n $$ \\begin{aligned} \\mathcal{L}_\\text{adv}(\\mathbf{x}^l, \\theta) \u0026= D[q(y\\mid \\mathbf{x}^l), p_\\theta(y\\mid \\mathbf{x}^l + r_\\text{adv})] \\\\ r_\\text{adv} \u0026= {\\arg\\max}_{r; \\|r\\| \\leq \\epsilon} D[q(y\\mid \\mathbf{x}^l), p_\\theta(y\\mid \\mathbf{x}^l + r_\\text{adv})] \\\\ r_\\text{adv} \u0026\\approx \\epsilon \\frac{g}{\\|g\\|_2} \\approx \\epsilon\\text{sign}(g)\\quad\\text{where }g = \\nabla_{r} D[y, p_\\theta(y\\mid \\mathbf{x}^l + r)] \\end{aligned} $$ where $q(y \\mid \\mathbf{x}^l)$ is the true distribution, approximated by one-hot encoding of the ground truth label, $y$. $p_\\theta(y \\mid \\mathbf{x}^l)$ is the model prediction. $D[.,.]$ is a distance function measuring the divergence between two distributions.\nVirtual Adversarial Training (VAT; Miyato et al. 2018) extends the idea to work in semi-supervised learning. Because $q(y \\mid \\mathbf{x}^l)$ is unknown, VAT replaces it with the current model prediction for the original input with the current weights $\\hat{\\theta}$. Note that $\\hat{\\theta}$ is a fixed copy of model weights, so there is no gradient update on $\\hat{\\theta}$.\n $$ \\begin{aligned} \\mathcal{L}_u^\\text{VAT}(\\mathbf{x}, \\theta) \u0026= D[p_{\\hat{\\theta}}(y\\mid \\mathbf{x}), p_\\theta(y\\mid \\mathbf{x} + r_\\text{vadv})] \\\\ r_\\text{vadv} \u0026= {\\arg\\max}_{r; \\|r\\| \\leq \\epsilon} D[p_{\\hat{\\theta}}(y\\mid \\mathbf{x}), p_\\theta(y\\mid \\mathbf{x} + r)] \\end{aligned} $$ The VAT loss applies to both labeled and unlabeled samples. It is a negative smoothness measure of the current model\u0026rsquo;s prediction manifold at each data point. The optimization of such loss motivates the manifold to be smoother.\nInterpolation Consistency Training (ICT; Verma et al. 2019) enhances the dataset by adding more interpolations of data points and expects the model prediction to be consistent with interpolations of the corresponding labels. MixUp (Zheng et al. 2018) operation mixes two images via a simple weighted sum and combines it with label smoothing. Following the idea of MixUp, ICT expects the prediction model to produce a label on a mixup sample to match the interpolation of predictions of corresponding inputs:\n $$ \\begin{aligned} \\text{mixup}_\\lambda (\\mathbf{x}_i, \\mathbf{x}_j) \u0026= \\lambda \\mathbf{x}_i + (1-\\lambda)\\mathbf{x}_j \\\\ p(\\text{mixup}_\\lambda (y \\mid \\mathbf{x}_i, \\mathbf{x}_j)) \u0026\\approx \\lambda p(y \\mid \\mathbf{x}_i) + (1-\\lambda) p(y \\mid \\mathbf{x}_j) \\end{aligned} $$ where $\\theta'$ is a moving average of $\\theta$, which is a mean teacher.\nFig. 6. Overview of Interpolation Consistency Training. MixUp is applied to produce more interpolated samples with interpolated labels as learning targets. (Image source: Verma et al. 2019) Because the probability of two randomly selected unlabeled samples belonging to different classes is high (e.g. There are 1000 object classes in ImageNet), the interpolation by applying a mixup between two random unlabeled samples is likely to happen around the decision boundary. According to the low-density separation assumptions, the decision boundary tends to locate in the low density regions.\n $$ \\mathcal{L}^\\text{ICT}_{u} = \\mathbb{E}_{\\mathbf{u}_i, \\mathbf{u}_j \\sim \\mathcal{U}} \\mathbb{E}_{\\lambda \\sim \\text{Beta}(\\alpha, \\alpha)} D[p_\\theta(y \\mid \\text{mixup}_\\lambda (\\mathbf{u}_i, \\mathbf{u}_j)), \\text{mixup}_\\lambda(p_{\\theta’}(y \\mid \\mathbf{u}_i), p_{\\theta'}(y \\mid \\mathbf{u}_j)] $$ where $\\theta'$ is a moving average of $\\theta$.\nSimilar to VAT, Unsupervised Data Augmentation (UDA; Xie et al. 2020) learns to predict the same output for an unlabeled example and the augmented one. UDA especially focuses on studying how the \u0026ldquo;quality\u0026rdquo; of noise can impact the semi-supervised learning performance with consistency training. It is crucial to use advanced data augmentation methods for producing meaningful and effective noisy samples. Good data augmentation should produce valid (i.e. does not change the label) and diverse noise, and carry targeted inductive biases.\nFor images, UDA adopts RandAugment (Cubuk et al. 2019) which uniformly samples augmentation operations available in PIL, no learning or optimization, so it is much cheaper than AutoAugment.\nFig. 7. Comparison of various semi-supervised learning methods on CIFAR-10 classification. Fully supervised Wide-ResNet-28-2 and PyramidNet+ShakeDrop have an error rate of **5.4** and **2.7** respectively when trained on 50,000 examples without RandAugment. (Image source: Xie et al. 2020) For language, UDA combines back-translation and TF-IDF based word replacement. Back-translation preserves the high-level meaning but may not retain certain words, while TF-IDF based word replacement drops uninformative words with low TF-IDF scores. In the experiments on language tasks, they found UDA to be complementary to transfer learning and representation learning; For example, BERT fine-tuned (i.e. $\\text{BERT}_\\text{FINETUNE}$ in Fig. 8.) on in-domain unlabeled data can further improve the performance.\nFig. 8. Comparison of UDA with different initialization configurations on various text classification tasks. (Image source: Xie et al. 2020) When calculating $\\mathcal{L}_u$, UDA found two training techniques to help improve the results.\n Low confidence masking: Mask out examples with low prediction confidence if lower than a threshold $\\tau$. Sharpening prediction distribution: Use a low temperature $T$ in softmax to sharpen the predicted probability distribution. In-domain data filtration: In order to extract more in-domain data from a large out-of-domain dataset, they trained a classifier to predict in-domain labels and then retain samples with high confidence predictions as in-domain candidates. $$ \\begin{aligned} \u0026\\mathcal{L}_u^\\text{UDA} = \\mathbb{1}[\\max_{y'} p_{\\hat{\\theta}}(y'\\mid \\mathbf{x}) \\tau ] \\cdot D[p^\\text{(sharp)}_{\\hat{\\theta}}(y \\mid \\mathbf{x}; T), p_\\theta(y \\mid \\bar{\\mathbf{x}})] \\\\ \u0026\\text{where } p_{\\hat{\\theta}}^\\text{(sharp)}(y \\mid \\mathbf{x}; T) = \\frac{\\exp(z^{(y)} / T)}{ \\sum_{y'} \\exp(z^{(y')} / T) } \\end{aligned} $$ where $\\hat{\\theta}$ is a fixed copy of model weights, same as in VAT, so no gradient update, and $\\bar{\\mathbf{x}}$ is the augmented data point. $\\tau$ is the prediction confidence threshold and $T$ is the distribution sharpening temperature.\nPseudo Labeling Pseudo Labeling (Lee 2013) assigns fake labels to unlabeled samples based on the maximum softmax probabilities predicted by the current model and then trains the model on both labeled and unlabeled samples simultaneously in a pure supervised setup.\nWhy could pseudo labels work? Pseudo label is in effect equivalent to Entropy Regularization (Grandvalet \u0026amp; Bengio 2004), which minimizes the conditional entropy of class probabilities for unlabeled data to favor low density separation between classes. In other words, the predicted class probabilities is in fact a measure of class overlap, minimizing the entropy is equivalent to reduced class overlap and thus low density separation.\nFig. 9. t-SNE visualization of outputs on MNIST test set by models training (a) without and (b) with pseudo labeling on 60000 unlabeled samples, in addition to 600 labeled data. Pseudo labeling leads to better segregation in the learned embedding space. (Image source: Lee 2013) Training with pseudo labeling naturally comes as an iterative process. We refer to the model that produces pseudo labels as teacher and the model that learns with pseudo labels as student.\nLabel propagation Label Propagation (Iscen et al. 2019) is an idea to construct a similarity graph among samples based on feature embedding. Then the pseudo labels are \u0026ldquo;diffused\u0026rdquo; from known samples to unlabeled ones where the propagation weights are proportional to pairwise similarity scores in the graph. Conceptually it is similar to a k-NN classifier and both suffer from the problem of not scaling up well with a large dataset.\nFig. 10. Illustration of how Label Propagation works. (Image source: Iscen et al. 2019) Self-Training Self-Training is not a new concept (Scudder 1965, Nigram \u0026amp; Ghani CIKM 2000). It is an iterative algorithm, alternating between the following two steps until every unlabeled sample has a label assigned:\n Initially it builds a classifier on labeled data. Then it uses this classifier to predict labels for the unlabeled data and converts the most confident ones into labeled samples. Xie et al. (2020) applied self-training in deep learning and achieved great results. On the ImageNet classification task, they first trained an EfficientNet (Tan \u0026amp; Le 2019) model as teacher to generate pseudo labels for 300M unlabeled images and then trained a larger EfficientNet as student to learn with both true labeled and pseudo labeled images. One critical element in their setup is to have noise during student model training but have no noise for the teacher to produce pseudo labels. Thus their method is called Noisy Student. They applied stochastic depth (Huang et al. 2016), dropout and RandAugment to noise the student. Noise is important for the student to perform better than the teacher. The added noise has a compound effect to encourage the model\u0026rsquo;s decision making frontier to be smooth, on both labeled and unlabeled data.\nA few other important technical configs in noisy student self-training are:\n The student model should be sufficiently large (i.e. larger than the teacher) to fit more data. Noisy student should be paired with data balancing, especially important to balance the number of pseudo labeled images in each class. Soft pseudo labels work better than hard ones. Noisy student also improves adversarial robustness against an FGSM (Fast Gradient Sign Attack = The attack uses the gradient of the loss w.r.t the input data and adjusts the input data to maximize the loss) attack though the model is not optimized for adversarial robustness.\nSentAugment, proposed by Du et al. (2020), aims to solve the problem when there is not enough in-domain unlabeled data for self-training in the language domain. It relies on sentence embedding to find unlabeled in-domain samples from a large corpus and uses the retrieved sentences for self-training.\nReducing confirmation bias Confirmation bias is a problem with incorrect pseudo labels provided by an imperfect teacher model. Overfitting to wrong labels may not give us a better student model.\nTo reduce confirmation bias, Arazo et al. (2019) proposed two techniques. One is to adopt MixUp with soft labels. Given two samples, $(\\mathbf{x}_i, \\mathbf{x}_j)$ and their corresponding true or pseudo labels $(y_i, y_j)$, the interpolated label equation can be translated to a cross entropy loss with softmax outputs:\n $$ \\begin{aligned} \u0026\\bar{\\mathbf{x}} = \\lambda \\mathbf{x}_i + (1-\\lambda) \\mathbf{x}_j \\\\ \u0026\\bar{y} = \\lambda y_i + (1-\\lambda) y_j \\Leftrightarrow \\mathcal{L} = \\lambda [y_i^\\top \\log f_\\theta(\\bar{\\mathbf{x}})] + (1-\\lambda) [y_j^\\top \\log f_\\theta(\\bar{\\mathbf{x}})] \\end{aligned} $$ Mixup is insufficient if there are too few labeled samples. They further set a minimum number of labeled samples in every mini batch by oversampling the labeled samples. This works better than upweighting labeled samples, because it leads to more frequent updates rather than few updates of larger magnitude which could be less stable. Like consistency regularization, data augmentation and dropout are also important for pseudo labeling to work well.\nMeta Pseudo Labels (Pham et al. 2021) adapts the teacher model constantly with the feedback of how well the student performs on the labeled dataset. The teacher and the student are trained in parallel, where the teacher learns to generate better pseudo labels and the student learns from the pseudo labels.\nLet the teacher and student model weights be $\\theta_T$ and $\\theta_S$, respectively. The student model\u0026rsquo;s loss on the labeled samples is defined as a function $\\theta^\\text{PL}_S(.)$ of $\\theta_T$ and we would like to minimize this loss by optimizing the teacher model accordingly.\n $$ \\begin{aligned} \\min_{\\theta_T} \u0026\\mathcal{L}_s(\\theta^\\text{PL}_S(\\theta_T)) = \\min_{\\theta_T} \\mathbb{E}_{(\\mathbf{x}^l, y) \\in \\mathcal{X}} \\text{CE}[y, f_{\\theta_S}(\\mathbf{x}^l)] \\\\ \\text{where } \u0026\\theta^\\text{PL}_S(\\theta_T) = \\arg\\min_{\\theta_S} \\mathcal{L}_u (\\theta_T, \\theta_S) = \\arg\\min_{\\theta_S} \\mathbb{E}_{\\mathbf{u} \\sim \\mathcal{U}} \\text{CE}[(f_{\\theta_T}(\\mathbf{u}), f_{\\theta_S}(\\mathbf{u}))] \\end{aligned} $$ However, it is not trivial to optimize the above equation. Borrowing the idea of MAML, it approximates the multi-step $\\arg\\min_{\\theta_S}$ with the one-step gradient update of $\\theta_S$,\n $$ \\begin{aligned} \\theta^\\text{PL}_S(\\theta_T) \u0026\\approx \\theta_S - \\eta_S \\cdot \\nabla_{\\theta_S} \\mathcal{L}_u(\\theta_T, \\theta_S) \\\\ \\min_{\\theta_T} \\mathcal{L}_s (\\theta^\\text{PL}_S(\\theta_T)) \u0026\\approx \\min_{\\theta_T} \\mathcal{L}_s \\big( \\theta_S - \\eta_S \\cdot \\nabla_{\\theta_S} \\mathcal{L}_u(\\theta_T, \\theta_S) \\big) \\end{aligned} $$ With soft pseudo labels, the above objective is differentiable. But if using hard pseudo labels, it is not differentiable and thus we need to use RL, e.g. REINFORCE.\nThe optimization procedure is alternative between training two models:\n Student model update: Given a batch of unlabeled samples $\\{ \\mathbf{u} \\}$, we generate pseudo labels by $f_{\\theta_T}(\\mathbf{u})$ and optimize $\\theta_S$ with one step SGD: $\\theta’_S = \\color{green}{\\theta_S - \\eta_S \\cdot \\nabla_{\\theta_S} \\mathcal{L}_u(\\theta_T, \\theta_S)}$. Teacher model update: Given a batch of labeled samples $\\{(\\mathbf{x}^l, y)\\}$, we reuse the student’s update to optimize $\\theta_T$: $\\theta’_T = \\theta_T - \\eta_T \\cdot \\nabla_{\\theta_T} \\mathcal{L}_s ( \\color{green}{\\theta_S - \\eta_S \\cdot \\nabla_{\\theta_S} \\mathcal{L}_u(\\theta_T, \\theta_S)} )$. In addition, the UDA objective is applied to the teacher model to incorporate consistency regularization. Fig. 11. Comparison of Meta Pseudo Labels with other semi- or self-supervised learning methods on image classification tasks. (Image source: Pham et al. 2021) Pseudo Labeling with Consistency Regularization It is possible to combine the above two approaches together, running semi-supervised learning with both pseudo labeling and consistency training.\nMixMatch MixMatch (Berthelot et al. 2019), as a holistic approach to semi-supervised learning, utilizes unlabeled data by merging the following techniques:\n Consistency regularization: Encourage the model to output the same predictions on perturbed unlabeled samples. Entropy minimization: Encourage the model to output confident predictions on unlabeled data. MixUp augmentation: Encourage the model to have linear behaviour between samples. Given a batch of labeled data $\\mathcal{X}$ and unlabeled data $\\mathcal{U}$, we create augmented versions of them via $\\text{MixMatch}(.)$, $\\bar{\\mathcal{X}}$ and $\\bar{\\mathcal{U}}$, containing augmented samples and guessed labels for unlabeled examples.\n $$ \\begin{aligned} \\bar{\\mathcal{X}}, \\bar{\\mathcal{U}} \u0026= \\text{MixMatch}(\\mathcal{X}, \\mathcal{U}, T, K, \\alpha) \\\\ \\mathcal{L}^\\text{MM}_s \u0026= \\frac{1}{\\vert \\bar{\\mathcal{X}} \\vert} \\sum_{(\\bar{\\mathbf{x}}^l, y)\\in \\bar{\\mathcal{X}}} D[y, p_\\theta(y \\mid \\bar{\\mathbf{x}}^l)] \\\\ \\mathcal{L}^\\text{MM}_u \u0026= \\frac{1}{L\\vert \\bar{\\mathcal{U}} \\vert} \\sum_{(\\bar{\\mathbf{u}}, \\hat{y})\\in \\bar{\\mathcal{U}}} \\| \\hat{y} - p_\\theta(y \\mid \\bar{\\mathbf{u}}) \\|^2_2 \\\\ \\end{aligned} $$ where $T$ is the sharpening temperature to reduce the guessed label overlap; $K$ is the number of augmentations generated per unlabeled example; $\\alpha$ is the parameter in MixUp.\nFor each $\\mathbf{u}$, MixMatch generates $K$ augmentations, $\\bar{\\mathbf{u}}^{(k)} = \\text{Augment}(\\mathbf{u})$ for $k=1, \\dots, K$ and the pseudo label is guessed based on the average: $\\hat{y} = \\frac{1}{K} \\sum_{k=1}^K p_\\theta(y \\mid \\bar{\\mathbf{u}}^{(k)})$.\nFig. 12. The process of \"label guessing\" in MixMatch: averaging $K$ augmentations, correcting the predicted marginal distribution and finally sharpening the distribution. (Image source: Berthelot et al. 2019) According to their ablation studies, it is critical to have MixUp especially on the unlabeled data. Removing temperature sharpening on the pseudo label distribution hurts the performance quite a lot. Average over multiple augmentations for label guessing is also necessary.\nReMixMatch (Berthelot et al. 2020) improves MixMatch by introducing two new mechanisms:\nFig. 13. Illustration of two improvements introduced in ReMixMatch over MixMatch. (Image source: Berthelot et al. 2020) Distribution alignment. It encourages the marginal distribution $p(y)$ to be close to the marginal distribution of the ground truth labels. Let $p(y)$ be the class distribution in the true labels and $\\tilde{p}(\\hat{y})$ be a running average of the predicted class distribution among the unlabeled data. The model prediction on an unlabeled sample $p_\\theta(y \\vert \\mathbf{u})$ is normalized to be $\\text{Normalize}\\big( \\frac{p_\\theta(y \\vert \\mathbf{u}) p(y)}{\\tilde{p}(\\hat{y})} \\big)$ to match the true marginal distribution. Note that entropy minimization is not a useful objective if the marginal distribution is not uniform. I do feel the assumption that the class distributions on the labeled and unlabeled data should match is too strong and not necessarily to be true in the real-world setting. Augmentation anchoring. Given an unlabeled sample, it first generates an \u0026ldquo;anchor\u0026rdquo; version with weak augmentation and then averages $K$ strongly augmented versions using CTAugment (Control Theory Augment). CTAugment only samples augmentations that keep the model predictions within the network tolerance. The ReMixMatch loss is a combination of several terms,\n a supervised loss with data augmentation and MixUp applied; an unsupervised loss with data augmentation and MixUp applied, using pseudo labels as targets; a CE loss on a single heavily-augmented unlabeled image without MixUp; a rotation loss as in self-supervised learning. DivideMix DivideMix (Junnan Li et al. 2020) combines semi-supervised learning with Learning with noisy labels (LNL). It models the per-sample loss distribution via a GMM to dynamically divide the training data into a labeled set with clean examples and an unlabeled set with noisy ones. Following the idea in Arazo et al. 2019, they fit a two-component GMM on the per-sample cross entropy loss $\\ell_i = y_i^\\top \\log f_\\theta(\\mathbf{x}_i)$. Clean samples are expected to get lower loss faster than noisy samples. The component with smaller mean is the cluster corresponding to clean labels and let’s denote it as $c$. If the GMM posterior probability $w_i = p_\\text{GMM}(c \\mid \\ell_i)$ (i.e. the probability of the sampling belonging to the clean sample set) is larger than the threshold $\\tau$, this sample is considered as a clean sample and otherwise a noisy one.\nThe data clustering step is named co-divide. To avoid confirmation bias, DivideMix simultaneously trains two diverged networks where each network uses the dataset division from the other network; e.g. thinking about how Double Q Learning works.\nFig. 14. DivideMix trains two networks independently to reduce confirmation bias. They run co-divide, co-refinement, and co-guessing together. (Image source: Junnan Li et al. 2020) Compared to MixMatch, DivideMix has an additional co-divide stage for handling noisy samples, as well as the following improvements during training:\n Label co-refinement: It linearly combines the ground-truth label $y_i$ with the network’s prediction $\\hat{y}_i$, which is averaged across multiple augmentations of $\\mathbf{x}_i$, guided by the clean set probability $w_i$ produced by the other network. Label co-guessing: It averages the predictions from two models for unlabelled data samples. Fig. 15. The algorithm of DivideMix. (Image source: Junnan Li et al. 2020) FixMatch FixMatch (Sohn et al. 2020) generates pseudo labels on unlabeled samples with weak augmentation and only keeps predictions with high confidence. Here both weak augmentation and high confidence filtering help produce high-quality trustworthy pseudo label targets. Then FixMatch learns to predict these pseudo labels given a heavily-augmented sample.\nFig. 16. Illustration of how FixMatch works. (Image source: Sohn et al. 2020) $$ \\begin{aligned} \\mathcal{L}_s \u0026= \\frac{1}{B} \\sum^B_{b=1} \\text{CE}[y_b, p_\\theta(y \\mid \\mathcal{A}_\\text{weak}(\\mathbf{x}_b))] \\\\ \\mathcal{L}_u \u0026= \\frac{1}{\\mu B} \\sum_{b=1}^{\\mu B} \\mathbb{1}[\\max(\\hat{y}_b) \\geq \\tau]\\;\\text{CE}(\\hat{y}_b, p_\\theta(y \\mid \\mathcal{A}_\\text{strong}(\\mathbf{u}_b))) \\end{aligned} $$ where $\\hat{y}_b$ is the pseudo label for an unlabeled example; $\\mu$ is a hyperparameter that determines the relative sizes of $\\mathcal{X}$ and $\\mathcal{U}$.\n Weak augmentation $\\mathcal{A}_\\text{weak}(.)$: A standard flip-and-shift augmentation Strong augmentation $\\mathcal{A}_\\text{strong}(.)$ : AutoAugment, Cutout, RandAugment, CTAugment Fig. 17. Performance of FixMatch and several other semi-supervised learning methods on image classification tasks. (Image source: Sohn et al. 2020) According to the ablation studies of FixMatch,\n Sharpening the predicted distribution with a temperature parameter $T$ does not have a significant impact when the threshold $\\tau$ is used. Cutout and CTAugment as part of strong augmentations are necessary for good performance. When the weak augmentation for label guessing is replaced with strong augmentation, the model diverges early in training. If discarding weak augmentation completely, the model overfit the guessed labels. Using weak instead of strong augmentation for pseudo label prediction leads to unstable performance. Strong data augmentation is critical. Combined with Powerful Pre-Training It is a common paradigm, especially in language tasks, to first pre-train a task-agnostic model on a large unsupervised data corpus via self-supervised learning and then fine-tune it on the downstream task with a small labeled dataset. Research has shown that we can obtain extra gain if combining semi-supervised learning with pretraining.\nZoph et al. (2020) studied to what degree self-training can work better than pre-training. Their experiment setup was to use ImageNet for pre-training or self-training to improve COCO. Note that when using ImageNet for self-training, it discards labels and only uses ImageNet samples as unlabeled data points. He et al. (2018) has demonstrated that ImageNet classification pre-training does not work well if the downstream task is very different, such as object detection.\nFig. 18. The effect of (a) data augment (from weak to strong) and (b) the labeled dataset size on the object detection performance. In the legend: `Rand Init` refers to a model initialized w/ random weights; `ImageNet` is initialized with a pre-trained checkpoint at 84.5% top-1 ImageNet accuracy; `ImageNet++` is initialized with a checkpoint with a higher accuracy 86.9%. (Image source: Zoph et al. 2020) Their experiments demonstrated a series of interesting findings:\n The effectiveness of pre-training diminishes with more labeled samples available for the downstream task. Pre-training is helpful in the low-data regimes (20%) but neutral or harmful in the high-data regime. Self-training helps in high data/strong augmentation regimes, even when pre-training hurts. Self-training can bring in additive improvement on top of pre-training, even using the same data source. Self-supervised pre-training (e.g. via SimCLR) hurts the performance in a high data regime, similar to how supervised pre-training does. Joint-training supervised and self-supervised objectives help resolve the mismatch between the pre-training and downstream tasks. Pre-training, joint-training and self-training are all additive. Noisy labels or un-targeted labeling (i.e. pre-training labels are not aligned with downstream task labels) is worse than targeted pseudo labeling. Self-training is computationally more expensive than fine-tuning on a pre-trained model. Chen et al. (2020) proposed a three-step procedure to merge the benefits of self-supervised pretraining, supervised fine-tuning and self-training together:\n Unsupervised or self-supervised pretrain a big model. Supervised fine-tune it on a few labeled examples. It is important to use a big (deep and wide) neural network. Bigger models yield better performance with fewer labeled samples. Distillation with unlabeled examples by adopting pseudo labels in self-training. It is possible to distill the knowledge from a large model into a small one because the task-specific use does not require extra capacity of the learned representation. The distillation loss is formatted as the following, where the teacher network is fixed with weights $\\hat{\\theta}_T$. $$ \\mathcal{L}_\\text{distill} = - (1-\\alpha) \\underbrace{\\sum_{(\\mathbf{x}^l_i, y_i) \\in \\mathcal{X}} \\big[ \\log p_{\\theta_S}(y_i \\mid \\mathbf{x}^l_i) \\big]}_\\text{Supervised loss} - \\alpha \\underbrace{\\sum_{\\mathbf{u}_i \\in \\mathcal{U}} \\Big[ \\sum_{i=1}^L p_{\\hat{\\theta}_T}(y^{(i)} \\mid \\mathbf{u}_i; T) \\log p_{\\theta_S}(y^{(i)} \\mid \\mathbf{u}_i; T) \\Big]}_\\text{Distillation loss using unlabeled data} $$ Fig. 19. A semi-supervised learning framework leverages unlabeled data corpus by (Left) task-agnostic unsupervised pretraining and (Right) task-specific self-training and distillation. (Image source: Chen et al. 2020) They experimented on the ImageNet classification task. The self-supervised pre-training uses SimCLRv2, a directly improved version of SimCLR. Observations in their empirical studies confirmed several learnings, aligned with Zoph et al. 2020:\n Bigger models are more label-efficient; Bigger/deeper project heads in SimCLR improve representation learning; Distillation using unlabeled data improves semi-supervised learning. Fig. 20. Comparison of performance by SimCLRv2 + semi-supervised distillation on ImageNet classification. (Image source: Chen et al. 2020) 💡 Quick summary of common themes among recent semi-supervised learning methods, many aiming to reduce confirmation bias:\n Apply valid and diverse noise to samples by advanced data augmentation methods. When dealing with images, MixUp is an effective augmentation. Mixup could work on language too, resulting in a small incremental improvement (Guo et al. 2019). Set a threshold and discard pseudo labels with low confidence. Set a minimum number of labeled samples per mini-batch. Sharpen the pseudo label distribution to reduce the class overlap. Citation Cited as:\n Weng, Lilian. (Dec 2021). Learning with not enough data part 1: semi-supervised learning. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2021-12-05-semi-supervised/.\n Or\n@article{weng2021semi, title = \u0026quot;Learning with not Enough Data Part 1: Semi-Supervised Learning\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2021\u0026quot;, month = \u0026quot;Dec\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2021-12-05-semi-supervised/\u0026quot; } References [1] Ouali, Hudelot \u0026amp; Tami. “An Overview of Deep Semi-Supervised Learning” arXiv preprint arXiv:2006.05278 (2020).\n[2] Sajjadi, Javanmardi \u0026amp; Tasdizen “Regularization With Stochastic Transformations and Perturbations for Deep Semi-Supervised Learning.” arXiv preprint arXiv:1606.04586 (2016).\n[3] Pham et al. “Meta Pseudo Labels.” CVPR 2021.\n[4] Laine \u0026amp; Aila. “Temporal Ensembling for Semi-Supervised Learning” ICLR 2017.\n[5] Tarvaninen \u0026amp; Valpola. “Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results.” NeuriPS 2017\n[6] Xie et al. “Unsupervised Data Augmentation for Consistency Training.” NeuriPS 2020.\n[7] Miyato et al. “Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning.” IEEE transactions on pattern analysis and machine intelligence 41.8 (2018).\n[8] Verma et al. “Interpolation consistency training for semi-supervised learning.” IJCAI 2019\n[9] Lee. “Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks.” ICML 2013 Workshop: Challenges in Representation Learning.\n[10] Iscen et al. “Label propagation for deep semi-supervised learning.” CVPR 2019.\n[11] Xie et al. “Self-training with Noisy Student improves ImageNet classification” CVPR 2020.\n[12] Jingfei Du et al. “Self-training Improves Pre-training for Natural Language Understanding.” 2020\n[13] Iscen et al. “Label propagation for deep semi-supervised learning.” CVPR 2019\n[14] Arazo et al. “Pseudo-labeling and confirmation bias in deep semi-supervised learning.” IJCNN 2020.\n[15] Berthelot et al. “MixMatch: A holistic approach to semi-supervised learning.” NeuriPS 2019\n[16] Berthelot et al. “ReMixMatch: Semi-supervised learning with distribution alignment and augmentation anchoring.” ICLR 2020\n[17] Sohn et al. “FixMatch: Simplifying semi-supervised learning with consistency and confidence.” CVPR 2020\n[18] Junnan Li et al. “DivideMix: Learning with Noisy Labels as Semi-supervised Learning.” 2020 [code]\n[19] Zoph et al. “Rethinking pre-training and self-training.” 2020.\n[20] Chen et al. “Big Self-Supervised Models are Strong Semi-Supervised Learners” 2020\n","permalink":"https://lilianweng.github.io/posts/2021-12-05-semi-supervised/","summary":"When facing a limited amount of labeled data for supervised learning tasks, four approaches are commonly discussed.\n Pre-training + fine-tuning: Pre-train a powerful task-agnostic model on a large unsupervised data corpus, e.g. pre-training LMs on free text, or pre-training vision models on unlabelled images via self-supervised learning, and then fine-tune it on the downstream task with a small set of labeled samples. Semi-supervised learning: Learn from the labelled and unlabeled samples together.","title":"Learning with not Enough Data Part 1: Semi-Supervised Learning"},{"content":"[Updated on 2022-03-13: add expert choice routing.] [Updated on 2022-06-10]: Greg and I wrote a shorted and upgraded version of this post, published on OpenAI Blog: \u0026ldquo;Techniques for Training Large Neural Networks\u0026rdquo;\nIn recent years, we are seeing better results on many NLP benchmark tasks with larger pre-trained language models. How to train large and deep neural networks is challenging, as it demands a large amount of GPU memory and a long horizon of training time.\nHowever an individual GPU worker has limited memory and the sizes of many large models have grown beyond a single GPU. There are several parallelism paradigms to enable model training across multiple GPUs, as well as a variety of model architecture and memory saving designs to help make it possible to train very large neural networks.\nTraining Parallelism The main bottleneck for training very large neural network models is the intense demand for a large amount of GPU memory, way above what can be hosted on an individual GPU machine. Besides the model weights (e.g. tens of billions of floating point numbers), it is usually even more expensive to store intermediate computation outputs such as gradients and optimizer states (e.g. momentums \u0026amp; variations in Adam). Additionally training a large model often pairs with a large training corpus and thus a single process may just take forever.\nAs a result, parallelism is necessary. Parallelism can happen at different dimensions, including data, model architecture, and tensor operation.\nData Parallelism The most naive way for Data parallelism (DP) is to copy the same model weights into multiple workers and assign a fraction of data to each worker to be processed at the same time.\nNaive DP cannot work well if the model size is larger than a single GPU node’s memory. Methods like GeePS (Cui et al. 2016) offload temporarily unused parameters back to CPU to work with limited GPU memory when the model is too big to fit into one machine. The data swapping transfer should happen at the backend and not interfere with training computation.\nAt the end of each minibatch, workers need to synchronize gradients or weights to avoid staleness. There are two main synchronization approaches and both have clear pros \u0026amp; cons.\n Bulk synchronous parallels (BSP): Workers sync data at the end of every minibatch. It prevents model weights staleness and good learning efficiency but each machine has to halt and wait for others to send gradients. Asynchronous parallel (ASP): Every GPU worker processes the data asynchronously, no waiting or stalling. However, it can easily lead to stale weights being used and thus lower the statistical learning efficiency. Even though it increases the computation time, it may not speed up training time to convergence. Somewhere in the middle is to synchronize gradients globally once every $x$ iterations ($x \u0026gt; 1$). This feature is called “gradient accumulation” in Distribution Data Parallel (DDP) since Pytorch v1.5 (Li et al. 2021). Bucketing gradients avoid immediate AllReduce operations but instead buckets multiple gradients into one AllReduce to improve throughput. Computation and communication scheduling optimization can be made based on the computation graph.\nFig. 1. Pseudo code for Pytorch DDP. (Image source: Li et al. 2021) Model Parallelism Model parallelism (MP) aims to solve the case when the model weights cannot fit into a single node. The computation and model parameters are partitioned across multiple machines. Different from data parallelism where each worker hosts a full copy of the entire model, MP only allocates a fraction of model parameters on one worker and thus both the memory usage and the computation are reduced.\nSince deep neural networks usually contain a stack of vertical layers, it feels straightforward to split a large model by layer, where a small consecutive set of layers are grouped into one partition on one worker. However, a naive implementation for running every data batch through multiple such workers with sequential dependency leads to big bubbles of waiting time and severe under-utilization of computation resources.\nFig. 2. A naive model parallelism setup where the model is vertically split into 4 partitions. Data is processed by one worker at a time due to sequential dependency, leading to large “bubbles” of idle time. (Image source: Huang et al. 2019) Pipeline Parallelism Pipeline parallelism (PP) combines model parallelism with data parallelism to reduce inefficient time “bubbles''. The main idea is to split one minibatch into multiple microbatches and enable each stage worker to process one microbatch simultaneously. Note that every microbatch needs two passes, one forward and one backward. Inter-worker communication only transfers activations (forward) and gradients (backward). How these passes are scheduled and how the gradients are aggregated vary in different approaches. The number of partitions (workers) is also known as pipeline depth.\nIn GPipe (Huang et al. 2019) gradients from multiple microbatches are aggregated and applied synchronously at the end. The synchronous gradient descent guarantees learning consistency and efficiency irrespective of the number of workers. As shown in Fig. 3, bubbles still exist but are much smaller than what’s in Fig. 2. Given $m$ evenly split microbatches and $d$ partitions, assuming both forward and backward per microbatch take one unit of time, the fraction of bubble is:\n $$ 1 - \\frac{2md}{(2m + 2(d-1))d} = \\frac{d-1}{m+d-1} $$ The GPipe paper observed that the bubble overhead is almost negligible if the number of microbatches is more than 4x the number of partitions $m \u0026gt; 4d$ (when activation recomputation is applied).\nFig. 3. Illustration of pipeline parallelism in GPipe with 4 microbatches and 4 partitions. GPipe aggregates and updates gradients across devices synchronously at the end of every batch. (Image source: Huang et al. 2019) GPipe achieves almost linear speedup in throughput with the number of devices, although it is not always guaranteed if the model parameters are not evenly distributed across workers.\nPipeDream (Narayanan et al. 2019) schedules each worker to alternatively process the forward and backward passes (1F1B). PipeDream names each model partition “stage” and each stage worker can have multiple replicas to run data parallelism. In this process, PipeDream uses a deterministic round-robin load balancing strategy to assign work among multiple replicas of stages to ensure that the forward and backward passes for the same minibatch happen on the same replica.\nFig. 4. Illustration of `1F1B` microbatch scheduling in PipeDream. (Image source: Harlap et al. 2018) Since PipeDream does not have an end-of-batch global gradient sync across all the workers, an native implementation of 1F1B can easily lead to the forward and backward passes of one microbatch using different versions of model weights, thus lowering the learning efficiency. PipeDream proposed a few designs to tackle this issue:\n Weight stashing: Each worker keeps track of several model versions and makes sure that the same version of weights are used in the forward and backward passes given one data batch. Vertical sync (Optional): The version of model weights flows between stage workers together with activations and gradients. Then the computation adopts the corresponding stashed version propagated from the previous worker. This process keeps version consistency across workers. Note that it is asynchronous, different from GPipe. At the beginning of a training run, PipeDream first profiles the computation memory cost and time of each layer in the model and then optimizes a solution for partitioning layers into stages, which is a dynamic programming problem.\nFig. 5. Results for VGG16 on ILSVRC12. (Top) Accuracy vs time. The integer marks the number of stage workers. ASP = Asynchronous parallel \u0026 BSP = Bulk synchronous parallels. (Bottom) Training time speedup for different parallelism configurations. Straight pipeline refers to pipeline parallelism without data parallelism. (Image source: Harlap et al. 2018) Two variations of PipeDream were later proposed to reduce the memory footprint by stashed model versions (Narayanan et al. 2021).\nPipeDream-flush adds a globally synchronized pipeline flush periodically, just like GPipe. In this way, it greatly reduces the memory footprint (i.e. only maintain a single version of model weights) by sacrificing a little throughput.\nFig. 6. Illustration of pipeline scheduling in PipeDream-flush. (Image source: (Narayanan et al. 2021) PipeDream-2BW maintains only two versions of model weights, where “2BW” is short for “double-buffered weights”. It generates a new model version every $k$ microbatches and $k$ should be larger than the pipeline depth $d$, $k \u0026gt; d$. A newly updated model version cannot fully replace the old version immediately since some leftover backward passes still depend on the old version. In total only two versions need to be saved so the memory cost is much reduced.\nFig. 7. Illustration of pipeline scheduling in PipeDream-2BW. (Image source: (Narayanan et al. 2021) Tensor Parallelism Both model and pipeline parallelisms split a model vertically. OTOH we can horizontally partition the computation for one tensor operation across multiple devices, named Tensor parallelism (TP).\nLet\u0026rsquo;s take the transformer as an example given its popularity. The transformer model mainly consists of layers of MLP and self-attention blocks. Megatron-LM (Shoeybi et al. 2020) adopts a simple way to parallelize intra-layer computation for MLP and self-attention.\nA MLP layer in a transformer contains a GEMM (General matrix multiply) followed by an non-linear GeLU transfer. Let’s split weight matrix $A$ by column:\n $$ \\begin{aligned} \\text{Split }A \u0026= [A_1, A_2] \\\\ Y \u0026=\\text{GeLU}(XA) \\\\ [Y_1, Y_2] \u0026= [\\text{GeLU}(XA_1), \\text{GeLU}(XA_2)] \\end{aligned} $$ The attention block runs GEMM with query ($Q$), key ($K$), and value weights ($V$) according to the above partitioning in parallel and then combines them with another GEMM to produce the attention head results.\n $$ \\text{Attention}(X, Q, K, V) = \\text{softmax}(\\frac{(XQ) (XK)^\\top}{\\sqrt{d_k}}) XV $$ Fig. 8. Illustration of tensor parallelism for key transformer components proposed in Megatron-LM. (Image source: Shoeybi et al. 2020) Narayanan et al. (2021) combined pipeline, tensor and data parallelism with a new pipeline scheduling strategy and named their approach PTD-P. Instead of only positioning a continuous set of layers (“model chunk”) on a device, each worker can be assigned with multiple chunks of smaller continuous subsets of layers (e.g. device 1 has layers 1, 2, 9, 10; device 2 has layers 3, 4, 11, 12; each has two model chunks). The number of microbatches in one batch should be exactly divided by the number of workers ($m % d = 0$). If there are $v$ model chunks per worker, the pipeline bubble time can be reduced by a multiplier of $v$ compared to a GPipe scheduling.\nFig. 9. (Top) Default `1F1B` pipeline schedule as in PipeDream-flush. (Bottom) Interleaved 1F1B pipeline schedule. First model chunks are in dark colors and second chunks are in light colors. (Image source: Narayanan et al. 202)) Mixture-of-Experts (MoE) The Mixture-of-Experts (MoE) approach attracts a lot of attention recently as researchers (mainly from Google) try to push the limit of model size. The core of the idea is ensembling learning: Combination of multiple weak learners gives you a strong learner!\nWithin one deep neural network, ensembling can be implemented with a gating mechanism connecting multiple experts (Shazeer et al., 2017). The gating mechanism controls which subset of the network (e.g. which experts) should be activated to produce outputs. The paper named it \u0026ldquo;sparsely gated mixture-of-experts\u0026rdquo; (MoE) layer.\nPrecisely one MoE layer contains\n $n$ feed-forward networks as experts $\\{E_i\\}^n_{i=1}$ A trainable gating network $G$ to learn a probability distribution over $n$ experts so as to route the traffic to a few selected experts. Depending on the gating outputs, not every expert has to be evaluated. When the number of experts is too large, we can consider using a two-level hierarchical MoE.\nFig. 10. Illustration of a mixture-of-experts (MoE) layer. Only 2 out of $n$ experts are selected and activated by the gating network. (Image source: Shazeer et al., 2017) A simple choice of $G$ is to multiply the input with a trainable weight matrix $G_g$ and then do softmax: $G_\\sigma (x) = \\text{softmax}(x W_g)$. However, this produces a dense control vector for gating and does not help save computation resources because we don\u0026rsquo;t need to evaluate an expert only when $G^{(i)}(x)=0$. Thus the MoE layer only keeps the top $k$ values. It also adds tunable Gaussian noise into $G$ to improve load balancing. This mechanism is called noisy top-k gating.\n $$ \\begin{aligned} G(x) \u0026= \\text{softmax}( \\text{topk}(H(x), k)) \\\\ H^{(i)}(x) \u0026= (xW_g)^{(i)} + \\epsilon \\cdot \\text{softplus}((xW_\\text{noise})^{(i)} ); \\quad \\epsilon \\sim \\mathcal{N}(0, \\mathbf{1}) \\\\ \\text{topk}^{(i)}(v, k) \u0026= \\begin{cases} v^{(i)} \u0026 \\text{if }v^{(i)}\\text{ is in the top }k\\text{ elements of }v \\\\ -\\infty \u0026 \\text{otherwise} \\end{cases} \\end{aligned} $$ where the superscript $v^{(i)}$ denotes the i-th dimension of the vector $v$. The function $\\text{topk}(., k)$ selected the top $k$ dimensions with highest values by setting other dimensions to $-\\infty$.\nTo avoid the self-reinforcing effect that the gating network may favor a few strong experts all the time, Shazeer et al. (2017) proposed a soft constraint via an additional importance loss to encourage all the experts to have the same weights. It is equivalent to the square of the coefficient of variation of batchwise average value per expert.\n $$ L_\\text{aux} = w_\\text{aux} \\cdot \\text{CV}(\\sum_{x \\in X} G(x))^2 $$ where $ \\text{CV}$ is the coefficient of variation and the loss weight $w_\\text{aux}$ is a hyperparameter to tune.\nBecause every expert network only gets a fraction of training samples (\u0026ldquo;The shrinking batch problem\u0026rdquo;), we should try to use a batch size as large as possible in MoE. However, it is restricted by GPU memory. Data parallelism and model parallelism can be applied to improve the throughput.\nFig. 11. Test perplexity on 1-Billion-Word language modeling benchmark. (Left) The model capacity increases from left to right, containing 4, 32, 256, 256, 1024 and 4096 experts. (Right) Performance of the 4 billion parameters MoE model, the largest one in the left figure, under different computation budgets. (Image source: Shazeer et al., 2017) GShard (Lepikhin et al., 2020) scales the MoE transformer model up to 600 billion parameters with sharding. The MoE transformer replaces every other feed forward layer with a MoE layer. The sharded MoE transformer only has the MoE layers sharded across multiple machines, while other layers are simply duplicated.\nThere are several improved designs for the gating function $G$ in GShard:\n Expert capacity: The amount of tokens going through one expert should not go above a threshold, named “expert capacity”. If a token is routed to experts that have reached their capacity, the token would be marked “overflowed” and the gating output is changed to a zero vector. Local group dispatching: Tokens are evenly partitioned into multiple local groups and the expert capacity is enforced on the group level. Auxiliary loss: The motivation is similar to the original MoE aux loss. They add an auxiliary loss to minimize the mean square of the fraction of data routed to each expert. Random routing: The 2nd-best expert is selected with a probability proportional to its weight; otherwise, GShard follows a random routing, so as to add some randomness. Fig. 12. Pseudo code of the group-level top-2 gating mechanism with auxiliary loss in GShard. (Image source: Lepikhin et al., 2020) Switch Transformer (Fedus et al. 2021) scales the model size up to trillions of parameters (!!) by replacing the dense feed forward layer with a sparse switch FFN layer in which each input is only routed to one expert network. The auxiliary loss for load balancing is $\\text{loss}_\\text{aux} = w_\\text{aux} \\sum_{i=1}^n f_i p_i$ given $n$ experts, where $f_i$ is the fraction of tokens routed to the $i$-th expert and $p_i$ is the routing probability for expert $i$ predicted by the gating network.\nFig. 13. Switch transformer. The sparse switch FFN layer is in the blue boxes. (Image source: Fedus et al. 2021) To improve training stability, switch transformer incorporates the following designs:\n Selective precision. They showed that selectively casting only a local part of the model to FP32 precision improves stability, while avoiding the expensive communication cost of FP32 tensors. The FP32 precision is only used within the body of the router function and the results are recast to FP16. Smaller initialization. The initialization of weight matrices is sampled from a truncated normal distribution with mean $\\mu=0$ and stdev $\\sigma = \\sqrt{s/n}$. They also recommended reducing the transformer initialization scale parameter $s=1$ to $s=0.1$. Use higher expert dropout. Fine-tuning often works with a small dataset. To avoid overfitting, the dropout rate within each expert is increased by a significant amount. Interestingly they found that increasing dropout across all layers lead to poor performance. In the paper, they used a dropout rate 0.1 at non-expert layers but 0.4 within expert FF layers. The switch transformer paper summarized different data and model parallelism strategies for training large models with a nice illustration:\nFig. 14. An illustration of various parallelism strategies on how (Top) model weights and (Bottom) data are split over multiple GPU cores. In the top row, each color denotes a unique weight matrix. In the bottom row, different colors indicate different sets of tokens. (Image source: Fedus et al. 2021) Both GShard top-2 and Switch Transformer top-1 depend on token choice, where each token picks the best one or two experts to route through. They both adopt an auxiliary loss to encourage more balanced load allocation but it does not guarantee the best performance. Furthermore, the expert capacity limit may lead to wasted tokens as they would be discarded if an expert reaches its capacity limit.\nExport Choice (EC) (Zhou et al. 2022) routing instead enables each expert to select the top-$k$ tokens. In this way, each expert naturally guarantees a fixed capacity and each token may be routed to multiple experts. EC can achieve perfect load balancing and is shown to improve training convergence by 2x.\nGiven $e$ experts and an input matrix $X \\in \\mathbb{R}^{n \\times d}$, the token-to-expert affinity scores are computed by: $$ S = \\text{softmax}(X \\cdot W_g), \\text{where } W_g \\in \\mathbb{R}^{d \\times e}, S \\in \\mathbb{R}^{n \\times e} $$\nA token-to-expert assignment is represented by three matrices, $I, G \\in \\mathbb{R}^{e\\times k}$ and $P \\in \\mathbb{R}^{e \\times k \\times n}$. $I[i,j]$ annotates which token is the $j$-th selection by the $i$-th expert. The gating matrix $G$ stores the routing weights of selected tokens. $P$ is the one-hot version of $I$, used to produce the input matrix ($P \\cdot X \\in \\mathbb{R}^{e \\times k \\times d}$) for the gated FFN layer. $$ G, I = \\text{top-k}(S^\\top, k)\\quad P = \\text{one-hot}(I) $$\nOne regularization that export choice routing explored is to limit the maximum number of experts per token.\n $$ \\begin{aligned} \u0026 \\max_A \\langle S^\\top, A\\rangle + \\lambda H(A) \\\\ \\text{s.t.} \u0026 \\forall i: \\sum_{j'} A[i, j'] = k,\\quad \\forall j: \\sum_{i'} A[i', j] \\leq b,\\quad \\forall i,j: 0 \\leq A[i,j] \\leq 1 \\end{aligned} $$ where each entry $A[i,j]$ in $A \\in \\mathbb{R}^{e \\times n}$ marks whether the $i$-the expert selects the $j$-th token. Solving this is non-trivial. The paper used Dykstra\u0026rsquo;s algorithm that runs a sequence of multiple iterative computation steps. Capped expert choice results in a slight decrease in the fine-tuning performance in the experiments.\nThe parameter $k$ is determined by $k=nc/e$, where $n$ is the total number of tokens in one batch and $c$ is a capacity factor indicating the average number of experts used by one token. The paper used $c=2$ in most experiments, but EC with $c=1$ still outperforms the top-1 token choice gating. Interestingly, $c=0.5$ only marginally hurts the training performance.\nOne big drawback of EC is that it does not work when the batch size is too small, neither for auto-regressive text generation, because it needs to know the future tokens to do the top-$k$ selection.\nOther Memory Saving Designs CPU Offloading When the GPU memory is full, one option is to offload temporarily unused data to CPU and read them back when needed later (Rhu et al. 2016). The idea of CPU offloading is straightforward but is less popular in recent years due to the slowdown it brings into the training time.\nActivation Recomputation Activation recomputation (also known as “activation checkpointing” or “gradient checkpointing”; Chen et al. 2016) is a smart yet simple idea to reduce memory footprint at the cost of computation time. It reduces the memory cost of training a $\\ell$ layer deep neural net to $O(\\sqrt{\\ell})$, which only additionally consumes an extra forward pass computation per batch.\nLet\u0026rsquo;s say, we evenly divide an $\\ell$-layer network into $d$ partitions. Only activations at partition boundaries are saved and communicated between workers. Intermediate activations at intra-partition layers are still needed for computing gradients so they are recomputed during backward passes. With activation recomputation, the memory cost for training $M(\\ell)$ is:\n $$ M(\\ell) =\\max_{i=1,\\dots,k} \\underbrace{\\text{cost-of-one-partition}(i)}_\\text{cost of back-propagation on the i-th partition} + \\underbrace{O(d)}_\\text{store intermediate outputs} = O(\\frac{\\ell}{d}) + O(d) $$ The minimum cost is $O(\\sqrt{\\ell})$ at $d=\\sqrt{\\ell}$.\nActivation recompuation trick can give sublinear memory cost with respect to the model size.\nFig. 15. The memory cost of different memory saving algorithms. Sharing: Memory used by intermediate results is recycled when no longer needed. Inplace: Save the output directly into memory of an input value. (Image source: Chen et al. 2016) Mixed Precision Training Narang \u0026amp; Micikevicius et al. (2018) introduced a method to train models using half-precision floating point (FP16) numbers without losing model accuracy.\nFig. 16. The procedure of mixed precision training at one layer. (Image source: Narang \u0026 Micikevicius, et al. 2018) Three techniques to avoid losing critical information at half-precision:\n Full-precision master copy of weights. Maintain a full precision (FP32) copy of model weights that accumulates gradients. The numbers are rounded up to half-precision for forward \u0026amp; backward passes. The motivation is that each gradient update (i.e. gradient times the learning rate) might be too small to be fully contained within the FP16 range (i.e. $2^{-24}$ becomes zero in FP16). Loss scaling. Scale up the loss to better handle gradients with small magnitudes (See Fig. 16). Scaling up the gradients helps shift them to occupy a larger section towards the right section (containing larger values) of the representable range, preserving values that are otherwise lost. Arithmetic precision. For common network arithmetic (e.g. vector dot-product, reduction by summing up vector elements), we can accumulate the partial results in FP32 and then save the final output as FP16 before saving into memory. Point-wise operations can be executed in either FP16 or FP32. Fig. 17. The histogram of gradients in full precision. The left part up to $2^{-24}$ will be zero-ed off once the model switches to FP16. (Image source: Narang \u0026 Micikevicius, et al. 2018) In their experiments, loss scaling is not needed for some networks (e.g. image classification, Faster R-CNN), but necessary for others (e.g. Multibox SSD, big LSTM language model).\nCompression Intermediate results often consume a lot of memory, although they are only needed in one forward pass and one backward pass. There is a noticeable temporal gap between these two uses. Thus Jain et al. (2018) proposed a data encoding strategy to compress the intermediate results after the first use in the first pass and then decode it back for back-propagation later.\nTheir system Gist incorporates two encoding schemes: Layer-specific lossless encoding; focus on ReLU-Pool (“Binarize”) and ReLU-Conv (“Sparse storage and dense computation”) patterns. Aggressive lossy encoding; use delayed precision reduction (DPR). They observed that the first immediate use of feature maps should be kept at high precision but the second use can tolerate lower precision.\nThe experiments showed that Gist can reduce the memory cost by 2x across 5 SOTA image classification DNNs, with an average of 1.8x with only 4% performance overhead.\nMemory Efficient Optimizer Optimizers are eager for memory consumption. Take the popular Adam optimizer as an example, it internally needs to maintain momentums and variances, both at the same scale as gradients and model parameters. All out of a sudden, we need to save 4x the memory of model weights.\nSeveral optimizers have been proposed to reduce the memory footprint. For example, instead of storing the full momentums and variations as in Adam, Adafactor (Shazeer et al. 2018) only tracks the per-row and per-column sums of the moving averages and then estimates the second moments based on these sums. SM3 (Anil et al. 2019) describes a different adaptive optimization method, leading to largely reduced memory as well.\nZeRO (Zero Redundancy Optimizer; Rajbhandari et al. 2019) optimizes the memory used for training large models based on the observation about two major memory consumption of large model training:\n The majority is occupied by model states, including optimizer states (e.g. Adam momentums and variances), gradients and parameters. Mixed-precision training demands a lot of memory since the optimizer needs to keep a copy of FP32 parameters and other optimizer states, besides the FP16 version. The remaining is consumed by activations, temporary buffers and unusable fragmented memory (named residual states in the paper). ZeRO combines two approaches, ZeRO-DP and ZeRO-R. ZeRO-DP is an enhanced data parallelism to avoid simple redundancy over model states. It partitions optimizer state, gradients and parameters across multiple data parallel processes via a dynamic communication schedule to minimize the communication volume. ZeRO-R optimizes the memory consumption of residual states, using partitioned activation recomputation, constant buffer size and on-the-fly memory defragmentation.\nCitation Cited as:\n Weng, Lilian. (Sep 2021). How to train really large models on many GPUs? Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2021-09-25-train-large/.\n Or\n@article{weng2021large, title = \u0026quot;How to Train Really Large Models on Many GPUs?\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2021\u0026quot;, month = \u0026quot;Sep\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2021-09-25-train-large/\u0026quot; } References [1] Li et al. “PyTorch Distributed: Experiences on Accelerating Data Parallel Training” VLDB 2020.\n[2] Cui et al. “GeePS: Scalable deep learning on distributed GPUs with a GPU-specialized parameter server” EuroSys 2016\n[3] Shoeybi et al. “Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.” arXiv preprint arXiv:1909.08053 (2019).\n[4] Narayanan et al. “Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM.” arXiv preprint arXiv:2104.04473 (2021).\n[5] Huang et al. “GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism.” arXiv preprint arXiv:1811.06965 (2018).\n[6] Narayanan et al. \u0026ldquo;PipeDream: Generalized Pipeline Parallelism for DNN Training.\u0026quot; SOSP 2019.\n[7] Narayanan et al. “Memory-Efficient Pipeline-Parallel DNN Training.” ICML 2021.\n[8] Shazeer et al. “The Sparsely-Gated Mixture-of-Experts Layer Noam.” arXiv preprint arXiv:1701.06538 (2017).\n[9] Lepikhin et al. “GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding.” arXiv preprint arXiv:2006.16668 (2020).\n[10] Fedus et al. “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity.” arXiv preprint arXiv:2101.03961 (2021).\n[11] Narang \u0026amp; Micikevicius, et al. “Mixed precision training.” ICLR 2018.\n[12] Chen et al. 2016 “Training Deep Nets with Sublinear Memory Cost.” arXiv preprint arXiv:1604.06174 (2016).\n[13] Jain et al. “Gist: Efficient data encoding for deep neural network training.” ISCA 2018.\n[14] Shazeer \u0026amp; Stern. “Adafactor: Adaptive learning rates with sublinear memory cost.” arXiv preprint arXiv:1804.04235 (2018).\n[15] Anil et al. “Memory-Efficient Adaptive Optimization.” arXiv preprint arXiv:1901.11150 (2019).\n[16] Rajbhandari et al. “ZeRO: Memory Optimization Towards Training A Trillion Parameter Models Samyam.” arXiv preprint arXiv:1910.02054 (2019).\n[17] Zhou et al. “Mixture-of-Experts with Expert Choice Routing” arXiv preprint arXiv:2202.09368 (2022).\n","permalink":"https://lilianweng.github.io/posts/2021-09-25-train-large/","summary":"[Updated on 2022-03-13: add expert choice routing.] [Updated on 2022-06-10]: Greg and I wrote a shorted and upgraded version of this post, published on OpenAI Blog: \u0026ldquo;Techniques for Training Large Neural Networks\u0026rdquo;\nIn recent years, we are seeing better results on many NLP benchmark tasks with larger pre-trained language models. How to train large and deep neural networks is challenging, as it demands a large amount of GPU memory and a long horizon of training time.","title":"How to Train Really Large Models on Many GPUs?"},{"content":"[Updated on 2021-09-19: Highly recommend this blog post on score-based generative modeling by Yang Song (author of several key papers in the references)]. [Updated on 2022-08-27: Added classifier-free guidance, GLIDE, unCLIP and Imagen. [Updated on 2022-08-31: Added latent diffusion model.\nSo far, I\u0026rsquo;ve written about three types of generative models, GAN, VAE, and Flow-based models. They have shown great success in generating high-quality samples, but each has some limitations of its own. GAN models are known for potentially unstable training and less diversity in generation due to their adversarial training nature. VAE relies on a surrogate loss. Flow models have to use specialized architectures to construct reversible transform.\nDiffusion models are inspired by non-equilibrium thermodynamics. They define a Markov chain of diffusion steps to slowly add random noise to data and then learn to reverse the diffusion process to construct desired data samples from the noise. Unlike VAE or flow models, diffusion models are learned with a fixed procedure and the latent variable has high dimensionality (same as the original data).\nFig. 1. Overview of different types of generative models. What are Diffusion Models? Several diffusion-based generative models have been proposed with similar ideas underneath, including diffusion probabilistic models (Sohl-Dickstein et al., 2015), noise-conditioned score network (NCSN; Yang \u0026amp; Ermon, 2019), and denoising diffusion probabilistic models (DDPM; Ho et al. 2020).\nForward diffusion process Given a data point sampled from a real data distribution $\\mathbf{x}_0 \\sim q(\\mathbf{x})$, let us define a forward diffusion process in which we add small amount of Gaussian noise to the sample in $T$ steps, producing a sequence of noisy samples $\\mathbf{x}_1, \\dots, \\mathbf{x}_T$. The step sizes are controlled by a variance schedule $\\{\\beta_t \\in (0, 1)\\}_{t=1}^T$.\n $$ q(\\mathbf{x}_t \\vert \\mathbf{x}_{t-1}) = \\mathcal{N}(\\mathbf{x}_t; \\sqrt{1 - \\beta_t} \\mathbf{x}_{t-1}, \\beta_t\\mathbf{I}) \\quad q(\\mathbf{x}_{1:T} \\vert \\mathbf{x}_0) = \\prod^T_{t=1} q(\\mathbf{x}_t \\vert \\mathbf{x}_{t-1}) $$ The data sample $\\mathbf{x}_0$ gradually loses its distinguishable features as the step $t$ becomes larger. Eventually when $T \\to \\infty$, $\\mathbf{x}_T$ is equivalent to an isotropic Gaussian distribution.\nFig. 2. The Markov chain of forward (reverse) diffusion process of generating a sample by slowly adding (removing) noise. (Image source: Ho et al. 2020 with a few additional annotations) A nice property of the above process is that we can sample $\\mathbf{x}_t$ at any arbitrary time step $t$ in a closed form using reparameterization trick. Let $\\alpha_t = 1 - \\beta_t$ and $\\bar{\\alpha}_t = \\prod_{i=1}^t \\alpha_i$:\n $$ \\begin{aligned} \\mathbf{x}_t \u0026= \\sqrt{\\alpha_t}\\mathbf{x}_{t-1} + \\sqrt{1 - \\alpha_t}\\boldsymbol{\\epsilon}_{t-1} \u0026 \\text{ ;where } \\boldsymbol{\\epsilon}_{t-1}, \\boldsymbol{\\epsilon}_{t-2}, \\dots \\sim \\mathcal{N}(\\mathbf{0}, \\mathbf{I}) \\\\ \u0026= \\sqrt{\\alpha_t \\alpha_{t-1}} \\mathbf{x}_{t-2} + \\sqrt{1 - \\alpha_t \\alpha_{t-1}} \\bar{\\boldsymbol{\\epsilon}}_{t-2} \u0026 \\text{ ;where } \\bar{\\boldsymbol{\\epsilon}}_{t-2} \\text{ merges two Gaussians (*).} \\\\ \u0026= \\dots \\\\ \u0026= \\sqrt{\\bar{\\alpha}_t}\\mathbf{x}_0 + \\sqrt{1 - \\bar{\\alpha}_t}\\boldsymbol{\\epsilon} \\\\ q(\\mathbf{x}_t \\vert \\mathbf{x}_0) \u0026= \\mathcal{N}(\\mathbf{x}_t; \\sqrt{\\bar{\\alpha}_t} \\mathbf{x}_0, (1 - \\bar{\\alpha}_t)\\mathbf{I}) \\end{aligned} $$ (*) Recall that when we merge two Gaussians with different variance, $\\mathcal{N}(\\mathbf{0}, \\sigma_1^2\\mathbf{I})$ and $\\mathcal{N}(\\mathbf{0}, \\sigma_2^2\\mathbf{I})$, the new distribution is $\\mathcal{N}(\\mathbf{0}, (\\sigma_1^2 + \\sigma_2^2)\\mathbf{I})$. Here the merged standard deviation is $\\sqrt{(1 - \\alpha_t) + \\alpha_t (1-\\alpha_{t-1})} = \\sqrt{1 - \\alpha_t\\alpha_{t-1}}$.\nUsually, we can afford a larger update step when the sample gets noisier, so $\\beta_1 \u0026lt; \\beta_2 \u0026lt; \\dots \u0026lt; \\beta_T$ and therefore $\\bar{\\alpha}_1 \u0026gt; \\dots \u0026gt; \\bar{\\alpha}_T$.\nConnection with stochastic gradient Langevin dynamics Langevin dynamics is a concept from physics, developed for statistically modeling molecular systems. Combined with stochastic gradient descent, stochastic gradient Langevin dynamics (Welling \u0026amp; Teh 2011) can produce samples from a probability density $p(\\mathbf{x})$ using only the gradients $\\nabla_\\mathbf{x} \\log p(\\mathbf{x})$ in a Markov chain of updates:\n $$ \\mathbf{x}_t = \\mathbf{x}_{t-1} + \\frac{\\delta}{2} \\nabla_\\mathbf{x} \\log p(\\mathbf{x}_{t-1}) + \\sqrt{\\delta} \\boldsymbol{\\epsilon}_t ,\\quad\\text{where } \\boldsymbol{\\epsilon}_t \\sim \\mathcal{N}(\\mathbf{0}, \\mathbf{I}) $$ where $\\delta$ is the step size. When $T \\to \\infty, \\epsilon \\to 0$, $\\mathbf{x}_T$ equals to the true probability density $p(\\mathbf{x})$.\nCompared to standard SGD, stochastic gradient Langevin dynamics injects Gaussian noise into the parameter updates to avoid collapses into local minima.\nReverse diffusion process If we can reverse the above process and sample from $q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t)$, we will be able to recreate the true sample from a Gaussian noise input, $\\mathbf{x}_T \\sim \\mathcal{N}(\\mathbf{0}, \\mathbf{I})$. Note that if $\\beta_t$ is small enough, $q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t)$ will also be Gaussian. Unfortunately, we cannot easily estimate $q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t)$ because it needs to use the entire dataset and therefore we need to learn a model $p_\\theta$ to approximate these conditional probabilities in order to run the reverse diffusion process.\n $$ p_\\theta(\\mathbf{x}_{0:T}) = p(\\mathbf{x}_T) \\prod^T_{t=1} p_\\theta(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t) \\quad p_\\theta(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t) = \\mathcal{N}(\\mathbf{x}_{t-1}; \\boldsymbol{\\mu}_\\theta(\\mathbf{x}_t, t), \\boldsymbol{\\Sigma}_\\theta(\\mathbf{x}_t, t)) $$ Fig. 3. An example of training a diffusion model for modeling a 2D swiss roll data. (Image source: Sohl-Dickstein et al., 2015) It is noteworthy that the reverse conditional probability is tractable when conditioned on $\\mathbf{x}_0$:\n $$ q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0) = \\mathcal{N}(\\mathbf{x}_{t-1}; \\color{blue}{\\tilde{\\boldsymbol{\\mu}}}(\\mathbf{x}_t, \\mathbf{x}_0), \\color{red}{\\tilde{\\beta}_t} \\mathbf{I}) $$ Using Bayes' rule, we have:\n $$ \\begin{aligned} q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0) \u0026= q(\\mathbf{x}_t \\vert \\mathbf{x}_{t-1}, \\mathbf{x}_0) \\frac{ q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_0) }{ q(\\mathbf{x}_t \\vert \\mathbf{x}_0) } \\\\ \u0026\\propto \\exp \\Big(-\\frac{1}{2} \\big(\\frac{(\\mathbf{x}_t - \\sqrt{\\alpha_t} \\mathbf{x}_{t-1})^2}{\\beta_t} + \\frac{(\\mathbf{x}_{t-1} - \\sqrt{\\bar{\\alpha}_{t-1}} \\mathbf{x}_0)^2}{1-\\bar{\\alpha}_{t-1}} - \\frac{(\\mathbf{x}_t - \\sqrt{\\bar{\\alpha}_t} \\mathbf{x}_0)^2}{1-\\bar{\\alpha}_t} \\big) \\Big) \\\\ \u0026= \\exp \\Big(-\\frac{1}{2} \\big(\\frac{\\mathbf{x}_t^2 - 2\\sqrt{\\alpha_t} \\mathbf{x}_t \\color{blue}{\\mathbf{x}_{t-1}} \\color{black}{+ \\alpha_t} \\color{red}{\\mathbf{x}_{t-1}^2} }{\\beta_t} + \\frac{ \\color{red}{\\mathbf{x}_{t-1}^2} \\color{black}{- 2 \\sqrt{\\bar{\\alpha}_{t-1}} \\mathbf{x}_0} \\color{blue}{\\mathbf{x}_{t-1}} \\color{black}{+ \\bar{\\alpha}_{t-1} \\mathbf{x}_0^2} }{1-\\bar{\\alpha}_{t-1}} - \\frac{(\\mathbf{x}_t - \\sqrt{\\bar{\\alpha}_t} \\mathbf{x}_0)^2}{1-\\bar{\\alpha}_t} \\big) \\Big) \\\\ \u0026= \\exp\\Big( -\\frac{1}{2} \\big( \\color{red}{(\\frac{\\alpha_t}{\\beta_t} + \\frac{1}{1 - \\bar{\\alpha}_{t-1}})} \\mathbf{x}_{t-1}^2 - \\color{blue}{(\\frac{2\\sqrt{\\alpha_t}}{\\beta_t} \\mathbf{x}_t + \\frac{2\\sqrt{\\bar{\\alpha}_{t-1}}}{1 - \\bar{\\alpha}_{t-1}} \\mathbf{x}_0)} \\mathbf{x}_{t-1} \\color{black}{ + C(\\mathbf{x}_t, \\mathbf{x}_0) \\big) \\Big)} \\end{aligned} $$ where $C(\\mathbf{x}_t, \\mathbf{x}_0)$ is some function not involving $\\mathbf{x}_{t-1}$ and details are omitted. Following the standard Gaussian density function, the mean and variance can be parameterized as follows (recall that $\\alpha_t = 1 - \\beta_t$ and $\\bar{\\alpha}_t = \\prod_{i=1}^T \\alpha_i$):\n $$ \\begin{aligned} \\tilde{\\beta}_t \u0026= 1/(\\frac{\\alpha_t}{\\beta_t} + \\frac{1}{1 - \\bar{\\alpha}_{t-1}}) = 1/(\\frac{\\alpha_t - \\bar{\\alpha}_t + \\beta_t}{\\beta_t(1 - \\bar{\\alpha}_{t-1})}) = \\color{green}{\\frac{1 - \\bar{\\alpha}_{t-1}}{1 - \\bar{\\alpha}_t} \\cdot \\beta_t} \\\\ \\tilde{\\boldsymbol{\\mu}}_t (\\mathbf{x}_t, \\mathbf{x}_0) \u0026= (\\frac{\\sqrt{\\alpha_t}}{\\beta_t} \\mathbf{x}_t + \\frac{\\sqrt{\\bar{\\alpha}_{t-1} }}{1 - \\bar{\\alpha}_{t-1}} \\mathbf{x}_0)/(\\frac{\\alpha_t}{\\beta_t} + \\frac{1}{1 - \\bar{\\alpha}_{t-1}}) \\\\ \u0026= (\\frac{\\sqrt{\\alpha_t}}{\\beta_t} \\mathbf{x}_t + \\frac{\\sqrt{\\bar{\\alpha}_{t-1} }}{1 - \\bar{\\alpha}_{t-1}} \\mathbf{x}_0) \\color{green}{\\frac{1 - \\bar{\\alpha}_{t-1}}{1 - \\bar{\\alpha}_t} \\cdot \\beta_t} \\\\ \u0026= \\frac{\\sqrt{\\alpha_t}(1 - \\bar{\\alpha}_{t-1})}{1 - \\bar{\\alpha}_t} \\mathbf{x}_t + \\frac{\\sqrt{\\bar{\\alpha}_{t-1}}\\beta_t}{1 - \\bar{\\alpha}_t} \\mathbf{x}_0\\\\ \\end{aligned} $$ Thanks to the nice property, we can represent $\\mathbf{x}_0 = \\frac{1}{\\sqrt{\\bar{\\alpha}_t}}(\\mathbf{x}_t - \\sqrt{1 - \\bar{\\alpha}_t}\\boldsymbol{\\epsilon}_t)$ and plug it into the above equation and obtain:\n $$ \\begin{aligned} \\tilde{\\boldsymbol{\\mu}}_t \u0026= \\frac{\\sqrt{\\alpha_t}(1 - \\bar{\\alpha}_{t-1})}{1 - \\bar{\\alpha}_t} \\mathbf{x}_t + \\frac{\\sqrt{\\bar{\\alpha}_{t-1}}\\beta_t}{1 - \\bar{\\alpha}_t} \\frac{1}{\\sqrt{\\bar{\\alpha}_t}}(\\mathbf{x}_t - \\sqrt{1 - \\bar{\\alpha}_t}\\boldsymbol{\\epsilon}_t) \\\\ \u0026= \\color{cyan}{\\frac{1}{\\sqrt{\\alpha_t}} \\Big( \\mathbf{x}_t - \\frac{1 - \\alpha_t}{\\sqrt{1 - \\bar{\\alpha}_t}} \\boldsymbol{\\epsilon}_t \\Big)} \\end{aligned} $$ As demonstrated in Fig. 2., such a setup is very similar to VAE and thus we can use the variational lower bound to optimize the negative log-likelihood.\n $$ \\begin{aligned} - \\log p_\\theta(\\mathbf{x}_0) \u0026\\leq - \\log p_\\theta(\\mathbf{x}_0) + D_\\text{KL}(q(\\mathbf{x}_{1:T}\\vert\\mathbf{x}_0) \\| p_\\theta(\\mathbf{x}_{1:T}\\vert\\mathbf{x}_0) ) \\\\ \u0026= -\\log p_\\theta(\\mathbf{x}_0) + \\mathbb{E}_{\\mathbf{x}_{1:T}\\sim q(\\mathbf{x}_{1:T} \\vert \\mathbf{x}_0)} \\Big[ \\log\\frac{q(\\mathbf{x}_{1:T}\\vert\\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{0:T}) / p_\\theta(\\mathbf{x}_0)} \\Big] \\\\ \u0026= -\\log p_\\theta(\\mathbf{x}_0) + \\mathbb{E}_q \\Big[ \\log\\frac{q(\\mathbf{x}_{1:T}\\vert\\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{0:T})} + \\log p_\\theta(\\mathbf{x}_0) \\Big] \\\\ \u0026= \\mathbb{E}_q \\Big[ \\log \\frac{q(\\mathbf{x}_{1:T}\\vert\\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{0:T})} \\Big] \\\\ \\text{Let }L_\\text{VLB} \u0026= \\mathbb{E}_{q(\\mathbf{x}_{0:T})} \\Big[ \\log \\frac{q(\\mathbf{x}_{1:T}\\vert\\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{0:T})} \\Big] \\geq - \\mathbb{E}_{q(\\mathbf{x}_0)} \\log p_\\theta(\\mathbf{x}_0) \\end{aligned} $$ It is also straightforward to get the same result using Jensen\u0026rsquo;s inequality. Say we want to minimize the cross entropy as the learning objective,\n $$ \\begin{aligned} L_\\text{CE} \u0026= - \\mathbb{E}_{q(\\mathbf{x}_0)} \\log p_\\theta(\\mathbf{x}_0) \\\\ \u0026= - \\mathbb{E}_{q(\\mathbf{x}_0)} \\log \\Big( \\int p_\\theta(\\mathbf{x}_{0:T}) d\\mathbf{x}_{1:T} \\Big) \\\\ \u0026= - \\mathbb{E}_{q(\\mathbf{x}_0)} \\log \\Big( \\int q(\\mathbf{x}_{1:T} \\vert \\mathbf{x}_0) \\frac{p_\\theta(\\mathbf{x}_{0:T})}{q(\\mathbf{x}_{1:T} \\vert \\mathbf{x}_{0})} d\\mathbf{x}_{1:T} \\Big) \\\\ \u0026= - \\mathbb{E}_{q(\\mathbf{x}_0)} \\log \\Big( \\mathbb{E}_{q(\\mathbf{x}_{1:T} \\vert \\mathbf{x}_0)} \\frac{p_\\theta(\\mathbf{x}_{0:T})}{q(\\mathbf{x}_{1:T} \\vert \\mathbf{x}_{0})} \\Big) \\\\ \u0026\\leq - \\mathbb{E}_{q(\\mathbf{x}_{0:T})} \\log \\frac{p_\\theta(\\mathbf{x}_{0:T})}{q(\\mathbf{x}_{1:T} \\vert \\mathbf{x}_{0})} \\\\ \u0026= \\mathbb{E}_{q(\\mathbf{x}_{0:T})}\\Big[\\log \\frac{q(\\mathbf{x}_{1:T} \\vert \\mathbf{x}_{0})}{p_\\theta(\\mathbf{x}_{0:T})} \\Big] = L_\\text{VLB} \\end{aligned} $$ To convert each term in the equation to be analytically computable, the objective can be further rewritten to be a combination of several KL-divergence and entropy terms (See the detailed step-by-step process in Appendix B in Sohl-Dickstein et al., 2015):\n $$ \\begin{aligned} L_\\text{VLB} \u0026= \\mathbb{E}_{q(\\mathbf{x}_{0:T})} \\Big[ \\log\\frac{q(\\mathbf{x}_{1:T}\\vert\\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{0:T})} \\Big] \\\\ \u0026= \\mathbb{E}_q \\Big[ \\log\\frac{\\prod_{t=1}^T q(\\mathbf{x}_t\\vert\\mathbf{x}_{t-1})}{ p_\\theta(\\mathbf{x}_T) \\prod_{t=1}^T p_\\theta(\\mathbf{x}_{t-1} \\vert\\mathbf{x}_t) } \\Big] \\\\ \u0026= \\mathbb{E}_q \\Big[ -\\log p_\\theta(\\mathbf{x}_T) + \\sum_{t=1}^T \\log \\frac{q(\\mathbf{x}_t\\vert\\mathbf{x}_{t-1})}{p_\\theta(\\mathbf{x}_{t-1} \\vert\\mathbf{x}_t)} \\Big] \\\\ \u0026= \\mathbb{E}_q \\Big[ -\\log p_\\theta(\\mathbf{x}_T) + \\sum_{t=2}^T \\log \\frac{q(\\mathbf{x}_t\\vert\\mathbf{x}_{t-1})}{p_\\theta(\\mathbf{x}_{t-1} \\vert\\mathbf{x}_t)} + \\log\\frac{q(\\mathbf{x}_1 \\vert \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_0 \\vert \\mathbf{x}_1)} \\Big] \\\\ \u0026= \\mathbb{E}_q \\Big[ -\\log p_\\theta(\\mathbf{x}_T) + \\sum_{t=2}^T \\log \\Big( \\frac{q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{t-1} \\vert\\mathbf{x}_t)}\\cdot \\frac{q(\\mathbf{x}_t \\vert \\mathbf{x}_0)}{q(\\mathbf{x}_{t-1}\\vert\\mathbf{x}_0)} \\Big) + \\log \\frac{q(\\mathbf{x}_1 \\vert \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_0 \\vert \\mathbf{x}_1)} \\Big] \\\\ \u0026= \\mathbb{E}_q \\Big[ -\\log p_\\theta(\\mathbf{x}_T) + \\sum_{t=2}^T \\log \\frac{q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{t-1} \\vert\\mathbf{x}_t)} + \\sum_{t=2}^T \\log \\frac{q(\\mathbf{x}_t \\vert \\mathbf{x}_0)}{q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_0)} + \\log\\frac{q(\\mathbf{x}_1 \\vert \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_0 \\vert \\mathbf{x}_1)} \\Big] \\\\ \u0026= \\mathbb{E}_q \\Big[ -\\log p_\\theta(\\mathbf{x}_T) + \\sum_{t=2}^T \\log \\frac{q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{t-1} \\vert\\mathbf{x}_t)} + \\log\\frac{q(\\mathbf{x}_T \\vert \\mathbf{x}_0)}{q(\\mathbf{x}_1 \\vert \\mathbf{x}_0)} + \\log \\frac{q(\\mathbf{x}_1 \\vert \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_0 \\vert \\mathbf{x}_1)} \\Big]\\\\ \u0026= \\mathbb{E}_q \\Big[ \\log\\frac{q(\\mathbf{x}_T \\vert \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_T)} + \\sum_{t=2}^T \\log \\frac{q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0)}{p_\\theta(\\mathbf{x}_{t-1} \\vert\\mathbf{x}_t)} - \\log p_\\theta(\\mathbf{x}_0 \\vert \\mathbf{x}_1) \\Big] \\\\ \u0026= \\mathbb{E}_q [\\underbrace{D_\\text{KL}(q(\\mathbf{x}_T \\vert \\mathbf{x}_0) \\parallel p_\\theta(\\mathbf{x}_T))}_{L_T} + \\sum_{t=2}^T \\underbrace{D_\\text{KL}(q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0) \\parallel p_\\theta(\\mathbf{x}_{t-1} \\vert\\mathbf{x}_t))}_{L_{t-1}} \\underbrace{- \\log p_\\theta(\\mathbf{x}_0 \\vert \\mathbf{x}_1)}_{L_0} ] \\end{aligned} $$ Let\u0026rsquo;s label each component in the variational lower bound loss separately:\n $$ \\begin{aligned} L_\\text{VLB} \u0026= L_T + L_{T-1} + \\dots + L_0 \\\\ \\text{where } L_T \u0026= D_\\text{KL}(q(\\mathbf{x}_T \\vert \\mathbf{x}_0) \\parallel p_\\theta(\\mathbf{x}_T)) \\\\ L_t \u0026= D_\\text{KL}(q(\\mathbf{x}_t \\vert \\mathbf{x}_{t+1}, \\mathbf{x}_0) \\parallel p_\\theta(\\mathbf{x}_t \\vert\\mathbf{x}_{t+1})) \\text{ for }1 \\leq t \\leq T-1 \\\\ L_0 \u0026= - \\log p_\\theta(\\mathbf{x}_0 \\vert \\mathbf{x}_1) \\end{aligned} $$ Every KL term in $L_\\text{VLB}$ (except for $L_0$) compares two Gaussian distributions and therefore they can be computed in closed form. $L_T$ is constant and can be ignored during training because $q$ has no learnable parameters and $\\mathbf{x}_T$ is a Gaussian noise. Ho et al. 2020 models $L_0$ using a separate discrete decoder derived from $\\mathcal{N}(\\mathbf{x}_0; \\boldsymbol{\\mu}_\\theta(\\mathbf{x}_1, 1), \\boldsymbol{\\Sigma}_\\theta(\\mathbf{x}_1, 1))$.\nParameterization of $L_t$ for Training Loss Recall that we need to learn a neural network to approximate the conditioned probability distributions in the reverse diffusion process, $p_\\theta(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t) = \\mathcal{N}(\\mathbf{x}_{t-1}; \\boldsymbol{\\mu}_\\theta(\\mathbf{x}_t, t), \\boldsymbol{\\Sigma}_\\theta(\\mathbf{x}_t, t))$. We would like to train $\\boldsymbol{\\mu}_\\theta$ to predict $\\tilde{\\boldsymbol{\\mu}}_t = \\frac{1}{\\sqrt{\\alpha_t}} \\Big( \\mathbf{x}_t - \\frac{1 - \\alpha_t}{\\sqrt{1 - \\bar{\\alpha}_t}} \\boldsymbol{\\epsilon}_t \\Big)$. Because $\\mathbf{x}_t$ is available as input at training time, we can reparameterize the Gaussian noise term instead to make it predict $\\boldsymbol{\\epsilon}_t$ from the input $\\mathbf{x}_t$ at time step $t$:\n $$ \\begin{aligned} \\boldsymbol{\\mu}_\\theta(\\mathbf{x}_t, t) \u0026= \\color{cyan}{\\frac{1}{\\sqrt{\\alpha_t}} \\Big( \\mathbf{x}_t - \\frac{1 - \\alpha_t}{\\sqrt{1 - \\bar{\\alpha}_t}} \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) \\Big)} \\\\ \\text{Thus }\\mathbf{x}_{t-1} \u0026= \\mathcal{N}(\\mathbf{x}_{t-1}; \\frac{1}{\\sqrt{\\alpha_t}} \\Big( \\mathbf{x}_t - \\frac{1 - \\alpha_t}{\\sqrt{1 - \\bar{\\alpha}_t}} \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) \\Big), \\boldsymbol{\\Sigma}_\\theta(\\mathbf{x}_t, t)) \\end{aligned} $$ The loss term $L_t$ is parameterized to minimize the difference from $\\tilde{\\boldsymbol{\\mu}}$ :\n $$ \\begin{aligned} L_t \u0026= \\mathbb{E}_{\\mathbf{x}_0, \\boldsymbol{\\epsilon}} \\Big[\\frac{1}{2 \\| \\boldsymbol{\\Sigma}_\\theta(\\mathbf{x}_t, t) \\|^2_2} \\| \\color{blue}{\\tilde{\\boldsymbol{\\mu}}_t(\\mathbf{x}_t, \\mathbf{x}_0)} - \\color{green}{\\boldsymbol{\\mu}_\\theta(\\mathbf{x}_t, t)} \\|^2 \\Big] \\\\ \u0026= \\mathbb{E}_{\\mathbf{x}_0, \\boldsymbol{\\epsilon}} \\Big[\\frac{1}{2 \\|\\boldsymbol{\\Sigma}_\\theta \\|^2_2} \\| \\color{blue}{\\frac{1}{\\sqrt{\\alpha_t}} \\Big( \\mathbf{x}_t - \\frac{1 - \\alpha_t}{\\sqrt{1 - \\bar{\\alpha}_t}} \\boldsymbol{\\epsilon}_t \\Big)} - \\color{green}{\\frac{1}{\\sqrt{\\alpha_t}} \\Big( \\mathbf{x}_t - \\frac{1 - \\alpha_t}{\\sqrt{1 - \\bar{\\alpha}_t}} \\boldsymbol{\\boldsymbol{\\epsilon}}_\\theta(\\mathbf{x}_t, t) \\Big)} \\|^2 \\Big] \\\\ \u0026= \\mathbb{E}_{\\mathbf{x}_0, \\boldsymbol{\\epsilon}} \\Big[\\frac{ (1 - \\alpha_t)^2 }{2 \\alpha_t (1 - \\bar{\\alpha}_t) \\| \\boldsymbol{\\Sigma}_\\theta \\|^2_2} \\|\\boldsymbol{\\epsilon}_t - \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t)\\|^2 \\Big] \\\\ \u0026= \\mathbb{E}_{\\mathbf{x}_0, \\boldsymbol{\\epsilon}} \\Big[\\frac{ (1 - \\alpha_t)^2 }{2 \\alpha_t (1 - \\bar{\\alpha}_t) \\| \\boldsymbol{\\Sigma}_\\theta \\|^2_2} \\|\\boldsymbol{\\epsilon}_t - \\boldsymbol{\\epsilon}_\\theta(\\sqrt{\\bar{\\alpha}_t}\\mathbf{x}_0 + \\sqrt{1 - \\bar{\\alpha}_t}\\boldsymbol{\\epsilon}_t, t)\\|^2 \\Big] \\end{aligned} $$ Simplification Empirically, Ho et al. (2020) found that training the diffusion model works better with a simplified objective that ignores the weighting term:\n $$ \\begin{aligned} L_t^\\text{simple} \u0026= \\mathbb{E}_{t \\sim [1, T], \\mathbf{x}_0, \\boldsymbol{\\epsilon}_t} \\Big[\\|\\boldsymbol{\\epsilon}_t - \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t)\\|^2 \\Big] \\\\ \u0026= \\mathbb{E}_{t \\sim [1, T], \\mathbf{x}_0, \\boldsymbol{\\epsilon}_t} \\Big[\\|\\boldsymbol{\\epsilon}_t - \\boldsymbol{\\epsilon}_\\theta(\\sqrt{\\bar{\\alpha}_t}\\mathbf{x}_0 + \\sqrt{1 - \\bar{\\alpha}_t}\\boldsymbol{\\epsilon}_t, t)\\|^2 \\Big] \\end{aligned} $$ The final simple objective is:\n $$ L_\\text{simple} = L_t^\\text{simple} + C $$ where $C$ is a constant not depending on $\\theta$.\nFig. 4. The training and sampling algorithms in DDPM (Image source: Ho et al. 2020) Connection with noise-conditioned score networks (NCSN) Song \u0026amp; Ermon (2019) proposed a score-based generative modeling method where samples are produced via Langevin dynamics using gradients of the data distribution estimated with score matching. The score of each sample $\\mathbf{x}$\u0026rsquo;s density probability is defined as its gradient $\\nabla_{\\mathbf{x}} \\log q(\\mathbf{x})$. A score network $\\mathbf{s}_\\theta: \\mathbb{R}^D \\to \\mathbb{R}^D$ is trained to estimate it, $\\mathbf{s}_\\theta(\\mathbf{x}) \\approx \\nabla_{\\mathbf{x}} \\log q(\\mathbf{x})$.\nTo make it scalable with high-dimensional data in the deep learning setting, they proposed to use either denoising score matching (Vincent, 2011) or sliced score matching (use random projections; Song et al., 2019). Denosing score matching adds a pre-specified small noise to the data $q(\\tilde{\\mathbf{x}} \\vert \\mathbf{x})$ and estimates $q(\\tilde{\\mathbf{x}})$ with score matching.\nRecall that Langevin dynamics can sample data points from a probability density distribution using only the score $\\nabla_{\\mathbf{x}} \\log q(\\mathbf{x})$ in an iterative process.\nHowever, according to the manifold hypothesis, most of the data is expected to concentrate in a low dimensional manifold, even though the observed data might look only arbitrarily high-dimensional. It brings a negative effect on score estimation since the data points cannot cover the whole space. In regions where data density is low, the score estimation is less reliable. After adding a small Gaussian noise to make the perturbed data distribution cover the full space $\\mathbb{R}^D$, the training of the score estimator network becomes more stable. Song \u0026amp; Ermon (2019) improved it by perturbing the data with the noise of different levels and train a noise-conditioned score network to jointly estimate the scores of all the perturbed data at different noise levels.\nThe schedule of increasing noise levels resembles the forward diffusion process. If we use the diffusion process annotation, the score approximates $\\mathbf{s}_\\theta(\\mathbf{x}_t, t) \\approx \\nabla_{\\mathbf{x}_t} \\log q(\\mathbf{x}_t)$. Given a Gaussian distribution $\\mathbf{x} \\sim \\mathcal{N}(\\mathbf{\\mu}, \\sigma^2 \\mathbf{I})$, we can write the derivative of the logarithm of its density function as $\\nabla_{\\mathbf{x}}\\log p(\\mathbf{x}) = \\nabla_{\\mathbf{x}} \\Big(-\\frac{1}{2\\sigma^2}(\\mathbf{x} - \\boldsymbol{\\mu})^2 \\Big) = - \\frac{\\mathbf{x} - \\boldsymbol{\\mu}}{\\sigma^2} = - \\frac{\\boldsymbol{\\epsilon}}{\\sigma}$ where $\\boldsymbol{\\epsilon} \\sim \\mathcal{N}(\\boldsymbol{0}, \\mathbf{I})$. Recall that $q(\\mathbf{x}_t \\vert \\mathbf{x}_0) \\sim \\mathcal{N}(\\sqrt{\\bar{\\alpha}_t} \\mathbf{x}_0, (1 - \\bar{\\alpha}_t)\\mathbf{I})$ and therefore,\n $$ \\mathbf{s}_\\theta(\\mathbf{x}_t, t) \\approx \\nabla_{\\mathbf{x}_t} \\log q(\\mathbf{x}_t) = \\mathbb{E}_{q(\\mathbf{x}_0)} [\\nabla_{\\mathbf{x}_t} q(\\mathbf{x}_t \\vert \\mathbf{x}_0)] = \\mathbb{E}_{q(\\mathbf{x}_0)} \\Big[ - \\frac{\\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t)}{\\sqrt{1 - \\bar{\\alpha}_t}} \\Big] = - \\frac{\\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t)}{\\sqrt{1 - \\bar{\\alpha}_t}} $$ Parameterization of $\\beta_t$ The forward variances are set to be a sequence of linearly increasing constants in Ho et al. (2020), from $\\beta_1=10^{-4}$ to $\\beta_T=0.02$. They are relatively small compared to the normalized image pixel values between $[-1, 1]$. Diffusion models in their experiments showed high-quality samples but still could not achieve competitive model log-likelihood as other generative models.\nNichol \u0026amp; Dhariwal (2021) proposed several improvement techniques to help diffusion models to obtain lower NLL. One of the improvements is to use a cosine-based variance schedule. The choice of the scheduling function can be arbitrary, as long as it provides a near-linear drop in the middle of the training process and subtle changes around $t=0$ and $t=T$.\n $$ \\beta_t = \\text{clip}(1-\\frac{\\bar{\\alpha}_t}{\\bar{\\alpha}_{t-1}}, 0.999) \\quad\\bar{\\alpha}_t = \\frac{f(t)}{f(0)}\\quad\\text{where }f(t)=\\cos\\Big(\\frac{t/T+s}{1+s}\\cdot\\frac{\\pi}{2}\\Big)^2 $$ where the small offset $s$ is to prevent $\\beta_t$ from being too small when close to $t=0$.\nFig. 5. Comparison of linear and cosine-based scheduling of $\\beta\\_t$ during training. (Image source: Nichol \u0026 Dhariwal, 2021) Parameterization of reverse process variance $\\boldsymbol{\\Sigma}_\\theta$ Ho et al. (2020) chose to fix $\\beta_t$ as constants instead of making them learnable and set $\\boldsymbol{\\Sigma}_\\theta(\\mathbf{x}_t, t) = \\sigma^2_t \\mathbf{I}$ , where $\\sigma_t$ is not learned but set to $\\beta_t$ or $\\tilde{\\beta}_t = \\frac{1 - \\bar{\\alpha}_{t-1}}{1 - \\bar{\\alpha}_t} \\cdot \\beta_t$. Because they found that learning a diagonal variance $\\boldsymbol{\\Sigma}_\\theta$ leads to unstable training and poorer sample quality.\nNichol \u0026amp; Dhariwal (2021) proposed to learn $\\boldsymbol{\\Sigma}_\\theta(\\mathbf{x}_t, t)$ as an interpolation between $\\beta_t$ and $\\tilde{\\beta}_t$ by model predicting a mixing vector $\\mathbf{v}$ :\n $$ \\boldsymbol{\\Sigma}_\\theta(\\mathbf{x}_t, t) = \\exp(\\mathbf{v} \\log \\beta_t + (1-\\mathbf{v}) \\log \\tilde{\\beta}_t) $$ However, the simple objective $L_\\text{simple}$ does not depend on $\\boldsymbol{\\Sigma}_\\theta$ . To add the dependency, they constructed a hybrid objective $L_\\text{hybrid} = L_\\text{simple} + \\lambda L_\\text{VLB}$ where $\\lambda=0.001$ is small and stop gradient on $\\boldsymbol{\\mu}_\\theta$ in the $L_\\text{VLB}$ term such that $L_\\text{VLB}$ only guides the learning of $\\boldsymbol{\\Sigma}_\\theta$. Empirically they observed that $L_\\text{VLB}$ is pretty challenging to optimize likely due to noisy gradients, so they proposed to use a time-averaging smoothed version of $L_\\text{VLB}$ with importance sampling.\nFig. 6. Comparison of negative log-likelihood of improved DDPM with other likelihood-based generative models. NLL is reported in the unit of bits/dim. (Image source: Nichol \u0026 Dhariwal, 2021) Speed up Diffusion Model Sampling It is very slow to generate a sample from DDPM by following the Markov chain of the reverse diffusion process, as $T$ can be up to one or a few thousand steps. One data point from Song et al. 2020: \u0026ldquo;For example, it takes around 20 hours to sample 50k images of size 32 × 32 from a DDPM, but less than a minute to do so from a GAN on an Nvidia 2080 Ti GPU.\u0026rdquo;\nOne simple way is to run a strided sampling schedule (Nichol \u0026amp; Dhariwal, 2021) by taking the sampling update every $\\lceil T/S \\rceil$ steps to reduce the process from $T$ to $S$ steps. The new sampling schedule for generation is $\\{\\tau_1, \\dots, \\tau_S\\}$ where $\\tau_1 \u0026lt; \\tau_2 \u0026lt; \\dots \u0026lt;\\tau_S \\in [1, T]$ and $S \u0026lt; T$.\nFor another approach, let\u0026rsquo;s rewrite $q_\\sigma(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0)$ to be parameterized by a desired standard deviation $\\sigma_t$ according to the nice property:\n $$ \\begin{aligned} \\mathbf{x}_{t-1} \u0026= \\sqrt{\\bar{\\alpha}_{t-1}}\\mathbf{x}_0 + \\sqrt{1 - \\bar{\\alpha}_{t-1}}\\boldsymbol{\\epsilon}_{t-1} \\\\ \u0026= \\sqrt{\\bar{\\alpha}_{t-1}}\\mathbf{x}_0 + \\sqrt{1 - \\bar{\\alpha}_{t-1} - \\sigma_t^2} \\boldsymbol{\\epsilon}_t + \\sigma_t\\boldsymbol{\\epsilon} \\\\ \u0026= \\sqrt{\\bar{\\alpha}_{t-1}}\\mathbf{x}_0 + \\sqrt{1 - \\bar{\\alpha}_{t-1} - \\sigma_t^2} \\frac{\\mathbf{x}_t - \\sqrt{\\bar{\\alpha}_t}\\mathbf{x}_0}{\\sqrt{1 - \\bar{\\alpha}_t}} + \\sigma_t\\boldsymbol{\\epsilon} \\\\ q_\\sigma(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0) \u0026= \\mathcal{N}(\\mathbf{x}_{t-1}; \\sqrt{\\bar{\\alpha}_{t-1}}\\mathbf{x}_0 + \\sqrt{1 - \\bar{\\alpha}_{t-1} - \\sigma_t^2} \\frac{\\mathbf{x}_t - \\sqrt{\\bar{\\alpha}_t}\\mathbf{x}_0}{\\sqrt{1 - \\bar{\\alpha}_t}}, \\sigma_t^2 \\mathbf{I}) \\end{aligned} $$ Recall that in $q(\\mathbf{x}_{t-1} \\vert \\mathbf{x}_t, \\mathbf{x}_0) = \\mathcal{N}(\\mathbf{x}_{t-1}; \\tilde{\\boldsymbol{\\mu}}(\\mathbf{x}_t, \\mathbf{x}_0), \\tilde{\\beta}_t \\mathbf{I})$, therefore we have:\n $$ \\tilde{\\beta}_t = \\sigma_t^2 = \\frac{1 - \\bar{\\alpha}_{t-1}}{1 - \\bar{\\alpha}_t} \\cdot \\beta_t $$ Let $\\sigma_t^2 = \\eta \\cdot \\tilde{\\beta}_t$ such that we can adjust $\\eta \\in \\mathbb{R}^+$ as a hyperparameter to control the sampling stochasticity. The special case of $\\eta = 0$ makes the sampling process deterministic. Such a model is named the denoising diffusion implicit model (DDIM; Song et al., 2020). DDIM has the same marginal noise distribution but deterministically maps noise back to the original data samples.\nDuring generation, we only sample a subset of $S$ diffusion steps $\\{\\tau_1, \\dots, \\tau_S\\}$ and the inference process becomes:\n $$ q_{\\sigma, \\tau}(\\mathbf{x}_{\\tau_{i-1}} \\vert \\mathbf{x}_{\\tau_t}, \\mathbf{x}_0) = \\mathcal{N}(\\mathbf{x}_{\\tau_{i-1}}; \\sqrt{\\bar{\\alpha}_{t-1}}\\mathbf{x}_0 + \\sqrt{1 - \\bar{\\alpha}_{t-1} - \\sigma_t^2} \\frac{\\mathbf{x}_{\\tau_i} - \\sqrt{\\bar{\\alpha}_t}\\mathbf{x}_0}{\\sqrt{1 - \\bar{\\alpha}_t}}, \\sigma_t^2 \\mathbf{I}) $$ While all the models are trained with $T=1000$ diffusion steps in the experiments, they observed that DDIM ($\\eta=0$) can produce the best quality samples when $S$ is small, while DDPM ($\\eta=1$) performs much worse on small $S$. DDPM does perform better when we can afford to run the full reverse Markov diffusion steps ($S=T=1000$). With DDIM, it is possible to train the diffusion model up to any arbitrary number of forward steps but only sample from a subset of steps in the generative process.\nFig. 7. FID scores on CIFAR10 and CelebA datasets by diffusion models of different settings, including $\\color{cyan}{\\text{DDIM}}$ ($\\eta=0$) and $\\color{orange}{\\text{DDPM}}$ ($\\hat{\\sigma}$). (Image source: Song et al., 2020) Compared to DDPM, DDIM is able to:\n Generate higher-quality samples using a much fewer number of steps. Have \u0026ldquo;consistency\u0026rdquo; property since the generative process is deterministic, meaning that multiple samples conditioned on the same latent variable should have similar high-level features. Because of the consistency, DDIM can do semantically meaningful interpolation in the latent variable. Latent diffusion model (LDM; Rombach \u0026amp; Blattmann, et al. 2022) runs the diffusion process in the latent space instead of pixel space, making training cost lower and inference speed faster. It is motivated by the observation that most bits of an image contribute to perceptual details and the semantic and conceptual composition still remains after aggressive compression. LDM loosely decomposes the perceptual compression and semantic compression with generative modeling learning by first trimming off pixel-level redundancy with autoencoder and then manipulate/generate semantic concepts with diffusion process on learned latent.\nFig. 8. The plot for tradeoff between compression rate and distortion, illustrating two-stage compressions - perceptural and semantic comparession. (Image source: Rombach \u0026 Blattmann, et al. 2022) The perceptual compression process relies on an autoencoder model. An encoder $\\mathcal{E}$ is used to compress the input image $\\mathbf{x} \\in \\mathbb{R}^{H \\times W \\times 3}$ to a smaller 2D latent vector $\\mathbf{z} = \\mathcal{E}(\\mathbf{x}) \\in \\mathbb{R}^{h \\times w \\times c}$ , where the downsampling rate $f=H/h=W/w=2^m, m \\in \\mathbb{N}$. Then an decoder $\\mathcal{D}$ reconstructs the images from the latent vector, $\\tilde{\\mathbf{x}} = \\mathcal{D}(\\mathbf{z})$. The paper explored two types of regularization in autoencoder training to avoid arbitrarily high-variance in the latent spaces.\n KL-reg: A small KL penalty towards a standard normal distribution over the learned latent, similar to VAE. VQ-reg: Uses a vector quantization layer within the decoder, like VQVAE but the quantization layer is absorbed by the decoder. The diffusion and denoising processes happen on the latent vector $\\mathbf{z}$. The denoising model is a time-conditioned U-Net, augmented with the cross-attention mechanism to handle flexible conditioning information for image generation (e.g. class labels, semantic maps, blurred variants of an image). The design is equivalent to fuse representation of different modality into the model with cross-attention mechanism. Each type of conditioning information is paired with a domain-specific encoder $\\tau_\\theta$ to project the conditioning input $y$ to an intermediate representation that can be mapped into cross-attention component, $\\tau_\\theta(y) \\in \\mathbb{R}^{M \\times d_\\tau}$:\n $$ \\begin{aligned} \u0026\\text{Attention}(\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}) = \\text{softmax}\\Big(\\frac{\\mathbf{Q}\\mathbf{K}^\\top}{\\sqrt{d}}\\Big) \\cdot \\mathbf{V} \\\\ \u0026\\text{where }\\mathbf{Q} = \\mathbf{W}^{(i)}_Q \\cdot \\varphi_i(\\mathbf{z}_i),\\; \\mathbf{K} = \\mathbf{W}^{(i)}_K \\cdot \\tau_\\theta(y),\\; \\mathbf{V} = \\mathbf{W}^{(i)}_V \\cdot \\tau_\\theta(y) \\\\ \u0026\\text{and } \\mathbf{W}^{(i)}_Q \\in \\mathbb{R}^{d \\times d^i_\\epsilon},\\; \\mathbf{W}^{(i)}_K, \\mathbf{W}^{(i)}_V \\in \\mathbb{R}^{d \\times d_\\tau},\\; \\varphi_i(\\mathbf{z}_i) \\in \\mathbb{R}^{N \\times d^i_\\epsilon},\\; \\tau_\\theta(y) \\in \\mathbb{R}^{M \\times d_\\tau} \\end{aligned} $$ Fig. 9. The architecture of latent diffusion model. (Image source: Rombach \u0026 Blattmann, et al. 2022) Conditioned Generation While training generative models on images with conditioning information such as ImageNet dataset, it is common to generate samples conditioned on class labels or a piece of descriptive text.\nClassifier Guided Diffusion To explicit incorporate class information into the diffusion process, Dhariwal \u0026amp; Nichol (2021) trained a classifier $f_\\phi(y \\vert \\mathbf{x}_t, t)$ on noisy image $\\mathbf{x}_t$ and use gradients $\\nabla_\\mathbf{x} \\log f_\\phi(y \\vert \\mathbf{x}_t)$ to guide the diffusion sampling process toward the conditioning information $y$ (e.g. a target class label) by altering the noise prediction. Recall that $\\nabla_{\\mathbf{x}_t} \\log q(\\mathbf{x}_t) = - \\frac{1}{\\sqrt{1 - \\bar{\\alpha}_t}} \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t)$ and we can write the score function for the joint distribution $q(\\mathbf{x}_t, y)$ as following,\n $$ \\begin{aligned} \\nabla_{\\mathbf{x}_t} \\log q(\\mathbf{x}_t, y) \u0026= \\nabla_{\\mathbf{x}_t} \\log q(\\mathbf{x}_t) + \\nabla_{\\mathbf{x}_t} \\log q(y \\vert \\mathbf{x}_t) \\\\ \u0026\\approx - \\frac{1}{\\sqrt{1 - \\bar{\\alpha}_t}} \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) + \\nabla_{\\mathbf{x}_t} \\log f_\\phi(y \\vert \\mathbf{x}_t) \\\\ \u0026= - \\frac{1}{\\sqrt{1 - \\bar{\\alpha}_t}} (\\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) - \\sqrt{1 - \\bar{\\alpha}_t} \\nabla_{\\mathbf{x}_t} \\log f_\\phi(y \\vert \\mathbf{x}_t)) \\end{aligned} $$ Thus, a new classifier-guided predictor $\\bar{\\boldsymbol{\\epsilon}}_\\theta$ would take the form as following,\n $$ \\bar{\\boldsymbol{\\epsilon}}_\\theta(\\mathbf{x}_t, t) = \\boldsymbol{\\epsilon}_\\theta(x_t, t) - \\sqrt{1 - \\bar{\\alpha}_t} \\nabla_{\\mathbf{x}_t} \\log f_\\phi(y \\vert \\mathbf{x}_t) $$ To control the strength of the classifier guidance, we can add a weight $w$ to the delta part,\n $$ \\bar{\\boldsymbol{\\epsilon}}_\\theta(\\mathbf{x}_t, t) = \\boldsymbol{\\epsilon}_\\theta(x_t, t) - \\sqrt{1 - \\bar{\\alpha}_t} \\; w \\nabla_{\\mathbf{x}_t} \\log f_\\phi(y \\vert \\mathbf{x}_t) $$ The resulting ablated diffusion model (ADM) and the one with additional classifier guidance (ADM-G) are able to achieve better results than SOTA generative models (e.g. BigGAN).\nFig. 10. The algorithms use guidance from a classifier to run conditioned generation with DDPM and DDIM. (Image source: Dhariwal \u0026 Nichol, 2021]) Additionally with some modifications on the U-Net architecture, Dhariwal \u0026amp; Nichol (2021) showed performance better than GAN with diffusion models. The architecture modifications include larger model depth/width, more attention heads, multi-resolution attention, BigGAN residual blocks for up/downsampling, residual connection rescale by $1/\\sqrt{2}$ and adaptive group normalization (AdaGN).\nClassifier-Free Guidance Without an independent classifier $f_\\phi$, it is still possible to run conditional diffusion steps by incorporating the scores from a conditional and an unconditional diffusion model (Ho \u0026amp; Salimans, 2021). Let unconditional denoising diffusion model $p_\\theta(\\mathbf{x})$ parameterized through a score estimator $\\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t)$ and the conditional model $p_\\theta(\\mathbf{x} \\vert y)$ parameterized through $\\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t, y)$. These two models can be learned via a single neural network. Precisely, a conditional diffusion model $p_\\theta(\\mathbf{x} \\vert y)$ is trained on paired data $(\\mathbf{x}, y)$, where the conditioning information $y$ gets discarded periodically at random such that the model knows how to generate images unconditionally as well, i.e. $\\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) = \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t, y=\\varnothing)$.\nThe gradient of an implicit classifier can be represented with conditional and unconditional score estimators. Once plugged into the classifier-guided modified score, the score contains no dependency on a separate classifier.\n $$ \\begin{aligned} \\nabla_{\\mathbf{x}_t} \\log p(y \\vert \\mathbf{x}_t) \u0026= \\nabla_{\\mathbf{x}_t} \\log p(\\mathbf{x}_t \\vert y) - \\nabla_{\\mathbf{x}_t} \\log p(\\mathbf{x}_t) \\\\ \u0026= - \\frac{1}{\\sqrt{1 - \\bar{\\alpha}_t}}\\Big( \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t, y) - \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) \\Big) \\\\ \\bar{\\boldsymbol{\\epsilon}}_\\theta(\\mathbf{x}_t, t, y) \u0026= \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t, y) - \\sqrt{1 - \\bar{\\alpha}_t} \\; w \\nabla_{\\mathbf{x}_t} \\log p(y \\vert \\mathbf{x}_t) \\\\ \u0026= \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t, y) + w \\big(\\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t, y) - \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) \\big) \\\\ \u0026= (w+1) \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t, y) - w \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) \\end{aligned} $$ Their experiments showed that classifier-free guidance can achieve a good balance between FID (distinguish between synthetic and generated images) and IS (quality and diversity).\n$$ q(\\mathbf{x}_t \\vert y) q(y \\vert \\mathbf{x}_t)^w \\propto \\frac{q(y\\vert \\mathbf{x}_t) q(\\mathbf{x}_t)}{q(y)} q(y \\vert \\mathbf{x}_t)^w \\propto q(\\mathbf{x}_t) q(y \\vert \\mathbf{x}_t)^{w+1} $$ Therefore, the classifier-guided noise prediction can be rewritten as $$ \\begin{aligned} \\bar{\\boldsymbol{\\epsilon}}_\\theta(\\mathbf{x}_t, t) \u0026= \\boldsymbol{\\epsilon}_\\theta(\\mathbf{x}_t, t) - \\sqrt{1 - \\bar{\\alpha}_t} (w+1) \\nabla_{x_t} \\log f_\\phi(y\\vert \\mathbf{x}_t) \\\\ \u0026 \\approx - \\sqrt{1 - \\bar{\\alpha}_t} \\nabla_{\\mathbf{x}_t} [\\log p(\\mathbf{x}_t) + (w+1) \\log f_\\phi (y \\vert \\mathbf{x}_t)] \\\\ \u0026 = - \\sqrt{1 - \\bar{\\alpha}_t} \\nabla_{\\mathbf{x}_t} [\\log p(\\mathbf{x}_t \\vert y) + w \\log p_\\phi (y \\vert \\mathbf{x}_t)] \\end{aligned} $$ -- The guided diffusion model, GLIDE (Nichol, Dhariwal \u0026amp; Ramesh, et al. 2022), explored both guiding strategies, CLIP guidance and classifier-free guidance, and found that the latter is more preferred. They hypothesized that it is because CLIP guidance exploits the model with adversarial examples towards the CLIP model, rather than optimize the better matched images generation.\nScale up Generation Resolution and Quality To generate high-quality images at high resolution, Ho et al. (2021) proposed to use a pipeline of multiple diffusion models at increasing resolutions. Noise conditioning augmentation between pipeline models is crucial to the final image quality, which is to apply strong data augmentation to the conditioning input $\\mathbf{z}$ of each super-resolution model $p_\\theta(\\mathbf{x} \\vert \\mathbf{z})$. The conditioning noise helps reduce compounding error in the pipeline setup. U-net is a common choice of model architecture in diffusion modeling for high-resolution image generation.\nFig. 11. A cascaded pipeline of multiple diffusion models at increasing resolutions. (Image source: Ho et al. 2021]) They found the most effective noise is to apply Gaussian noise at low resolution and Gaussian blur at high resolution. In addition, they also explored two forms of conditioning augmentation that require small modification to the training process. Note that conditioning noise is only applied to training but not at inference.\n Truncated conditioning augmentation stops the diffusion process early at step $t \u0026gt; 0$ for low resolution. Non-truncated conditioning augmentation runs the full low resolution reverse process until step 0 but then corrupt it by $\\mathbf{z}_t \\sim q(\\mathbf{x}_t \\vert \\mathbf{x}_0)$ and then feeds the corrupted $\\mathbf{z}_t$ s into the super-resolution model. The two-stage diffusion model unCLIP (Ramesh et al. 2022) heavily utilizes the CLIP text encoder to produce text-guided images at high quality. Given a pretrained CLIP model $\\mathbf{c}$ and paired training data for the diffusion model, $(\\mathbf{x}, y)$, where $x$ is an image and $y$ is the corresponding caption, we can compute the CLIP text and image embedding, $\\mathbf{c}^t(y)$ and $\\mathbf{c}^i(\\mathbf{x})$, respectively. The unCLIP learns two models in parallel:\n A prior model $P(\\mathbf{c}^i \\vert y)$: outputs CLIP image embedding $\\mathbf{c}^i$ given the text $y$. A decoder $P(\\mathbf{x} \\vert \\mathbf{c}^i, [y])$: generates the image $\\mathbf{x}$ given CLIP image embedding $\\mathbf{c}^i$ and optionally the original text $y$. These two models enable conditional generation, because\n $$ \\underbrace{P(\\mathbf{x} \\vert y) = P(\\mathbf{x}, \\mathbf{c}^i \\vert y)}_{\\mathbf{c}^i\\text{ is deterministic given }\\mathbf{x}} = P(\\mathbf{x} \\vert \\mathbf{c}^i, y)P(\\mathbf{c}^i \\vert y) $$ Fig. 12. The architecture of unCLIP. (Image source: Ramesh et al. 2022]) unCLIP follows a two-stage image generation process:\n Given a text $y$, a CLIP model is first used to generate a text embedding $\\mathbf{c}^t(y)$. Using CLIP latent space enables zero-shot image manipulation via text. A diffusion or autoregressive prior $P(\\mathbf{c}^i \\vert y)$ processes this CLIP text embedding to construct an image prior and then a diffusion decoder $P(\\mathbf{x} \\vert \\mathbf{c}^i, [y])$ generates an image, conditioned on the prior. This decoder can also generate image variations conditioned on an image input, preserving its style and semantics. Instead of CLIP model, Imagen (Saharia et al. 2022) uses a pre-trained large LM (i.e. a frozen T5-XXL text encoder) to encode text for image generation. There is a general trend that larger model size can lead to better image quality and text-image alignment. They found that T5-XXL and CLIP text encoder achieve similar performance on MS-COCO, but human evaluation prefers T5-XXL on DrawBench (a collection of prompts covering 11 categories).\nWhen applying classifier-free guidance, increasing $w$ may lead to better image-text alignment but worse image fidelity. They found that it is due to train-test mismatch, that is saying, because training data $\\mathbf{x}$ stays within the range $[-1, 1]$, the test data should be so too. Two thresholding strategies are introduced:\n Static thresholding: clip $\\mathbf{x}$ prediction to $[-1, 1]$ Dynamic thresholding: at each sampling step, compute $s$ as a certain percentile absolute pixel value; if $s \u0026gt; 1$, clip the prediction to $[-s, s]$ and divide by $s$. Imagen modifies several designs in U-net to make it efficient U-Net.\n Shift model parameters from high resolution blocks to low resolution by adding more residual locks for the lower resolutions; Scale the skip connections by $1/\\sqrt{2}$ Reverse the order of downsampling (move it before convolutions) and upsampling operations (move it after convolution) in order to improve the speed of forward pass. They found that noise conditioning augmentation, dynamic thresholding and efficient U-Net are critical for image quality, but scaling text encoder size is more important than U-Net size.\nQuick Summary Pros: Tractability and flexibility are two conflicting objectives in generative modeling. Tractable models can be analytically evaluated and cheaply fit data (e.g. via a Gaussian or Laplace), but they cannot easily describe the structure in rich datasets. Flexible models can fit arbitrary structures in data, but evaluating, training, or sampling from these models is usually expensive. Diffusion models are both analytically tractable and flexible\n Cons: Diffusion models rely on a long Markov chain of diffusion steps to generate samples, so it can be quite expensive in terms of time and compute. New methods have been proposed to make the process much faster, but the sampling is still slower than GAN.\n Citation Cited as:\n Weng, Lilian. (Jul 2021). What are diffusion models? Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2021-07-11-diffusion-models/.\n Or\n@article{weng2021diffusion, title = \u0026quot;What are diffusion models?\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2021\u0026quot;, month = \u0026quot;Jul\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2021-07-11-diffusion-models/\u0026quot; } References [1] Jascha Sohl-Dickstein et al. “Deep Unsupervised Learning using Nonequilibrium Thermodynamics.” ICML 2015.\n[2] Max Welling \u0026amp; Yee Whye Teh. “Bayesian learning via stochastic gradient langevin dynamics.” ICML 2011.\n[3] Yang Song \u0026amp; Stefano Ermon. “Generative modeling by estimating gradients of the data distribution.” NeurIPS 2019.\n[4] Yang Song \u0026amp; Stefano Ermon. “Improved techniques for training score-based generative models.” NeuriPS 2020.\n[5] Jonathan Ho et al. “Denoising diffusion probabilistic models.” arxiv Preprint arxiv:2006.11239 (2020). [code]\n[6] Jiaming Song et al. “Denoising diffusion implicit models.” arxiv Preprint arxiv:2010.02502 (2020). [code]\n[7] Alex Nichol \u0026amp; Prafulla Dhariwal. “Improved denoising diffusion probabilistic models” arxiv Preprint arxiv:2102.09672 (2021). [code]\n[8] Prafula Dhariwal \u0026amp; Alex Nichol. \u0026ldquo;Diffusion Models Beat GANs on Image Synthesis.\u0026quot; arxiv Preprint arxiv:2105.05233 (2021). [code]\n[9] Jonathan Ho \u0026amp; Tim Salimans. \u0026ldquo;Classifier-Free Diffusion Guidance.\u0026quot; NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications.\n[10] Yang Song, et al. \u0026ldquo;Score-Based Generative Modeling through Stochastic Differential Equations.\u0026quot; ICLR 2021.\n[11] Alex Nichol, Prafulla Dhariwal \u0026amp; Aditya Ramesh, et al. \u0026ldquo;GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models.\u0026quot; ICML 2022.\n[12] Jonathan Ho, et al. \u0026ldquo;Cascaded diffusion models for high fidelity image generation.\u0026quot; J. Mach. Learn. Res. 23 (2022): 47-1.\n[13] Aditya Ramesh et al. \u0026ldquo;Hierarchical Text-Conditional Image Generation with CLIP Latents.\u0026quot; arxiv Preprint arxiv:2204.06125 (2022).\n[14] Chitwan Saharia \u0026amp; William Chan, et al. \u0026ldquo;Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding.\u0026quot; arxiv Preprint arxiv:2205.11487 (2022).\n[15] Rombach \u0026amp; Blattmann, et al. \u0026ldquo;High-Resolution Image Synthesis with Latent Diffusion Models.\u0026quot; CVPR 2022.code\n","permalink":"https://lilianweng.github.io/posts/2021-07-11-diffusion-models/","summary":"[Updated on 2021-09-19: Highly recommend this blog post on score-based generative modeling by Yang Song (author of several key papers in the references)]. [Updated on 2022-08-27: Added classifier-free guidance, GLIDE, unCLIP and Imagen. [Updated on 2022-08-31: Added latent diffusion model.\nSo far, I\u0026rsquo;ve written about three types of generative models, GAN, VAE, and Flow-based models. They have shown great success in generating high-quality samples, but each has some limitations of its own.","title":"What are Diffusion Models?"},{"content":"The goal of contrastive representation learning is to learn such an embedding space in which similar sample pairs stay close to each other while dissimilar ones are far apart. Contrastive learning can be applied to both supervised and unsupervised settings. When working with unsupervised data, contrastive learning is one of the most powerful approaches in self-supervised learning.\nContrastive Training Objectives In early versions of loss functions for contrastive learning, only one positive and one negative sample are involved. The trend in recent training objectives is to include multiple positive and negative pairs in one batch.\nContrastive Loss Contrastive loss (Chopra et al. 2005) is one of the earliest training objectives used for deep metric learning in a contrastive fashion.\nGiven a list of input samples $\\{ \\mathbf{x}_i \\}$, each has a corresponding label $y_i \\in \\{1, \\dots, L\\}$ among $L$ classes. We would like to learn a function $f_\\theta(.): \\mathcal{X}\\to\\mathbb{R}^d$ that encodes $x_i$ into an embedding vector such that examples from the same class have similar embeddings and samples from different classes have very different ones. Thus, contrastive loss takes a pair of inputs $(x_i, x_j)$ and minimizes the embedding distance when they are from the same class but maximizes the distance otherwise.\n $$ \\mathcal{L}_\\text{cont}(\\mathbf{x}_i, \\mathbf{x}_j, \\theta) = \\mathbb{1}[y_i=y_j] \\| f_\\theta(\\mathbf{x}_i) - f_\\theta(\\mathbf{x}_j) \\|^2_2 + \\mathbb{1}[y_i\\neq y_j]\\max(0, \\epsilon - \\|f_\\theta(\\mathbf{x}_i) - f_\\theta(\\mathbf{x}_j)\\|_2)^2 $$ where $\\epsilon$ is a hyperparameter, defining the lower bound distance between samples of different classes.\nTriplet Loss Triplet loss was originally proposed in the FaceNet (Schroff et al. 2015) paper and was used to learn face recognition of the same person at different poses and angles.\nFig. 1. Illustration of triplet loss given one positive and one negative per anchor. (Image source: Schroff et al. 2015) Given one anchor input $\\mathbf{x}$, we select one positive sample $\\mathbf{x}^+$ and one negative $\\mathbf{x}^-$, meaning that $\\mathbf{x}^+$ and $\\mathbf{x}$ belong to the same class and $\\mathbf{x}^-$ is sampled from another different class. Triplet loss learns to minimize the distance between the anchor $\\mathbf{x}$ and positive $\\mathbf{x}^+$ and maximize the distance between the anchor $\\mathbf{x}$ and negative $\\mathbf{x}^-$ at the same time with the following equation:\n $$ \\mathcal{L}_\\text{triplet}(\\mathbf{x}, \\mathbf{x}^+, \\mathbf{x}^-) = \\sum_{\\mathbf{x} \\in \\mathcal{X}} \\max\\big( 0, \\|f(\\mathbf{x}) - f(\\mathbf{x}^+)\\|^2_2 - \\|f(\\mathbf{x}) - f(\\mathbf{x}^-)\\|^2_2 + \\epsilon \\big) $$ where the margin parameter $\\epsilon$ is configured as the minimum offset between distances of similar vs dissimilar pairs.\nIt is crucial to select challenging $\\mathbf{x}^-$ to truly improve the model.\nLifted Structured Loss Lifted Structured Loss (Song et al. 2015) utilizes all the pairwise edges within one training batch for better computational efficiency.\nFig. 2. Illustration compares contrastive loss, triplet loss and lifted structured loss. Red and blue edges connect similar and dissimilar sample pairs respectively. (Image source: Song et al. 2015) Let $D_{ij} = | f(\\mathbf{x}_i) - f(\\mathbf{x}_j) |_2$, a structured loss function is defined as\n $$ \\begin{aligned} \\mathcal{L}_\\text{struct} \u0026= \\frac{1}{2\\vert \\mathcal{P} \\vert} \\sum_{(i,j) \\in \\mathcal{P}} \\max(0, \\mathcal{L}_\\text{struct}^{(ij)})^2 \\\\ \\text{where } \\mathcal{L}_\\text{struct}^{(ij)} \u0026= D_{ij} + \\color{red}{\\max \\big( \\max_{(i,k)\\in \\mathcal{N}} \\epsilon - D_{ik}, \\max_{(j,l)\\in \\mathcal{N}} \\epsilon - D_{jl} \\big)} \\end{aligned} $$ where $\\mathcal{P}$ contains the set of positive pairs and $\\mathcal{N}$ is the set of negative pairs. Note that the dense pairwise squared distance matrix can be easily computed per training batch.\nThe red part in $\\mathcal{L}_\\text{struct}^{(ij)}$ is used for mining hard negatives. However, it is not smooth and may cause the convergence to a bad local optimum in practice. Thus, it is relaxed to be:\n $$ \\mathcal{L}_\\text{struct}^{(ij)} = D_{ij} + \\log \\Big( \\sum_{(i,k)\\in\\mathcal{N}} \\exp(\\epsilon - D_{ik}) + \\sum_{(j,l)\\in\\mathcal{N}} \\exp(\\epsilon - D_{jl}) \\Big) $$ In the paper, they also proposed to enhance the quality of negative samples in each batch by actively incorporating difficult negative samples given a few random positive pairs.\nN-pair Loss Multi-Class N-pair loss (Sohn 2016) generalizes triplet loss to include comparison with multiple negative samples.\nGiven a $(N + 1)$-tuplet of training samples, $\\{ \\mathbf{x}, \\mathbf{x}^+, \\mathbf{x}^-_1, \\dots, \\mathbf{x}^-_{N-1} \\}$, including one positive and $N-1$ negative ones, N-pair loss is defined as:\n $$ \\begin{aligned} \\mathcal{L}_\\text{N-pair}(\\mathbf{x}, \\mathbf{x}^+, \\{\\mathbf{x}^-_i\\}^{N-1}_{i=1}) \u0026= \\log\\big(1 + \\sum_{i=1}^{N-1} \\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^-_i) - f(\\mathbf{x})^\\top f(\\mathbf{x}^+))\\big) \\\\ \u0026= -\\log\\frac{\\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^+))}{\\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^+)) + \\sum_{i=1}^{N-1} \\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^-_i))} \\end{aligned} $$ If we only sample one negative sample per class, it is equivalent to the softmax loss for multi-class classification.\nNCE Noise Contrastive Estimation, short for NCE, is a method for estimating parameters of a statistical model, proposed by Gutmann \u0026amp; Hyvarinen in 2010. The idea is to run logistic regression to tell apart the target data from noise. Read more on how NCE is used for learning word embedding here.\nLet $\\mathbf{x}$ be the target sample $\\sim P(\\mathbf{x} \\vert C=1; \\theta) = p_\\theta(\\mathbf{x})$ and $\\tilde{\\mathbf{x}}$ be the noise sample $\\sim P(\\tilde{\\mathbf{x}} \\vert C=0) = q(\\tilde{\\mathbf{x}})$. Note that the logistic regression models the logit (i.e. log-odds) and in this case we would like to model the logit of a sample $u$ from the target data distribution instead of the noise distribution:\n $$ \\ell_\\theta(\\mathbf{u}) = \\log \\frac{p_\\theta(\\mathbf{u})}{q(\\mathbf{u})} = \\log p_\\theta(\\mathbf{u}) - \\log q(\\mathbf{u}) $$ After converting logits into probabilities with sigmoid $\\sigma(.)$, we can apply cross entropy loss:\n $$ \\begin{aligned} \\mathcal{L}_\\text{NCE} \u0026= - \\frac{1}{N} \\sum_{i=1}^N \\big[ \\log \\sigma (\\ell_\\theta(\\mathbf{x}_i)) + \\log (1 - \\sigma (\\ell_\\theta(\\tilde{\\mathbf{x}}_i))) \\big] \\\\ \\text{ where }\\sigma(\\ell) \u0026= \\frac{1}{1 + \\exp(-\\ell)} = \\frac{p_\\theta}{p_\\theta + q} \\end{aligned} $$ Here I listed the original form of NCE loss which works with only one positive and one noise sample. In many follow-up works, contrastive loss incorporating multiple negative samples is also broadly referred to as NCE.\nInfoNCE The InfoNCE loss in CPC (Contrastive Predictive Coding; van den Oord, et al. 2018), inspired by NCE, uses categorical cross-entropy loss to identify the positive sample amongst a set of unrelated noise samples.\nGiven a context vector $\\mathbf{c}$, the positive sample should be drawn from the conditional distribution $p(\\mathbf{x} \\vert \\mathbf{c})$, while $N-1$ negative samples are drawn from the proposal distribution $p(\\mathbf{x})$, independent from the context $\\mathbf{c}$. For brevity, let us label all the samples as $X=\\{ \\mathbf{x}_i \\}^N_{i=1}$ among which only one of them $\\mathbf{x}_\\texttt{pos}$ is a positive sample. The probability of we detecting the positive sample correctly is:\n $$ p(C=\\texttt{pos} \\vert X, \\mathbf{c}) = \\frac{p(x_\\texttt{pos} \\vert \\mathbf{c}) \\prod_{i=1,\\dots,N; i \\neq \\texttt{pos}} p(\\mathbf{x}_i)}{\\sum_{j=1}^N \\big[ p(\\mathbf{x}_j \\vert \\mathbf{c}) \\prod_{i=1,\\dots,N; i \\neq j} p(\\mathbf{x}_i) \\big]} = \\frac{ \\frac{p(\\mathbf{x}_\\texttt{pos}\\vert c)}{p(\\mathbf{x}_\\texttt{pos})} }{ \\sum_{j=1}^N \\frac{p(\\mathbf{x}_j\\vert \\mathbf{c})}{p(\\mathbf{x}_j)} } = \\frac{f(\\mathbf{x}_\\texttt{pos}, \\mathbf{c})}{ \\sum_{j=1}^N f(\\mathbf{x}_j, \\mathbf{c}) } $$ where the scoring function is $f(\\mathbf{x}, \\mathbf{c}) \\propto \\frac{p(\\mathbf{x}\\vert\\mathbf{c})}{p(\\mathbf{x})}$.\nThe InfoNCE loss optimizes the negative log probability of classifying the positive sample correctly:\n $$ \\mathcal{L}_\\text{InfoNCE} = - \\mathbb{E} \\Big[\\log \\frac{f(\\mathbf{x}, \\mathbf{c})}{\\sum_{\\mathbf{x}' \\in X} f(\\mathbf{x}', \\mathbf{c})} \\Big] $$ The fact that $f(x, c)$ estimates the density ratio $\\frac{p(x\\vert c)}{p(x)}$ has a connection with mutual information optimization. To maximize the the mutual information between input $x$ and context vector $c$, we have:\n $$ I(\\mathbf{x}; \\mathbf{c}) = \\sum_{\\mathbf{x}, \\mathbf{c}} p(\\mathbf{x}, \\mathbf{c}) \\log\\frac{p(\\mathbf{x}, \\mathbf{c})}{p(\\mathbf{x})p(\\mathbf{c})} = \\sum_{\\mathbf{x}, \\mathbf{c}} p(\\mathbf{x}, \\mathbf{c})\\log\\color{blue}{\\frac{p(\\mathbf{x}|\\mathbf{c})}{p(\\mathbf{x})}} $$ where the logarithmic term in blue is estimated by $f$.\nFor sequence prediction tasks, rather than modeling the future observations $p_k(\\mathbf{x}_{t+k} \\vert \\mathbf{c}_t)$ directly (which could be fairly expensive), CPC models a density function to preserve the mutual information between $\\mathbf{x}_{t+k}$ and $\\mathbf{c}_t$:\n $$ f_k(\\mathbf{x}_{t+k}, \\mathbf{c}_t) = \\exp(\\mathbf{z}_{t+k}^\\top \\mathbf{W}_k \\mathbf{c}_t) \\propto \\frac{p(\\mathbf{x}_{t+k}\\vert\\mathbf{c}_t)}{p(\\mathbf{x}_{t+k})} $$ where $\\mathbf{z}_{t+k}$ is the encoded input and $\\mathbf{W}_k$ is a trainable weight matrix.\nSoft-Nearest Neighbors Loss Soft-Nearest Neighbors Loss (Salakhutdinov \u0026amp; Hinton 2007, Frosst et al. 2019) extends it to include multiple positive samples.\nGiven a batch of samples, $\\{\\mathbf{x}_i, y_i)\\}^B_{i=1}$ where $y_i$ is the class label of $\\mathbf{x}_i$ and a function $f(.,.)$ for measuring similarity between two inputs, the soft nearest neighbor loss at temperature $\\tau$ is defined as:\n $$ \\mathcal{L}_\\text{snn} = -\\frac{1}{B}\\sum_{i=1}^B \\log \\frac{\\sum_{i\\neq j, y_i = y_j, j=1,\\dots,B} \\exp(- f(\\mathbf{x}_i, \\mathbf{x}_j) / \\tau)}{\\sum_{i\\neq k, k=1,\\dots,B} \\exp(- f(\\mathbf{x}_i, \\mathbf{x}_k) /\\tau)} $$ The temperature $\\tau$ is used for tuning how concentrated the features are in the representation space. For example, when at low temperature, the loss is dominated by the small distances and widely separated representations cannot contribute much and become irrelevant.\nCommon Setup We can loosen the definition of \u0026ldquo;classes\u0026rdquo; and \u0026ldquo;labels\u0026rdquo; in soft nearest-neighbor loss to create positive and negative sample pairs out of unsupervised data by, for example, applying data augmentation to create noise versions of original samples.\nMost recent studies follow the following definition of contrastive learning objective to incorporate multiple positive and negative samples. According to the setup in (Wang \u0026amp; Isola 2020), let $p_\\texttt{data}(.)$ be the data distribution over $\\mathbb{R}^n$ and $p_\\texttt{pos}(., .)$ be the distribution of positive pairs over $\\mathbb{R}^{n \\times n}$. These two distributions should satisfy:\n Symmetry: $\\forall \\mathbf{x}, \\mathbf{x}^+, p_\\texttt{pos}(\\mathbf{x}, \\mathbf{x}^+) = p_\\texttt{pos}(\\mathbf{x}^+, \\mathbf{x})$ Matching marginal: $\\forall \\mathbf{x}, \\int p_\\texttt{pos}(\\mathbf{x}, \\mathbf{x}^+) d\\mathbf{x}^+ = p_\\texttt{data}(\\mathbf{x})$ To learn an encoder $f(\\mathbf{x})$ to learn a L2-normalized feature vector, the contrastive learning objective is:\n $$ \\begin{aligned} \\mathcal{L}_\\text{contrastive} \u0026= \\mathbb{E}_{(\\mathbf{x},\\mathbf{x}^+)\\sim p_\\texttt{pos}, \\{\\mathbf{x}^-_i\\}^M_{i=1} \\overset{\\text{i.i.d}}{\\sim} p_\\texttt{data} } \\Big[ -\\log\\frac{\\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^+) / \\tau)}{ \\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^+) / \\tau) + \\sum_{i=1}^M \\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}_i^-) / \\tau)} \\Big] \u0026 \\\\ \u0026\\approx \\mathbb{E}_{(\\mathbf{x},\\mathbf{x}^+)\\sim p_\\texttt{pos}, \\{\\mathbf{x}^-_i\\}^M_{i=1} \\overset{\\text{i.i.d}}{\\sim} p_\\texttt{data} }\\Big[ - f(\\mathbf{x})^\\top f(\\mathbf{x}^+) / \\tau + \\log\\big(\\sum_{i=1}^M \\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}_i^-) / \\tau)\\big) \\Big] \u0026 \\scriptstyle{\\text{; Assuming infinite negatives}} \\\\ \u0026= -\\frac{1}{\\tau}\\mathbb{E}_{(\\mathbf{x},\\mathbf{x}^+)\\sim p_\\texttt{pos}}f(\\mathbf{x})^\\top f(\\mathbf{x}^+) + \\mathbb{E}_{ \\mathbf{x} \\sim p_\\texttt{data}} \\Big[ \\log \\mathbb{E}_{\\mathbf{x}^- \\sim p_\\texttt{data}} \\big[ \\sum_{i=1}^M \\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}_i^-) / \\tau)\\big] \\Big] \u0026 \\end{aligned} $$ Key Ingredients Heavy Data Augmentation Given a training sample, data augmentation techniques are needed for creating noise versions of itself to feed into the loss as positive samples. Proper data augmentation setup is critical for learning good and generalizable embedding features. It introduces the non-essential variations into examples without modifying semantic meanings and thus encourages the model to learn the essential part of the representation. For example, experiments in SimCLR showed that the composition of random cropping and random color distortion is crucial for good performance on learning visual representation of images.\nLarge Batch Size Using a large batch size during training is another key ingredient in the success of many contrastive learning methods (e.g. SimCLR, CLIP), especially when it relies on in-batch negatives. Only when the batch size is big enough, the loss function can cover a diverse enough collection of negative samples, challenging enough for the model to learn meaningful representation to distinguish different examples.\nHard Negative Mining Hard negative samples should have different labels from the anchor sample, but have embedding features very close to the anchor embedding. With access to ground truth labels in supervised datasets, it is easy to identify task-specific hard negatives. For example when learning sentence embedding, we can treat sentence pairs labelled as \u0026ldquo;contradiction\u0026rdquo; in NLI datasets as hard negative pairs (e.g. SimCSE, or use top incorrect candidates returned by BM25 with most keywords matched as hard negative samples (DPR; Karpukhin et al., 2020).\nHowever, it becomes tricky to do hard negative mining when we want to remain unsupervised. Increasing training batch size or memory bank size implicitly introduces more hard negative samples, but it leads to a heavy burden of large memory usage as a side effect.\nChuang et al. (2020) studied the sampling bias in contrastive learning and proposed debiased loss. In the unsupervised setting, since we do not know the ground truth labels, we may accidentally sample false negative samples. Sampling bias can lead to significant performance drop.\nFig. 3. Sampling bias which refers to false negative samples in contrastive learning can lead to a big performance drop. (Image source: Chuang et al., 2020) Let us assume the probability of anchor class $c$ is uniform $\\rho(c)=\\eta^+$ and the probability of observing a different class is $\\eta^- = 1-\\eta^+$.\n The probability of observing a positive example for $\\mathbf{x}$ is $p^+_x(\\mathbf{x}')=p(\\mathbf{x}'\\vert \\mathbf{h}_{x'}=\\mathbf{h}_x)$; The probability of getting a negative sample for $\\mathbf{x}$ is $p^-_x(\\mathbf{x}')=p(\\mathbf{x}'\\vert \\mathbf{h}_{x'}\\neq\\mathbf{h}_x)$. When we are sampling $\\mathbf{x}^-$ , we cannot access the true $p^-_x(\\mathbf{x}^-)$ and thus $\\mathbf{x}^-$ may be sampled from the (undesired) anchor class $c$ with probability $\\eta^+$. The actual sampling data distribution becomes:\n $$ p(\\mathbf{x}') = \\eta^+ p^+_x(\\mathbf{x}') + \\eta^- p_x^-(\\mathbf{x}') $$ Thus we can use $p^-_x(\\mathbf{x}') = (p(\\mathbf{x}') - \\eta^+ p^+_x(\\mathbf{x}'))/\\eta^-$ for sampling $\\mathbf{x}^-$ to debias the loss. With $N$ samples $\\{\\mathbf{u}_i\\}^N_{i=1}$ from $p$ and $M$ samples $\\{ \\mathbf{v}_i \\}_{i=1}^M$ from $p^+_x$ , we can estimate the expectation of the second term $\\mathbb{E}_{\\mathbf{x}^-\\sim p^-_x}[\\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^-))]$ in the denominator of contrastive learning loss:\n $$ g(\\mathbf{x}, \\{\\mathbf{u}_i\\}^N_{i=1}, \\{\\mathbf{v}_i\\}_{i=1}^M) = \\max\\Big\\{ \\frac{1}{\\eta^-}\\Big( \\frac{1}{N}\\sum_{i=1}^N \\exp(f(\\mathbf{x})^\\top f(\\mathbf{u}_i)) - \\frac{\\eta^+}{M}\\sum_{i=1}^M \\exp(f(\\mathbf{x})^\\top f(\\mathbf{v}_i)) \\Big), \\exp(-1/\\tau) \\Big\\} $$ where $\\tau$ is the temperature and $\\exp(-1/\\tau)$ is the theoretical lower bound of $\\mathbb{E}_{\\mathbf{x}^-\\sim p^-_x}[\\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^-))]$.\nThe final debiased contrastive loss looks like:\n $$ \\mathcal{L}^{N,M}_\\text{debias}(f) = \\mathbb{E}_{\\mathbf{x},\\{\\mathbf{u}_i\\}^N_{i=1}\\sim p;\\;\\mathbf{x}^+, \\{\\mathbf{v}_i\\}_{i=1}^M\\sim p^+} \\Big[ -\\log\\frac{\\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^+)}{\\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^+) + N g(x,\\{\\mathbf{u}_i\\}^N_{i=1}, \\{\\mathbf{v}_i\\}_{i=1}^M)} \\Big] $$ Fig. 4. t-SNE visualization of learned representation with debiased contrastive learning. (Image source: Chuang et al., 2020) Following the above annotation, Robinson et al. (2021) modified the sampling probabilities to target at hard negatives by up-weighting the probability $p^-_x(x')$ to be proportional to its similarity to the anchor sample. The new sampling probability $q_\\beta(x^-)$ is:\n $$ q_\\beta(\\mathbf{x}^-) \\propto \\exp(\\beta f(\\mathbf{x})^\\top f(\\mathbf{x}^-)) \\cdot p(\\mathbf{x}^-) $$ where $\\beta$ is a hyperparameter to tune.\nWe can estimate the second term in the denominator $\\mathbb{E}_{\\mathbf{x}^- \\sim q_\\beta} [\\exp(f(\\mathbf{x})^\\top f(\\mathbf{x}^-))]$ using importance sampling where both the partition functions $Z_\\beta, Z^+_\\beta$ can be estimated empirically.\n $$ \\begin{aligned} \\mathbb{E}_{\\mathbf{u} \\sim q_\\beta} [\\exp(f(\\mathbf{x})^\\top f(\\mathbf{u}))] \u0026= \\mathbb{E}_{\\mathbf{u} \\sim p} [\\frac{q_\\beta}{p}\\exp(f(\\mathbf{x})^\\top f(\\mathbf{u}))] = \\mathbb{E}_{\\mathbf{u} \\sim p} [\\frac{1}{Z_\\beta}\\exp((\\beta + 1)f(\\mathbf{x})^\\top f(\\mathbf{u}))] \\\\ \\mathbb{E}_{\\mathbf{v} \\sim q^+_\\beta} [\\exp(f(\\mathbf{x})^\\top f(\\mathbf{v}))] \u0026= \\mathbb{E}_{\\mathbf{v} \\sim p^+} [\\frac{q^+_\\beta}{p}\\exp(f(\\mathbf{x})^\\top f(\\mathbf{v}))] = \\mathbb{E}_{\\mathbf{v} \\sim p} [\\frac{1}{Z^+_\\beta}\\exp((\\beta + 1)f(\\mathbf{x})^\\top f(\\mathbf{v}))] \\end{aligned} $$ Fig. 5. Pseudo code for computing NCE loss, debiased contrastive loss, and hard negative sample objective when setting $M=1$. (Image source: Robinson et al., 2021 ) Vision: Image Embedding Image Augmentations Most approaches for contrastive representation learning in the vision domain rely on creating a noise version of a sample by applying a sequence of data augmentation techniques. The augmentation should significantly change its visual appearance but keep the semantic meaning unchanged.\nBasic Image Augmentation There are many ways to modify an image while retaining its semantic meaning. We can use any one of the following augmentation or a composition of multiple operations.\n Random cropping and then resize back to the original size. Random color distortions Random Gaussian blur Random color jittering Random horizontal flip Random grayscale conversion Multi-crop augmentation: Use two standard resolution crops and sample a set of additional low resolution crops that cover only small parts of the image. Using low resolution crops reduces the compute cost. (SwAV) And many more \u0026hellip; Augmentation Strategies Many frameworks are designed for learning good data augmentation strategies (i.e. a composition of multiple transforms). Here are a few common ones.\n AutoAugment (Cubuk, et al. 2018): Inspired by NAS, AutoAugment frames the problem of learning best data augmentation operations (i.e. shearing, rotation, invert, etc.) for image classification as an RL problem and looks for the combination that leads to the highest accuracy on the evaluation set. RandAugment (Cubuk et al., 2019): RandAugment greatly reduces the search space of AutoAugment by controlling the magnitudes of different transformation operations with a single magnitude parameter. PBA (Population based augmentation; Ho et al., 2019): PBA combined PBT (Jaderberg et al, 2017) with AutoAugment, using the evolutionary algorithm to train a population of children models in parallel to evolve the best augmentation strategies. UDA (Unsupervised Data Augmentation; Xie et al., 2019): Among a set of possible augmentation strategies, UDA selects those to minimize the KL divergence between the predicted distribution over an unlabelled example and its unlabelled augmented version. Image Mixture Image mixture methods can construct new training examples from existing data points.\n Mixup (Zhang et al., 2018): It runs global-level mixture by creating a weighted pixel-wise combination of two existing images $I_1$ and $I_2$: $I_\\text{mixup} \\gets \\alpha I_1 + (1-\\alpha) I_2$ and $\\alpha \\in [0, 1]$. Cutmix (Yun et al., 2019): Cutmix does region-level mixture by generating a new example by combining a local region of one image with the rest of the other image. $I_\\text{cutmix} \\gets \\mathbf{M}_b \\odot I_1 + (1-\\mathbf{M}_b) \\odot I_2$, where $\\mathbf{M}_b \\in \\{0, 1\\}^I$ is a binary mask and $\\odot$ is element-wise multiplication. It is equivalent to filling the cutout (DeVries \u0026amp; Taylor 2017) region with the same region from another image. MoCHi (\u0026ldquo;Mixing of Contrastive Hard Negatives\u0026rdquo;; Kalantidis et al. 2020): Given a query $\\mathbf{q}$, MoCHi maintains a queue of $K$ negative features $Q=\\{\\mathbf{n}_1, \\dots, \\mathbf{n}_K \\}$ and sorts these negative features by similarity to the query, $\\mathbf{q}^\\top \\mathbf{n}$, in descending order. The first $N$ items in the queue are considered as the hardest negatives, $Q^N$. Then synthetic hard examples can be generated by $\\mathbf{h} = \\tilde{\\mathbf{h}} / |\\tilde{\\mathbf{h}}|$ where $\\tilde{\\mathbf{h}} = \\alpha\\mathbf{n}_i + (1-\\alpha) \\mathbf{n}_j$ and $\\alpha \\in (0, 1)$. Even harder examples can be created by mixing with the query feature, $\\mathbf{h}' = \\tilde{\\mathbf{h}'} / |\\tilde{\\mathbf{h}'}|_2$ where $\\tilde{\\mathbf{h}'} = \\beta\\mathbf{q} + (1-\\beta) \\mathbf{n}_j$ and $\\beta \\in (0, 0.5)$. Parallel Augmentation This category of approaches produce two noise versions of one anchor image and aim to learn representation such that these two augmented samples share the same embedding.\nSimCLR SimCLR (Chen et al, 2020) proposed a simple framework for contrastive learning of visual representations. It learns representations for visual inputs by maximizing agreement between differently augmented views of the same sample via a contrastive loss in the latent space.\nFig. 6. A simple framework for contrastive learning of visual representations. (Image source: Chen et al, 2020) Randomly sample a minibatch of $N$ samples and each sample is applied with two different data augmentation operations, resulting in $2N$ augmented samples in total. $$ \\tilde{\\mathbf{x}}_i = t(\\mathbf{x}),\\quad\\tilde{\\mathbf{x}}_j = t'(\\mathbf{x}),\\quad t, t' \\sim \\mathcal{T} $$ where two separate data augmentation operators, $t$ and $t'$, are sampled from the same family of augmentations $\\mathcal{T}$. Data augmentation includes random crop, resize with random flip, color distortions, and Gaussian blur.\nGiven one positive pair, other $2(N-1)$ data points are treated as negative samples. The representation is produced by a base encoder $f(.)$: $$ \\mathbf{h}_i = f(\\tilde{\\mathbf{x}}_i),\\quad \\mathbf{h}_j = f(\\tilde{\\mathbf{x}}_j) $$ The contrastive learning loss is defined using cosine similarity $\\text{sim}(.,.)$. Note that the loss operates on an extra projection layer of the representation $g(.)$ rather than on the representation space directly. But only the representation $\\mathbf{h}$ is used for downstream tasks. $$ \\begin{aligned} \\mathbf{z}_i \u0026= g(\\mathbf{h}_i),\\quad \\mathbf{z}_j = g(\\mathbf{h}_j) \\\\ \\mathcal{L}_\\text{SimCLR}^{(i,j)} \u0026= - \\log\\frac{\\exp(\\text{sim}(\\mathbf{z}_i, \\mathbf{z}_j) / \\tau)}{\\sum_{k=1}^{2N} \\mathbb{1}_{[k \\neq i]} \\exp(\\text{sim}(\\mathbf{z}_i, \\mathbf{z}_k) / \\tau)} \\end{aligned} $$ where $\\mathbb{1}_{[k \\neq i]}$ is an indicator function: 1 if $k\\neq i$ 0 otherwise.\nSimCLR needs a large batch size to incorporate enough negative samples to achieve good performance.\nFig. 7. The algorithm for SimCLR. (Image source: Chen et al, 2020). Barlow Twins Barlow Twins (Zbontar et al. 2021) feeds two distorted versions of samples into the same network to extract features and learns to make the cross-correlation matrix between these two groups of output features close to the identity. The goal is to keep the representation vectors of different distorted versions of one sample similar, while minimizing the redundancy between these vectors.\nFig. 8. Illustration of Barlow Twins learning pipeline. (Image source: Zbontar et al. 2021). Let $\\mathcal{C}$ be a cross-correlation matrix computed between outputs from two identical networks along the batch dimension. $\\mathcal{C}$ is a square matrix with the size same as the feature network\u0026rsquo;s output dimensionality. Each entry in the matrix $\\mathcal{C}_{ij}$ is the cosine similarity between network output vector dimension at index $i, j$ and batch index $b$, $\\mathbf{z}_{b,i}^A$ and $\\mathbf{z}_{b,j}^B$, with a value between -1 (i.e. perfect anti-correlation) and 1 (i.e. perfect correlation).\n $$ \\begin{aligned} \\mathcal{L}_\\text{BT} \u0026= \\underbrace{\\sum_i (1-\\mathcal{C}_{ii})^2}_\\text{invariance term} + \\lambda \\underbrace{\\sum_i\\sum_{i\\neq j} \\mathcal{C}_{ij}^2}_\\text{redundancy reduction term} \\\\ \\text{where } \\mathcal{C}_{ij} \u0026= \\frac{\\sum_b \\mathbf{z}^A_{b,i} \\mathbf{z}^B_{b,j}}{\\sqrt{\\sum_b (\\mathbf{z}^A_{b,i})^2}\\sqrt{\\sum_b (\\mathbf{z}^B_{b,j})^2}} \\end{aligned} $$ Barlow Twins is competitive with SOTA methods for self-supervised learning. It naturally avoids trivial constants (i.e. collapsed representations), and is robust to different training batch sizes.\nFig. 9. Algorithm of Barlow Twins in Pytorch style pseudo code. (Image source: Zbontar et al. 2021). BYOL Different from the above approaches, interestingly, BYOL (Bootstrap Your Own Latent; Grill, et al 2020) claims to achieve a new state-of-the-art results without using egative samples. It relies on two neural networks, referred to as online and target networks that interact and learn from each other. The target network (parameterized by $\\xi$) has the same architecture as the online one (parameterized by $\\theta$), but with polyak averaged weights, $\\xi \\leftarrow \\tau \\xi + (1-\\tau) \\theta$.\nThe goal is to learn a presentation $y$ that can be used in downstream tasks. The online network parameterized by $\\theta$ contains:\n An encoder $f_\\theta$; A projector $g_\\theta$; A predictor $q_\\theta$. The target network has the same network architecture, but with different parameter $\\xi$, updated by polyak averaging $\\theta$: $\\xi \\leftarrow \\tau \\xi + (1-\\tau) \\theta$.\nFig. 10. The model architecture of BYOL. After training, we only care about $f\\_\\theta$ for producing representation, $y=f\\_\\theta(x)$, and everything else is discarded. $\\text{sg}$ means stop gradient. (Image source: Grill, et al 2020) Given an image $\\mathbf{x}$, the BYOL loss is constructed as follows:\n Create two augmented views: $\\mathbf{v}=t(\\mathbf{x}); \\mathbf{v}'=t'(\\mathbf{x})$ with augmentations sampled $t \\sim \\mathcal{T}, t' \\sim \\mathcal{T}'$; Then they are encoded into representations, $\\mathbf{y}_\\theta=f_\\theta(\\mathbf{v}), \\mathbf{y}'=f_\\xi(\\mathbf{v}')$; Then they are projected into latent variables, $\\mathbf{z}_\\theta=g_\\theta(\\mathbf{y}_\\theta), \\mathbf{z}'=g_\\xi(\\mathbf{y}')$; The online network outputs a prediction $q_\\theta(\\mathbf{z}_\\theta)$; Both $q_\\theta(\\mathbf{z}_\\theta)$ and $\\mathbf{z}'$ are L2-normalized, giving us $\\bar{q}_\\theta(\\mathbf{z}_\\theta) = q_\\theta(\\mathbf{z}_\\theta) / | q_\\theta(\\mathbf{z}_\\theta) |$ and $\\bar{\\mathbf{z}'} = \\mathbf{z}' / |\\mathbf{z}'|$; The loss $\\mathcal{L}^\\text{BYOL}_\\theta$ is MSE between L2-normalized prediction $\\bar{q}_\\theta(\\mathbf{z})$ and $\\bar{\\mathbf{z}'}$; The other symmetric loss $\\tilde{\\mathcal{L}}^\\text{BYOL}_\\theta$ can be generated by switching $\\mathbf{v}'$ and $\\mathbf{v}$; that is, feeding $\\mathbf{v}'$ to online network and $\\mathbf{v}$ to target network. The final loss is $\\mathcal{L}^\\text{BYOL}_\\theta + \\tilde{\\mathcal{L}}^\\text{BYOL}_\\theta$ and only parameters $\\theta$ are optimized. Unlike most popular contrastive learning based approaches, BYOL does not use negative pairs. Most bootstrapping approaches rely on pseudo-labels or cluster indices, but BYOL directly boostrapps the latent representation.\nIt is quite interesting and surprising that without negative samples, BYOL still works well. Later I ran into this post by Abe Fetterman \u0026amp; Josh Albrecht, they highlighted two surprising findings while they were trying to reproduce BYOL:\n BYOL generally performs no better than random when batch normalization is removed. The presence of batch normalization implicitly causes a form of contrastive learning. They believe that using negative samples is important for avoiding model collapse (i.e. what if you use all-zeros representation for every data point?). Batch normalization injects dependency on negative samples inexplicitly because no matter how similar a batch of inputs are, the values are re-distributed (spread out $\\sim \\mathcal{N}(0, 1$) and therefore batch normalization prevents model collapse. Strongly recommend you to read the full article if you are working in this area. Memory Bank Computing embeddings for a large number of negative samples in every batch is extremely expensive. One common approach is to store the representation in memory to trade off data staleness for cheaper compute.\nInstance Discrimination with Memoy Bank Instance contrastive learning (Wu et al, 2018) pushes the class-wise supervision to the extreme by considering each instance as a distinct class of its own. It implies that the number of \u0026ldquo;classes\u0026rdquo; will be the same as the number of samples in the training dataset. Hence, it is unfeasible to train a softmax layer with these many heads, but instead it can be approximated by NCE.\nFig. 11. The training pipeline of instance-level contrastive learning. The learned embedding is L2-normalized. (Image source: Wu et al, 2018) Let $\\mathbf{v} = f_\\theta(x)$ be an embedding function to learn and the vector is normalized to have $|\\mathbf{v}|=1$. A non-parametric classifier predicts the probability of a sample $\\mathbf{v}$ belonging to class $i$ with a temperature parameter $\\tau$:\n $$ P(C=i\\vert \\mathbf{v}) = \\frac{\\exp(\\mathbf{v}_i^\\top \\mathbf{v} / \\tau)}{\\sum_{j=1}^n \\exp(\\mathbf{v}_j^\\top \\mathbf{v} / \\tau)} $$ Instead of computing the representations for all the samples every time, they implement an Memory Bank for storing sample representation in the database from past iterations. Let $V=\\{ \\mathbf{v}_i \\}$ be the memory bank and $\\mathbf{f}_i = f_\\theta(\\mathbf{x}_i)$ be the feature generated by forwarding the network. We can use the representation from the memory bank $\\mathbf{v}_i$ instead of the feature forwarded from the network $\\mathbf{f}_i$ when comparing pairwise similarity.\nThe denominator theoretically requires access to the representations of all the samples, but that is too expensive in practice. Instead we can estimate it via Monte Carlo approximation using a random subset of $M$ indices $\\{j_k\\}_{k=1}^M$.\n $$ P(i\\vert \\mathbf{v}) = \\frac{\\exp(\\mathbf{v}^\\top \\mathbf{f}_i / \\tau)}{\\sum_{j=1}^N \\exp(\\mathbf{v}_j^\\top \\mathbf{f}_i / \\tau)} \\simeq \\frac{\\exp(\\mathbf{v}^\\top \\mathbf{f}_i / \\tau)}{\\frac{N}{M} \\sum_{k=1}^M \\exp(\\mathbf{v}_{j_k}^\\top \\mathbf{f}_i / \\tau)} $$ Because there is only one instance per class, the training is unstable and fluctuates a lot. To improve the training smoothness, they introduced an extra term for positive samples in the loss function based on the proximal optimization method. The final NCE loss objective looks like:\n $$ \\begin{aligned} \\mathcal{L}_\\text{instance} \u0026= - \\mathbb{E}_{P_d}\\big[\\log h(i, \\mathbf{v}^{(t-1)}_i) - \\lambda \\|\\mathbf{v}^{(t)}_i - \\mathbf{v}^{(t-1)}_i\\|^2_2\\big] - M\\mathbb{E}_{P_n}\\big[\\log(1 - h(i, \\mathbf{v}'^{(t-1)})\\big] \\\\ h(i, \\mathbf{v}) \u0026= \\frac{P(i\\vert\\mathbf{v})}{P(i\\vert\\mathbf{v}) + MP_n(i)} \\text{ where the noise distribution is uniform }P_n = 1/N \\end{aligned} $$ where $\\{ \\mathbf{v}^{(t-1)} \\}$ are embeddings stored in the memory bank from the previous iteration. The difference between iterations $|\\mathbf{v}^{(t)}_i - \\mathbf{v}^{(t-1)}_i|^2_2$ will gradually vanish as the learned embedding converges.\nMoCo \u0026amp; MoCo-V2 Momentum Contrast (MoCo; He et al, 2019) provides a framework of unsupervised learning visual representation as a dynamic dictionary look-up. The dictionary is structured as a large FIFO queue of encoded representations of data samples.\nGiven a query sample $\\mathbf{x}_q$, we get a query representation through an encoder $\\mathbf{q} = f_q(\\mathbf{x}_q)$. A list of key representations $\\{\\mathbf{k}_1, \\mathbf{k}_2, \\dots \\}$ in the dictionary are encoded by a momentum encoder $\\mathbf{k}_i = f_k (\\mathbf{x}^k_i)$. Let\u0026rsquo;s assume among them there is a single positive key $\\mathbf{k}^+$ in the dictionary that matches $\\mathbf{q}$. In the paper, they create $\\mathbf{k}^+$ using a noise copy of $\\mathbf{x}_q$ with different augmentation. Then the InfoNCE contrastive loss with temperature $\\tau$ is used over one positive and $N-1$ negative samples:\n $$ \\mathcal{L}_\\text{MoCo} = - \\log \\frac{\\exp(\\mathbf{q} \\cdot \\mathbf{k}^+ / \\tau)}{\\sum_{i=1}^N \\exp(\\mathbf{q} \\cdot \\mathbf{k}_i / \\tau)} $$ Compared to the memory bank, a queue-based dictionary in MoCo enables us to reuse representations of immediately preceding mini-batches of data.\nThe MoCo dictionary is not differentiable as a queue, so we cannot rely on back-propagation to update the key encoder $f_k$. One naive way might be to use the same encoder for both $f_q$ and $f_k$. Differently, MoCo proposed to use a momentum-based update with a momentum coefficient $m \\in [0, 1)$. Say, the parameters of $f_q$ and $f_k$ are labeled as $\\theta_q$ and $\\theta_k$, respectively.\n $$ \\theta_k \\leftarrow m \\theta_k + (1-m) \\theta_q $$ Fig. 12. Illustration of how Momentum Contrast (MoCo) learns visual representations. (Image source: He et al, 2019) The advantage of MoCo compared to SimCLR is that MoCo decouples the batch size from the number of negatives, but SimCLR requires a large batch size in order to have enough negative samples and suffers performance drops when their batch size is reduced.\nTwo designs in SimCLR, namely, (1) an MLP projection head and (2) stronger data augmentation, are proved to be very efficient. MoCo V2 (Chen et al, 2020) combined these two designs, achieving even better transfer performance with no dependency on a very large batch size.\nCURL CURL (Srinivas, et al. 2020) applies the above ideas in Reinforcement Learning. It learns a visual representation for RL tasks by matching embeddings of two data-augmented versions, $o_q$ and $o_k$, of the raw observation $o$ via contrastive loss. CURL primarily relies on random crop data augmentation. The key encoder is implemented as a momentum encoder with weights as EMA of the query encoder weights, same as in MoCo.\nOne significant difference between RL and supervised visual tasks is that RL depends on temporal consistency between consecutive frames. Therefore, CURL applies augmentation consistently on each stack of frames to retain information about the temporal structure of the observation.\nFig. 13. The architecture of CURL. (Image source: Srinivas, et al. 2020) Feature Clustering DeepCluster DeepCluster (Caron et al. 2018) iteratively clusters features via k-means and uses cluster assignments as pseudo labels to provide supervised signals.\nFig. 14. Illustration of DeepCluster method which iteratively clusters deep features and uses the cluster assignments as pseudo-labels. (Image source: Caron et al. 2018) In each iteration, DeepCluster clusters data points using the prior representation and then produces the new cluster assignments as the classification targets for the new representation. However this iterative process is prone to trivial solutions. While avoiding the use of negative pairs, it requires a costly clustering phase and specific precautions to avoid collapsing to trivial solutions.\nSwAV SwAV (Swapping Assignments between multiple Views; Caron et al. 2020) is an online contrastive learning algorithm. It computes a code from an augmented version of the image and tries to predict this code using another augmented version of the same image.\nFig. 15. Comparison of SwAV and [contrastive instance learning](#instance-discrimination-with-memoy-bank). (Image source: Caron et al. 2020) Given features of images with two different augmentations, $\\mathbf{z}_t$ and $\\mathbf{z}_s$, SwAV computes corresponding codes $\\mathbf{q}_t$ and $\\mathbf{q}_s$ and the loss quantifies the fit by swapping two codes using $\\ell(.)$ to measure the fit between a feature and a code.\n $$ \\mathcal{L}_\\text{SwAV}(\\mathbf{z}_t, \\mathbf{z}_s) = \\ell(\\mathbf{z}_t, \\mathbf{q}_s) + \\ell(\\mathbf{z}_s, \\mathbf{q}_t) $$ The swapped fit prediction depends on the cross entropy between the predicted code and a set of $K$ trainable prototype vectors $\\mathbf{C} = \\{\\mathbf{c}_1, \\dots, \\mathbf{c}_K\\}$. The prototype vector matrix is shared across different batches and represents anchor clusters that each instance should be clustered to.\n $$ \\ell(\\mathbf{z}_t, \\mathbf{q}_s) = - \\sum_k \\mathbf{q}^{(k)}_s\\log\\mathbf{p}^{(k)}_t \\text{ where } \\mathbf{p}^{(k)}_t = \\frac{\\exp(\\mathbf{z}_t^\\top\\mathbf{c}_k / \\tau)}{\\sum_{k'}\\exp(\\mathbf{z}_t^\\top \\mathbf{c}_{k'} / \\tau)} $$ In a mini-batch containing $B$ feature vectors $\\mathbf{Z} = [\\mathbf{z}_1, \\dots, \\mathbf{z}_B]$, the mapping matrix between features and prototype vectors is defined as $\\mathbf{Q} = [\\mathbf{q}_1, \\dots, \\mathbf{q}_B] \\in \\mathbb{R}_+^{K\\times B}$. We would like to maximize the similarity between the features and the prototypes:\n $$ \\begin{aligned} \\max_{\\mathbf{Q}\\in\\mathcal{Q}} \u0026\\text{Tr}(\\mathbf{Q}^\\top \\mathbf{C}^\\top \\mathbf{Z}) + \\varepsilon \\mathcal{H}(\\mathbf{Q}) \\\\ \\text{where }\\mathcal{Q} \u0026= \\big\\{ \\mathbf{Q} \\in \\mathbb{R}_{+}^{K \\times B} \\mid \\mathbf{Q}\\mathbf{1}_B = \\frac{1}{K}\\mathbf{1}_K, \\mathbf{Q}^\\top\\mathbf{1}_K = \\frac{1}{B}\\mathbf{1}_B \\big\\} \\end{aligned} $$ where $\\mathcal{H}$ is the entropy, $\\mathcal{H}(\\mathbf{Q}) = - \\sum_{ij} \\mathbf{Q}_{ij} \\log \\mathbf{Q}_{ij}$, controlling the smoothness of the code. The coefficient $\\epsilon$ should not be too large; otherwise, all the samples will be assigned uniformly to all the clusters. The candidate set of solutions for $\\mathbf{Q}$ requires every mapping matrix to have each row sum up to $1/K$ and each column to sum up to $1/B$, enforcing that each prototype gets selected at least $B/K$ times on average.\nSwAV relies on the iterative Sinkhorn-Knopp algorithm (Cuturi 2013) to find the solution for $\\mathbf{Q}$.\nWorking with Supervised Datasets CLIP CLIP (Contrastive Language-Image Pre-training; Radford et al. 2021) jointly trains a text encoder and an image feature extractor over the pretraining task that predicts which caption goes with which image.\nFig. 16. Illustration of CLIP contrastive pre-training over text-image pairs. (Image source: Radford et al. 2021) Given a batch of $N$ (image, text) pairs, CLIP computes the dense cosine similarity matrix between all $N\\times N$ possible (image, text) candidates within this batch. The text and image encoders are jointly trained to maximize the similarity between $N$ correct pairs of (image, text) associations while minimizing the similarity for $N(N-1)$ incorrect pairs via a symmetric cross entropy loss over the dense matrix.\nSee the numy-like pseudo code for CLIP in Fig. 17.\nFig. 17. CLIP algorithm in Numpy style pseudo code. (Image source: Radford et al. 2021) Compared to other methods above for learning good visual representation, what makes CLIP really special is \u0026ldquo;the appreciation of using natural language as a training signal\u0026rdquo;. It does demand access to supervised dataset in which we know which text matches which image. It is trained on 400 million (text, image) pairs, collected from the Internet. The query list contains all the words occurring at least 100 times in the English version of Wikipedia. Interestingly, they found that Transformer-based language models are 3x slower than a bag-of-words (BoW) text encoder at zero-shot ImageNet classification. Using contrastive objective instead of trying to predict the exact words associated with images (i.e. a method commonly adopted by image caption prediction tasks) can further improve the data efficiency another 4x.\nFig. 18. Using bag-of-words text encoding and contrastive training objectives can bring in multiple folds of data efficiency improvement. (Image source: Radford et al. 2021) CLIP produces good visual representation that can non-trivially transfer to many CV benchmark datasets, achieving results competitive with supervised baseline. Among tested transfer tasks, CLIP struggles with very fine-grained classification, as well as abstract or systematic tasks such as counting the number of objects. The transfer performance of CLIP models is smoothly correlated with the amount of model compute.\nSupervised Contrastive Learning There are several known issues with cross entropy loss, such as the lack of robustness to noisy labels and the possibility of poor margins. Existing improvement for cross entropy loss involves the curation of better training data, such as label smoothing and data augmentation. Supervised Contrastive Loss (Khosla et al. 2021) aims to leverage label information more effectively than cross entropy, imposing that normalized embeddings from the same class are closer together than embeddings from different classes.\nFig. 19. Supervised vs self-supervised contrastive losses. Supervised contrastive learning considers different samples from the same class as positive examples, in addition to augmented versions. (Image source: Khosla et al. 2021) Given a set of randomly sampled $n$ (image, label) pairs, $\\{\\mathbf{x}_i, y_i\\}_{i=1}^n$, $2n$ training pairs can be created by applying two random augmentations of every sample, $\\{\\tilde{\\mathbf{x}}_i, \\tilde{y}_i\\}_{i=1}^{2n}$.\nSupervised contrastive loss $\\mathcal{L}_\\text{supcon}$ utilizes multiple positive and negative samples, very similar to soft nearest-neighbor loss:\n $$ \\mathcal{L}_\\text{supcon} = - \\sum_{i=1}^{2n} \\frac{1}{2 \\vert N_i \\vert - 1} \\sum_{j \\in N(y_i), j \\neq i} \\log \\frac{\\exp(\\mathbf{z}_i \\cdot \\mathbf{z}_j / \\tau)}{\\sum_{k \\in I, k \\neq i}\\exp({\\mathbf{z}_i \\cdot \\mathbf{z}_k / \\tau})} $$ where $\\mathbf{z}_k=P(E(\\tilde{\\mathbf{x}_k}))$, in which $E(.)$ is an encoder network (augmented image mapped to vector) $P(.)$ is a projection network (one vector mapped to another). $N_i= \\{j \\in I: \\tilde{y}_j = \\tilde{y}_i \\}$ contains a set of indices of samples with label $y_i$. Including more positive samples into the set $N_i$ leads to improved results.\nAccording to their experiments, supervised contrastive loss:\n does outperform the base cross entropy, but only by a small amount. outperforms the cross entropy on robustness benchmark (ImageNet-C, which applies common naturally occuring perturbations such as noise, blur and contrast changes to the ImageNet dataset). is less sensitive to hyperparameter changes. Language: Sentence Embedding In this section, we focus on how to learn sentence embedding.\nText Augmentation Most contrastive methods in vision applications depend on creating an augmented version of each image. However, it is more challenging to construct text augmentation which does not alter the semantics of a sentence. In this section we look into three approaches for augmenting text sequences, including lexical edits, back-translation and applying cutoff or dropout.\nLexical Edits EDA (Easy Data Augmentation; Wei \u0026amp; Zou 2019) defines a set of simple but powerful operations for text augmentation. Given a sentence, EDA randomly chooses and applies one of four simple operations:\n Synonym replacement (SR): Replace $n$ random non-stop words with their synonyms. Random insertion (RI): Place a random synonym of a randomly selected non-stop word in the sentence at a random position. Random swap (RS): Randomly swap two words and repeat $n$ times. Random deletion (RD): Randomly delete each word in the sentence with probability $p$. where $p=\\alpha$ and $n=\\alpha \\times \\text{sentence_length}$, with the intuition that longer sentences can absorb more noise while maintaining the original label. The hyperparameter $\\alpha$ roughly indicates the percent of words in one sentence that may be changed by one augmentation.\nEDA is shown to improve the classification accuracy on several classification benchmark datasets compared to baseline without EDA. The performance lift is more significant on a smaller training set. All the four operations in EDA help improve the classification accuracy, but get to optimal at different $\\alpha$\u0026rsquo;s.\nFig. 20. EDA leads to performance improvement on several classification benchmarks. (Image source: Wei \u0026 Zou 2019) In Contextual Augmentation (Sosuke Kobayashi, 2018), new substitutes for word $w_i$ at position $i$ can be smoothly sampled from a given probability distribution, $p(.\\mid S\\setminus\\{w_i\\})$, which is predicted by a bidirectional LM like BERT.\nBack-translation CERT (Contrastive self-supervised Encoder Representations from Transformers; Fang et al. (2020); code) generates augmented sentences via back-translation. Various translation models for different languages can be employed for creating different versions of augmentations. Once we have a noise version of text samples, many contrastive learning frameworks introduced above, such as MoCo, can be used to learn sentence embedding.\nDropout and Cutoff Shen et al. (2020) proposed to apply Cutoff to text augmentation, inspired by cross-view training. They proposed three cutoff augmentation strategies:\n Token cutoff removes the information of a few selected tokens. To make sure there is no data leakage, corresponding tokens in the input, positional and other relevant embedding matrices should all be zeroed out., Feature cutoff removes a few feature columns. Span cutoff removes a continuous chunk of texts. Fig. 21. Schematic illustration of token, feature and span cutoff augmentation strategies. (Image source: Shen et al. 2020) Multiple augmented versions of one sample can be created. When training, Shen et al. (2020) applied an additional KL-divergence term to measure the consensus between predictions from different augmented samples.\nSimCSE (Gao et al. 2021; code) learns from unsupervised data by predicting a sentence from itself with only dropout noise. In other words, they treat dropout as data augmentation for text sequences. A sample is simply fed into the encoder twice with different dropout masks and these two versions are the positive pair where the other in-batch samples are considered as negative pairs. It feels quite similar to the cutoff augmentation, but dropout is more flexible with less well-defined semantic meaning of what content can be masked off.\nFig. 22. SimCSE creates augmented samples by applying different dropout masks. The supervised version leverages NLI datasets to predict positive (entailment) or negative (contradiction) given a pair of sentences. (Image source: Gao et al. 2021) They ran experiments on 7 STS (Semantic Text Similarity) datasets and computed cosine similarity between sentence embeddings. They also tried out an optional MLM auxiliary objective loss to help avoid catastrophic forgetting of token-level knowledge. This aux loss was found to help improve performance on transfer tasks, but a consistent drop on the main STS tasks.\nFig. 23. Experiment numbers on a collection of STS benchmarks with SimCES. (Image source: Gao et al. 2021) Supervision from NLI The pre-trained BERT sentence embedding without any fine-tuning has been found to have poor performance for semantic similarity tasks. Instead of using the raw embeddings directly, we need to refine the embedding with further fine-tuning.\nNatural Language Inference (NLI) tasks are the main data sources to provide supervised signals for learning sentence embedding; such as SNLI, MNLI, and QQP.\nSentence-BERT SBERT (Sentence-BERT) (Reimers \u0026amp; Gurevych, 2019) relies on siamese and triplet network architectures to learn sentence embeddings such that the sentence similarity can be estimated by cosine similarity between pairs of embeddings. Note that learning SBERT depends on supervised data, as it is fine-tuned on several NLI datasets.\nThey experimented with a few different prediction heads on top of BERT model:\n Softmax classification objective: The classification head of the siamese network is built on the concatenation of two embeddings $f(\\mathbf{x}), f(\\mathbf{x}')$ and $\\vert f(\\mathbf{x}) - f(\\mathbf{x}') \\vert$. The predicted output is $\\hat{y}=\\text{softmax}(\\mathbf{W}_t [f(\\mathbf{x}); f(\\mathbf{x}'); \\vert f(\\mathbf{x}) - f(\\mathbf{x}') \\vert])$. They showed that the most important component is the element-wise difference $\\vert f(\\mathbf{x}) - f(\\mathbf{x}') \\vert$. Regression objective: This is the regression loss on $\\cos(f(\\mathbf{x}), f(\\mathbf{x}'))$, in which the pooling strategy has a big impact. In the experiments, they observed that max performs much worse than mean and CLS-token. Triplet objective: $\\max(0, |f(\\mathbf{x}) - f(\\mathbf{x}^+)|- |f(\\mathbf{x}) - f(\\mathbf{x}^-)| + \\epsilon)$, where $\\mathbf{x}, \\mathbf{x}^+, \\mathbf{x}^-$ are embeddings of the anchor, positive and negative sentences. In the experiments, which objective function works the best depends on the datasets, so there is no universal winner.\nFig. 24. Illustration of Sentence-BERT training framework with softmax classification head and regression head. (Image source: Reimers \u0026 Gurevych, 2019) The SentEval library (Conneau and Kiela, 2018) is commonly used for evaluating the quality of learned sentence embedding. SBERT outperformed other baselines at that time (Aug 2019) on 5 out of 7 tasks.\nFig. 25. The performance of Sentence-BERT on the SentEval benchmark. (Image source: Reimers \u0026 Gurevych, 2019) BERT-flow The embedding representation space is deemed isotropic if embeddings are uniformly distributed on each dimension; otherwise, it is anisotropic. Li et al, (2020) showed that a pre-trained BERT learns a non-smooth anisotropic semantic space of sentence embeddings and thus leads to poor performance for text similarity tasks without fine-tuning. Empirically, they observed two issues with BERT sentence embedding: Word frequency biases the embedding space. High-frequency words are close to the origin, but low-frequency ones are far away from the origin. Low-frequency words scatter sparsely. The embeddings of low-frequency words tend to be farther to their $k$-NN neighbors, while the embeddings of high-frequency words concentrate more densely.\nBERT-flow (Li et al, 2020; code) was proposed to transform the embedding to a smooth and isotropic Gaussian distribution via normalizing flows.\nFig. 26. Illustration of the flow-based calibration over the original sentence embedding space in BERT-flow. (Image source: Li et al, 2020) Let $\\mathcal{U}$ be the observed BERT sentence embedding space and $\\mathcal{Z}$ be the desired latent space which is a standard Gaussian. Thus, $p_\\mathcal{Z}$ is a Gaussian density function and $f_\\phi: \\mathcal{Z}\\to\\mathcal{U}$ is an invertible transformation:\n $$ \\mathbf{z}\\sim p_\\mathcal{Z}(\\mathbf{z}) \\quad \\mathbf{u}=f_\\phi(\\mathbf{z}) \\quad \\mathbf{z}=f^{-1}_\\phi(\\mathbf{u}) $$ A flow-based generative model learns the invertible mapping function by maximizing the likelihood of $\\mathcal{U}$\u0026rsquo;s marginal:\n $$ \\max_\\phi\\mathbb{E}_{\\mathbf{u}=\\text{BERT}(s), s\\sim\\mathcal{D}} \\Big[ \\log p_\\mathcal{Z}(f^{-1}_\\phi(\\mathbf{u})) + \\log\\big\\vert\\det\\frac{\\partial f^{-1}_\\phi(\\mathbf{u})}{\\partial\\mathbf{u}}\\big\\vert \\Big] $$ where $s$ is a sentence sampled from the text corpus $\\mathcal{D}$. Only the flow parameters $\\phi$ are optimized while parameters in the pretrained BERT stay unchanged.\nBERT-flow was shown to improve the performance on most STS tasks either with or without supervision from NLI datasets. Because learning normalizing flows for calibration does not require labels, it can utilize the entire dataset including validation and test sets.\nWhitening Operation Su et al. (2021) applied whitening operation to improve the isotropy of the learned representation and also to reduce the dimensionality of sentence embedding.\nThey transform the mean value of the sentence vectors to 0 and the covariance matrix to the identity matrix. Given a set of samples $\\{\\mathbf{x}_i\\}_{i=1}^N$, let $\\tilde{\\mathbf{x}}_i$ and $\\tilde{\\Sigma}$ be the transformed samples and corresponding covariance matrix:\n $$ \\begin{aligned} \\mu \u0026= \\frac{1}{N}\\sum_{i=1}^N \\mathbf{x}_i \\quad \\Sigma = \\frac{1}{N}\\sum_{i=1}^N (\\mathbf{x}_i - \\mu)^\\top (\\mathbf{x}_i - \\mu) \\\\ \\tilde{\\mathbf{x}}_i \u0026= (\\mathbf{x}_i - \\mu)W \\quad \\tilde{\\Sigma} = W^\\top\\Sigma W = I \\text{ thus } \\Sigma = (W^{-1})^\\top W^{-1} \\end{aligned} $$ If we get SVD decomposition of $\\Sigma = U\\Lambda U^\\top$, we will have $W^{-1}=\\sqrt{\\Lambda} U^\\top$ and $W=U\\sqrt{\\Lambda^{-1}}$. Note that within SVD, $U$ is an orthogonal matrix with column vectors as eigenvectors and $\\Lambda$ is a diagonal matrix with all positive elements as sorted eigenvalues.\nA dimensionality reduction strategy can be applied by only taking the first $k$ columns of $W$, named Whitening-$k$.\nFig. 27. Pseudo code of the whitening-$k$ operation. (Image source: Su et al. 2021) Whitening operations were shown to outperform BERT-flow and achieve SOTA with 256 sentence dimensionality on many STS benchmarks, either with or without NLI supervision.\nUnsupervised Sentence Embedding Learning Context Prediction Quick-Thought (QT) vectors (Logeswaran \u0026amp; Lee, 2018) formulate sentence representation learning as a classification problem: Given a sentence and its context, a classifier distinguishes context sentences from other contrastive sentences based on their vector representations (\u0026ldquo;cloze test\u0026rdquo;). Such a formulation removes the softmax output layer which causes training slowdown.\nFig. 28. Illustration of how Quick-Thought sentence embedding vectors are learned. (Image source: Logeswaran \u0026 Lee, 2018) Let $f(.)$ and $g(.)$ be two functions that encode a sentence $s$ into a fixed-length vector. Let $C(s)$ be the set of sentences in the context of $s$ and $S(s)$ be the set of candidate sentences including only one sentence $s_c \\in C(s)$ and many other non-context negative sentences. Quick Thoughts model learns to optimize the probability of predicting the only true context sentence $s_c \\in S(s)$. It is essentially NCE loss when considering the sentence $(s, s_c)$ as the positive pairs while other pairs $(s, s')$ where $s' \\in S(s), s'\\neq s_c$ as negatives.\n $$ \\mathcal{L}_\\text{QT} = - \\sum_{s \\in \\mathcal{D}} \\sum_{s_c \\in C(s)} \\log p(s_c \\vert s, S(s)) = - \\sum_{s \\in \\mathcal{D}} \\sum_{s_c \\in C(s)}\\frac{\\exp(f(s)^\\top g(s_c))}{\\sum_{s'\\in S(s)} \\exp(f(s)^\\top g(s'))} $$ Mutual Information Maximization IS-BERT (Info-Sentence BERT) (Zhang et al. 2020; code) adopts a self-supervised learning objective based on mutual information maximization to learn good sentence embeddings in the unsupervised manners.\nFig. 29. Illustration of Info-Sentence BERT. (Image source: Zhang et al. 2020) IS-BERT works as follows:\n Use BERT to encode an input sentence $s$ to a token embedding of length $l$, $\\mathbf{h}_{1:l}$.\n Then apply 1-D conv net with different kernel sizes (e.g. 1, 3, 5) to process the token embedding sequence to capture the n-gram local contextual dependencies: $\\mathbf{c}_i = \\text{ReLU}(\\mathbf{w} \\cdot \\mathbf{h}_{i:i+k-1} + \\mathbf{b})$. The output sequences are padded to stay the same sizes of the inputs.\n The final local representation of the $i$-th token $\\mathcal{F}_\\theta^{(i)} (\\mathbf{x})$ is the concatenation of representations of different kernel sizes.\n The global sentence representation $\\mathcal{E}_\\theta(\\mathbf{x})$ is computed by applying a mean-over-time pooling layer on the token representations $\\mathcal{F}_\\theta(\\mathbf{x}) = \\{\\mathcal{F}_\\theta^{(i)} (\\mathbf{x}) \\in \\mathbb{R}^d\\}_{i=1}^l$.\n Since the mutual information estimation is generally intractable for continuous and high-dimensional random variables, IS-BERT relies on the Jensen-Shannon estimator (Nowozin et al., 2016, Hjelm et al., 2019) to maximize the mutual information between $\\mathcal{E}_\\theta(\\mathbf{x})$ and $\\mathcal{F}_\\theta^{(i)} (\\mathbf{x})$.\n $$ I^\\text{JSD}_\\omega(\\mathcal{F}_\\theta^{(i)} (\\mathbf{x}); \\mathcal{E}_\\theta(\\mathbf{x})) = \\mathbb{E}_{\\mathbf{x}\\sim P} [-\\text{sp}(-T_\\omega(\\mathcal{F}_\\theta^{(i)} (\\mathbf{x}); \\mathcal{E}_\\theta(\\mathbf{x})))] \\\\ - \\mathbb{E}_{\\mathbf{x}\\sim P, \\mathbf{x}' \\sim\\tilde{P}} [\\text{sp}(T_\\omega(\\mathcal{F}_\\theta^{(i)} (\\mathbf{x}'); \\mathcal{E}_\\theta(\\mathbf{x})))] $$ where $T_\\omega: \\mathcal{F}\\times\\mathcal{E} \\to \\mathbb{R}$ is a learnable network with parameters $\\omega$, generating discriminator scores. The negative sample $\\mathbf{x}'$ is sampled from the distribution $\\tilde{P}=P$. And $\\text{sp}(x)=\\log(1+e^x)$ is the softplus activation function.\nThe unsupervised numbers on SentEval with IS-BERT outperforms most of the unsupervised baselines (Sep 2020), but unsurprisingly weaker than supervised runs. When using labelled NLI datasets, IS-BERT produces results comparable with SBERT (See Fig. 25 \u0026amp; 30).\nFig. 30. The performance of IS-BERT on the SentEval benchmark. (Image source: Zhang et al. 2020) Citation Cited as:\n Weng, Lilian. (May 2021). Contrastive representation learning. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2021-05-31-contrastive/.\n Or\n@article{weng2021contrastive, title = \u0026quot;Contrastive Representation Learning\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2021\u0026quot;, month = \u0026quot;May\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2021-05-31-contrastive/\u0026quot; } References [1] Sumit Chopra, Raia Hadsell and Yann LeCun. \u0026ldquo;Learning a similarity metric discriminatively, with application to face verification.\u0026quot; CVPR 2005.\n[2] Florian Schroff, Dmitry Kalenichenko and James Philbin. \u0026ldquo;FaceNet: A Unified Embedding for Face Recognition and Clustering.\u0026quot; CVPR 2015.\n[3] Hyun Oh Song et al. \u0026ldquo;Deep Metric Learning via Lifted Structured Feature Embedding.\u0026quot; CVPR 2016. [code]\n[4] Ruslan Salakhutdinov and Geoff Hinton. \u0026ldquo;Learning a Nonlinear Embedding by Preserving Class Neighbourhood Structure\u0026rdquo; AISTATS 2007.\n[5] Michael Gutmann and Aapo Hyvärinen. \u0026ldquo;Noise-contrastive estimation: A new estimation principle for unnormalized statistical models.\u0026quot; AISTATS 2010.\n[6] Kihyuk Sohn et al. \u0026ldquo;Improved Deep Metric Learning with Multi-class N-pair Loss Objective\u0026rdquo; NIPS 2016.\n[7] Nicholas Frosst, Nicolas Papernot and Geoffrey Hinton. \u0026ldquo;Analyzing and Improving Representations with the Soft Nearest Neighbor Loss.\u0026quot; ICML 2019\n[8] Tongzhou Wang and Phillip Isola. \u0026ldquo;Understanding Contrastive Representation Learning through Alignment and Uniformity on the Hypersphere.\u0026quot; ICML 2020. [code]\n[9] Zhirong Wu et al. \u0026ldquo;Unsupervised feature learning via non-parametric instance-level discrimination.\u0026quot; CVPR 2018.\n[10] Ekin D. Cubuk et al. \u0026ldquo;AutoAugment: Learning augmentation policies from data.\u0026quot; arXiv preprint arXiv:1805.09501 (2018).\n[11] Daniel Ho et al. \u0026ldquo;Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules.\u0026quot; ICML 2019.\n[12] Ekin D. Cubuk \u0026amp; Barret Zoph et al. \u0026ldquo;RandAugment: Practical automated data augmentation with a reduced search space.\u0026quot; arXiv preprint arXiv:1909.13719 (2019).\n[13] Hongyi Zhang et al. \u0026ldquo;mixup: Beyond Empirical Risk Minimization.\u0026quot; ICLR 2017.\n[14] Sangdoo Yun et al. \u0026ldquo;CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features.\u0026quot; ICCV 2019.\n[15] Yannis Kalantidis et al. \u0026ldquo;Mixing of Contrastive Hard Negatives\u0026rdquo; NeuriPS 2020.\n[16] Ashish Jaiswal et al. \u0026ldquo;A Survey on Contrastive Self-Supervised Learning.\u0026quot; arXiv preprint arXiv:2011.00362 (2021)\n[17] Jure Zbontar et al. \u0026ldquo;Barlow Twins: Self-Supervised Learning via Redundancy Reduction.\u0026quot; arXiv preprint arXiv:2103.03230 (2021) [code]\n[18] Alec Radford, et al. \u0026ldquo;Learning Transferable Visual Models From Natural Language Supervision\u0026rdquo; arXiv preprint arXiv:2103.00020 (2021)\n[19] Mathilde Caron et al. \u0026ldquo;Unsupervised Learning of Visual Features by Contrasting Cluster Assignments (SwAV).\u0026quot; NeuriPS 2020.\n[20] Mathilde Caron et al. \u0026ldquo;Deep Clustering for Unsupervised Learning of Visual Features.\u0026quot; ECCV 2018.\n[21] Prannay Khosla et al. \u0026ldquo;Supervised Contrastive Learning.\u0026quot; NeurIPS 2020.\n[22] Aaron van den Oord, Yazhe Li \u0026amp; Oriol Vinyals. \u0026ldquo;Representation Learning with Contrastive Predictive Coding\u0026rdquo; arXiv preprint arXiv:1807.03748 (2018).\n[23] Jason Wei and Kai Zou. \u0026ldquo;EDA: Easy data augmentation techniques for boosting performance on text classification tasks.\u0026quot; EMNLP-IJCNLP 2019.\n[24] Sosuke Kobayashi. \u0026ldquo;Contextual Augmentation: Data Augmentation by Words with Paradigmatic Relations.\u0026quot; NAACL 2018\n[25] Hongchao Fang et al. \u0026ldquo;CERT: Contrastive self-supervised learning for language understanding.\u0026quot; arXiv preprint arXiv:2005.12766 (2020).\n[26] Dinghan Shen et al. \u0026ldquo;A Simple but Tough-to-Beat Data Augmentation Approach for Natural Language Understanding and Generation.\u0026quot; arXiv preprint arXiv:2009.13818 (2020) [code]\n[27] Tianyu Gao et al. \u0026ldquo;SimCSE: Simple Contrastive Learning of Sentence Embeddings.\u0026quot; arXiv preprint arXiv:2104.08821 (2020). [code]\n[28] Nils Reimers and Iryna Gurevych. \u0026ldquo;Sentence-BERT: Sentence embeddings using Siamese BERT-networks.\u0026quot; EMNLP 2019.\n[29] Jianlin Su et al. \u0026ldquo;Whitening sentence representations for better semantics and faster retrieval.\u0026quot; arXiv preprint arXiv:2103.15316 (2021). [code]\n[30] Yan Zhang et al. \u0026ldquo;An unsupervised sentence embedding method by mutual information maximization.\u0026quot; EMNLP 2020. [code]\n[31] Bohan Li et al. \u0026ldquo;On the sentence embeddings from pre-trained language models.\u0026quot; EMNLP 2020.\n[32] Lajanugen Logeswaran and Honglak Lee. \u0026ldquo;An efficient framework for learning sentence representations.\u0026quot; ICLR 2018.\n[33] Joshua Robinson, et al. \u0026ldquo;Contrastive Learning with Hard Negative Samples.\u0026quot; ICLR 2021.\n[34] Ching-Yao Chuang et al. \u0026ldquo;Debiased Contrastive Learning.\u0026quot; NeuriPS 2020.\n","permalink":"https://lilianweng.github.io/posts/2021-05-31-contrastive/","summary":"The goal of contrastive representation learning is to learn such an embedding space in which similar sample pairs stay close to each other while dissimilar ones are far apart. Contrastive learning can be applied to both supervised and unsupervised settings. When working with unsupervised data, contrastive learning is one of the most powerful approaches in self-supervised learning.\nContrastive Training Objectives In early versions of loss functions for contrastive learning, only one positive and one negative sample are involved.","title":"Contrastive Representation Learning"},{"content":"Large pretrained language models are trained over a sizable collection of online data. They unavoidably acquire certain toxic behavior and biases from the Internet. Pretrained language models are very powerful and have shown great success in many NLP tasks. However, to safely deploy them for practical real-world applications demands a strong safety control over the model generation process.\nMany challenges are associated with the effort to diminish various types of unsafe content:\n First, there are a variety of unsafe content types, such as toxicity, abusiveness, hate speech, biases, stereotypes, cyberbullying, identity attacks and more, which may or may not demand different treatment. Second, there is no clearly and widely agreed-upon categorization and definition of unsafe behavior in pretrained language models. Individual perceptions could vary a lot due to different social backgrounds. In this post, we delve into the issue of toxicity in language models. As I\u0026rsquo;m still struggling to find a concrete definition of toxic content, I list a couple in the literature below.\n [Perspective API] A rude, disrespectful, or unreasonable comment; likely to make people leave a discussion.\n [Kurita et al. 2019] Content that can offend or harm its recipients, including hate speech, racism, and offensive language.\n [Pavlopoulos et al. 2020] We use the term \u0026lsquo;toxic\u0026rsquo; as an umbrella term, but we note that the literature uses several terms for different kinds of toxic language or related phenomena: \u0026lsquo;offensive\u0026rsquo;, \u0026lsquo;abusive\u0026rsquo;, \u0026lsquo;hateful\u0026rsquo;, etc.\n Overall, toxicity is a broad term to describe several types of unsafe content. Methodologies in this post can be applied given some form of definition of toxicity; e.g. presented in the instruction for annotators. How to properly define the concept of toxicity and thus collect accurate annotation labels is out of the scope of this post.\nCategorization of Toxic Content How to categorize toxic content is not a straightforward task. Which content should be considered toxic and what types of toxic content exist can be very subjective. Language that does not look offensive to one group might seem inappropriate to another.\nOne popular categorization of offensive language is proposed by Zampieri et al. (2019), a three-level hierarchical taxonomy considering both the type and the target of offense. The Offensive Language Identification Dataset (OLID) dataset is collected based on this taxonomy.\nFig. 1. The three-level hierarchical taxonomy for categorizing offensive language, proposed by Zampieri et al. (2019). Level A: \u0026ldquo;Is it offensive?\u0026rdquo; [OFF] Offensive: Inappropriate language, insults, or threats. [NOT] Not offensive: No offense or profanity. Level B: \u0026ldquo;Is the offensive text targeted?\u0026rdquo; [TIN] Targeted Insult: Targeted insult or threat towards an individual, a group or other. [UNT] Untargeted: Non-targeted profanity and swearing. Level C: What is the target? [IND] The offense targets an individual, often defined as \u0026ldquo;cyberbullying\u0026rdquo;. [GRP] The offense targets a group of people based on ethnicity, gender, sexual orientation, religion, or other common characteristic, often defined as \u0026ldquo;hate speech\u0026rdquo;. [OTH] The target can belong to other categories, such as an organization, an event, an issue, etc. Data Collection Preparing a dataset of samples labelled as \u0026ldquo;safe\u0026rdquo; vs \u0026ldquo;unsafe\u0026rdquo; is the foundation for training a toxic language classifier and further providing signals for model detoxification.\nHuman Annotations Vidgen \u0026amp; Derczynski (2020) summarized that training data annotations for toxicity detection on the high level can be collected by:\n Expert coding: An expert has enough knowledge or training to complete the annotation tasks with good quality, such as a researcher who studies prejudice, a student with moderate level of training, or a NLP practitioner. It is more expensive but produces high-quality data. Crowdsourcing: Crowdsourcing platform pairs a large number of non-expert annotators with tasks. It is easier to scale up but demands more attention on quality control. Professional moderators: Professional moderators are experienced, well-trained on the tasks, but their goals are likely to optimize for the output specific to the platform. Synthetic data: Training dataset can also be manually created by relevant content creators to cover a broad range of toxic content types. Crowdsourcing is the most common approach among them (Davidson et al. 2017, Zampieri et al. 2019) and there are several good practices to improve the data quality:\n Test data: A small set of annotations collected from a few experts can be used as test questions (Zampieri et al. 2019) to filter out human annotators on the crowdsourcing platform who cannot achieve a certain threshold. Clear guidelines: Detailed instructions are useful to guide annotators to produce aligned and consistent labels. Without any guideline, annotators are encouraged to apply their personal perceptions, which could be problematic because (1) subjective interpretation of toxic content varies across individuals greatly and (2) it is tricky to mark certain types of noise like sarcasm and irony without any guideline. Majority vote: It is very common that we need labels from multiple annotators per sample and take the majority vote. Understanding annotators' identities: Demographic background has a big impact on the annotator\u0026rsquo;s understanding of the task. We should aim to recruit diverse and qualified annotators. Semi-supervised Dataset Khatri et al. (2018) proposed a simple approach to bootstrap a large amount of semi-supervised dataset for learning toxic content classifiers. Their approach relies on a small annotated dataset and a large unlabelled dataset.\n First, they gather a blacklist of 800+ words covering topics of profanity, hate, sexual content and insults. A black list of profanities may have high precision and low recall, but it can provide weak supervised signals. Subreddits are sorted by the percentage of blacklisted words. Then sensitive examples are sampled from the top subreddits and non-sensitive ones from the bottom, respectively. Train a weak binary classifier to further select more samples from the sorted subreddits, Sensitive: contain blacklisted words or toxic classifier confidence \u0026gt; 0.8; Non-sensitive: not contain blacklisted words and toxic classifier confidence \u0026lt; 0.3 Given this large expanded dataset, train a new classifier named \u0026ldquo;Two-stage bootstrap\u0026rdquo; (TS bootstrap). Their experiments showed that the TS bootstrap classifier achieved pretty good numbers on F1 score, accuracy and recall and it could also transfer to out-of-domain test data.\nFig. 2. The two-stage bootstrap classifier is trained on a dataset bootstrapped by a weak toxic binary classifier on Reddit data. (Image source: Khatri et al. 2018) SOLID (Semi-Supervised Offensive Language Identification Dataset; Rosenthal et al. 2020) contains 9+ M tweets annotated with the same taxonomy system as for OLID. SOLID treats OLID as a seed and extends it via a semi-supervised technique called democratic co-training. Democratic co-training (Zhou \u0026amp; Goldman, 2004) creates a large dataset from noisy labels provided by a collection of diverse models trained on a small supervised dataset. SOLID is constructed by:\n First, train a diverse set of supervised models on the labeled dataset OLID. The paper experimented with PMI (n-gram-based similarity), FastText (shallow neural model similar to BoW model), LSTM and BERT. For each sample in the unannotated dataset, each model predicts a confidence score for the target class. The scores are aggregated by taking avg() or min(). Samples with high confidence are added into the dataset. BERT model performance does not improve when the supervised dataset is large enough for a simple task, but can benefit from a big semi-supervised dataset if the original supervised dataset is too small for the task.\nToxicity Detection Given a supervised dataset, we can train a text classifier from scratch or fine-tune a pretrained language model to perform the classification task. But what if training samples are not good or sufficient enough? What if we don’t have access to such a supervised dataset?\nAdversarial Attacks To create a toxicity detection model that is robust to adversarial attacks, Dinan et al. (2019) proposed an iterative \u0026ldquo;build it, break it, fix it\u0026rdquo; strategy to improve the dialogue system safety with humans in the loop.\n Build it: A BERT model is trained to classify toxic comments on the Jigsaw dataset. Break it: Crowdsourced workers are asked to write toxic messages that are mistakenly labelled as \u0026ldquo;safe\u0026rdquo; by the model. Fix it: The model is re-trained on the combination of the original dataset and newly collected adversarial samples. Repeat: Redeploy the robustified model and repeat a new round from step 1. Fig. 3. The illustration of iteratively improving a toxic content detection model via the \"build it, break it, fix it\" process. (Image source: Dinan et al. 2019) One baseline in their experiments is to replace the adversarial collection in the \u0026ldquo;break it\u0026rdquo; step with the standard collection where workers are asked to submit \u0026ldquo;offensive\u0026rdquo; messages directly . Compared to the standard collection, the adversarial collection has less explicit profanity and more negations to trick the model. The tasks become more challenging in the later rounds.\nAdversarial models are more robust against adversarial attacks than baseline models trained on the standard collection. The third round adversarial model has worse performance on the standard task than the standard model, likely due to overfitting. I’m curious about how the model performance would be like if it is trained on both adversarial and standard collection, but I didn’t find it in the paper.\nFig. 4. The comparison of performance on standard and adversarial tasks of models trained on standard ($S\\_i$) and adversarial data collection ($A\\_i$). The subscript $i$ indicates the number of training rounds. (Image source: Dinan et al. 2019) Another type of adversarial attack is to trick the detection model to mistakenly classify a toxic sentence as safe by replacing or scrambling a subset of characters. Kurita et al. (2019) developed a method of generating such model-agnostic adversarial attacks, incorporating several types of character-level perturbations:\n Character scrambling: randomly permute character positions. Homoglyph substitution: replace one or multiple letters with similar looking international letters. Dictionary-based near-neighbor replacement: find closest but distinct token in terms of Levenshtein distance. Distractor injection: inject distractor tokens by repeating random selected sequences of non-toxic tokens. Adversarial noise combining token obfuscation and distractor tokens leads to substantial performance degradation of a toxic classifier. Character-level perturbation degrades performance more than distractors.\nThe paper proposed two ways to resolve adversarial attacks:\n Adversarial training refers to training the model on a dataset with noise. However, you need to know the details of the incoming attacks in advance. And there is no guarantee that training samples with arbitrary noise would generalize to the test set. CDAE (contextual denoising autoencoder) uses character-level and contextual information to denoise obfuscated tokens. CDAE takes a noise sample to predict the denoised version. Still, you need to know what types of character-level perturbation can be applied to create noise samples. CDAE performs comparable to BERT, but not substantially better. Perspective API perspective API (www.perspectiveapi.com) is the most widely used commercial API for toxic content detection. Perspective trains machine learning models to provide scores for several different attributes: toxicity, severe toxicity, insult, profanity, identity attack, threat, and sexually explicit. Each score is a number between [0, 1], indicating how likely the message contains a given attribute (i.e. confidence of a binary classifier) and it does not signify the severity of the attribute.\nFig. 5. The overview of Perspective API scores. (Image source: About Perspective API) Gehman et al. (2020) measured the Perspective API toxicity scores of unprompted generations sampled from several pretrained language models. \u0026ldquo;Unprompted\u0026rdquo; means that the generation is only conditioned on the start-of-sentence tokens, without injecting any additional context. Noticeably, all the tested models get to the expected maximum toxicity \u0026gt; 0.5 after 100 generations. They also pointed out that training datasets for large LMs contain an non-negligible amount of toxic content.\nFig. 6. Perspective API toxicity scores of unprompted generations. Each model generates a pool of 10k samples and the expected maximum toxicity score is estimated via bootstrapping. (Image source: Gehman et al. 2020) They collected the RealToxicityPrompt dataset for studying toxicity in conditional language model generation. It contains 100k naturally occurring prompts with associated toxicity scores from Perspective API. Some prompts that do not contain any toxic language still can trigger very offensive completion.\nDespite of its popularity, Perspective API contains known biases, as summarized by Gehman et al. (2020):\n \u0026hellip; exhibit biases against minorities and suffer from low agreement in annotations, partially due to annotator identity influencing their perception of hate speech and differences in annotation task setup.\n Notably, recent work has found that systems are overestimating the prevalence of toxicity in text that contains a minority identity mention (e.g., \u0026ldquo;I’m a gay man\u0026rdquo;) or text by racial minorities (e.g., text in African American English). This is partially due to detectors’ over-reliance on lexical cues of toxicity (including swearwords, slurs, and other \u0026ldquo;bad\u0026rdquo; words).\n Prompt-based Detection Self-diagnosis (Schick, et al. 2021) is a process of exploiting the capacity of a pretrained language model to detect socially undesired attributes in its own outputs. The diagnosis depends on a predefined prompt template where the attributes are described in short text and measures the normalized probability of the model outputting \u0026ldquo;yes\u0026rdquo; versus \u0026ldquo;no\u0026rdquo;. Note that self-diagnosis does not need to access a labelled dataset for training.\n# Self-diagnosis template \u0026quot;{sample text}\u0026quot; Question: Does the above text contain {undesired attribute}? Answer: {yes/no} They use RealToxicityPrompts dataset and Perspective API for evaluation in the experiments. The self-diagnosis performance is positively correlated with the model size.\nFig. 7. Self-diagnosis abilities for identifying undesired attributes. The ground truth is provided by Perspective API. (Image source: Schick, et al. 2021) Detoxification Blacklisting Bad word filtering is a pretty intuitive and effective way to avoid explicit profane words in the language model generation. At decoding time, we can manually reduce the probabilities of blocked words to avoid sampling them. However, it is not perfect, as it is still possible to have unsafe content composed of safe tokens.\nVocabulary shifting (Gehman et al. 2020) learns a 2-dimensional representation of toxicity versus non-toxicity for every token in the vocabulary of the pretrained model. Then the representation that encodes the non-toxicity is used to boost the likelihood of non-toxic tokens at decoding time.\nPrompt-based Detox Self-debiasing (Schick et al. 2021) follows the similar idea as in self-diagnosis. It is a process for using the internal knowledge of a pretrained language model to reduce the probability of undesired attributes in the model generation.\n# Self-debiasing template, denoted as sdb(.) The following text contains {undesired attribute s}: {sample text x} Given an input prompt $\\mathbf{x}$, a textual description of undesired attributes $s$, and the language model $M$, self-debiasing computes the difference between the probability of next words without and with the self-debiasing template $\\text{sdb}(.)$:\n $$ \\Delta(w, \\mathbf{x}, s) = p_M(w\\vert\\mathbf{x}) - p_M(w\\vert\\text{sdb}(\\mathbf{x}, s)) $$ Because $\\text{sdb}(.)$ is expected to boost the probabilities of undesired words, $\\Delta(w, \\mathbf{x}, s)$ should be negative for undesirable words.\nIn self-diasing decoding, a scaling function of the probability difference $\\alpha(\\Delta(w, \\mathbf{x}, s)): \\mathbb{R}\\to[0,1]$ is used to alter the true sampling distribution,\n $$ \\tilde{p}_M(w\\vert\\mathbf{x}) \\propto \\alpha(\\Delta(w, \\mathbf{x}, s)) p_M(w\\vert\\mathbf{x}) $$ In the paper, they used a soft variant where the probabilities of the words with negative $\\Delta$ are reduced w.r.t. the magnitude of $\\Delta(w, \\mathbf{x}, s)$:\n $$ \\alpha(x)=\\begin{cases} 1 \u0026 \\text{ if } x\\geq 0 \\\\ e^{\\lambda\\cdot x} \u0026 \\text{ otherwise} \\end{cases} $$ Fig. 8. Self-diasing decoding can reduce the probabilities of undesirable attributes. The scores are provided by Perspective API. (Image source: Schick et al. 2021) There are a couple of major limitations in self-debiasing detoxification:\n The evaluation solely relies on Perspective API, so it cannot capture bias \u0026amp; toxicity attributes that are not covered by Perspective API, such as gender biases. Using human evaluation is another alternative but the scale is limited. Self-debiasing sometimes acts too aggressively and filters out harmless words and it does not maintain the same level of perplexity as the original model. The approach is constrained by the internal capacity of the model. For example, if the model is not aware of certain biases, it would not be able to correct them. Text Style Transfer Unsupervised style transfer can be used to translate offensive sentences into innocuous ones (Santos et al. 2018). The approach should work for non-parallel datasets, meaning that we only have access to two separate datasets of offensive and non-offensive samples, but not paired versions. To preserve the content when transferring the text into another style, a cycle consistency loss (Zhu et al. 2017) is adopted.\nFig. 9. The training process of a neural text style transfer algorithm using non-parallel data. (Image source: Santos et al. 2018) Let $s_i$ be the desired style ($i=0$ for offensive and $i=1$ for non-offensive), and $\\mathbf{x}^i_k$ be the $k$-th sample of style $s_i$, $k = 1, \\dots, n$. Both the encoder $E$ and decoder $G$ take a sample (or hidden state) along with a style label. The classifier $C$ predicts a probability distribution over the style labels given an input sample.\nFollowing the illustration in Fig. 9:\n The top branch of forward transfer is auto encoder: ​$E(\\mathbf{x}^i_k, s_i) \\to H^i_k \\to G(H^i_k, s_i) \\to \\hat{\\mathbf{x}}^{i\\to i}_k$. Two losses are computed: Reconstruction loss measures how well the decoder can reconstruct the sample back: $$ \\mathcal{L}_\\text{self} = \\mathbb{E}_{\\mathbf{x}^i_k \\sim \\mathcal{X}} [-\\log p_G(\\mathbf{x}_k^i \\mid E(\\mathbf{x}^i_k, s_i), s_i)] $$ The bottom branch of forward transfer: $E(\\mathbf{x}^i_k, s_i) \\to H^i_k \\to G(H^i_k, s_j) \\to \\hat{\\mathbf{x}}^{i\\to j}_k$ Classification loss measures the effectiveness of style transfer: $$ \\mathcal{L}_\\text{style_fwd} = \\mathbb{E}_{\\hat{\\mathbf{x}}^{i\\to j}_k \\sim \\hat{\\mathcal{X}}} [-\\log p_C(s_j \\mid \\hat{\\mathbf{x}}^{i\\to j}_k)] $$ The back transfer uses cycle consistency loss: $E(\\hat{\\mathbf{x}}^{i\\to j}_k, s_j) \\to H^{i\\to j}_k \\to G(H^{i\\to j}_k, s_i) \\to \\hat{\\mathbf{x}}^{i\\to j \\to i}_k$ The cycle consistency loss controls how well the transferred sample can be converted back to the original form to encourage content preservation: $$ \\mathcal{L}_\\text{cycle} = \\mathbb{E}_{\\mathbf{x}^i_k \\sim \\mathcal{X}} [-\\log p_G(\\mathbf{x}_k^i \\mid E(\\hat{\\mathbf{x}}^{i \\to j}_k, s_j), s_i)] $$ - The classification loss ensures that the back-transferred sample has the correct label: $$ \\mathcal{L}_\\text{style_back} = \\mathbb{E}_{\\hat{\\mathbf{x}}^{i\\to j}_k \\sim \\hat{\\mathcal{X}}} [-\\log p_C(s_i \\mid G(E(\\hat{\\mathbf{x}}^{i\\to j}_k, s_j), s_i))] $$ There is an additional supervised classification loss for training an accurate classifier: $$ \\mathcal{L}_\\text{class} = \\mathbb{E}_{\\hat{\\mathbf{x}}^{i\\to j}_k \\sim \\hat{\\mathcal{X}}} [-\\log p_C(s_i \\mid \\hat{\\mathbf{x}}^i_k)] $$ The final training objective is as follows and the encoder, decoder and classifier are jointly trained:\n $$ \\mathcal{L}(\\theta_E, \\theta_G, \\theta_C) = \\min_{E, G, C} \\mathcal{L}_\\text{self} + \\mathcal{L}_\\text{style_fwd} + \\mathcal{L}_\\text{cycle} + \\mathcal{L}_\\text{style_back}+ \\mathcal{L}_\\text{class} $$ Style Transformer (Dai et al. 2019) also aims to learn unsupervised text style transfer. Different from the encoder-decoder model in Santos et al. 2018, it learns a Transformer-based style transfer function $f_\\theta(\\mathbf{x}, s)$ for a given input sample $\\mathbf{x}$ and a desired style control variable $s$.\nFig. 10. The comparison of style transformer and previous models that depend on disentangled latent representation. (Image source: Dai et al. 2019) Without access to the parallel corpus, the style transformer adopts a discriminator to create supervision from non-parallel dataset.\nLet $s$ and $\\hat{s}$ be two mutually exclusive style variables and $\\mathbf{x}$ is a sample of style $s$, style transformer computes several losses:\n Self reconstruction loss: $\\mathcal{L}_\\text{self} = - p_\\theta (\\mathbf{x} \\vert \\mathbf{x}, s)$ Cycle-consistency loss: $\\mathcal{L}_\\text{cycle} = - p_\\theta (\\mathbf{x} \\vert f_\\theta(\\mathbf{x}, \\hat{s}), s)$ Style controlling loss: This is necessary because otherwise the model would simply learn to copy the input over. $$ \\mathcal{L}_\\text{style} = - p_\\phi(\\text{class} = 1 \\vert f_\\theta(\\mathbf{x}, \\hat{s}), \\hat{s}) $$ , where the discriminator is a simple binary classifier trained to optimize the negative log-likelihood of the correct style. The discriminator is trained by labelling\n $\\{(\\mathbf{x}, s), (f_\\theta(\\mathbf{x}, s), s), (f_\\theta(\\mathbf{x}, \\hat{s}), \\hat{s})\\}$ as positive class 1 $\\{(\\mathbf{x}, \\hat{s}), (f_\\theta(\\mathbf{x}, s), \\hat{s}), (f_\\theta(\\mathbf{x}, \\hat{s}), s)\\}$ as negative class 0. Fig. 11. The training process of Style Transformer. (Image source: Dai et al. 2019) Driven by the research question \u0026ldquo;Can we fine-tune a pre-trained language model to suggest civil rephrasings of rude comments using a dataset solely annotated in toxicity?\u0026rdquo;, Laugier et al. (2021) fine-tuned a pretrained text-to-text transformer with a denoising and cyclic auto-encoder loss.\nLet $s$ be the attribute of $\\mathbf{x}$ (e.g. \u0026ldquo;civil\u0026rdquo;) and $\\bar{s}$ be the other opposite attribute (e.g. \u0026ldquo;toxic\u0026rdquo;). These two attributes are mutually exclusive. The goal is to learn a mapping function $f_\\theta$ such that it translates $x$ to a new fluent sequence $y$ with target attribute $a$ while preserving $x$\u0026rsquo;s content.\nThe encoder-decoder model is trained with the loss:\n $$ \\mathcal{L} = \\lambda_\\text{DAE} \\mathcal{L}_\\text{DAE} + \\lambda_\\text{cycle} \\mathcal{L}_\\text{cycle} $$ The denoising auto-encoder loss is the loss for denoising auto-encoders, where $\\eta$ is a masking function same as in BERT training: $$ \\mathcal{L}_\\text{DAE} = \\mathbb{E}_{\\mathbf{x} \\sim \\mathcal{X}} [−\\log p_\\theta(\\mathbf{x} \\mid \\eta(\\mathbf{x}), s)] $$ The cycle consistency loss (Zhu et al. 2017) has $\\tilde{\\theta}$ to produce a non-differentiable pseudo-prediction $\\hat{\\mathbf{y}}$ and it does not take gradient backpropagation. $$ \\mathcal{L}_\\text{cycle} = \\mathbb{E}_{\\mathbf{x} \\sim \\mathcal{X}} [−\\log p_\\theta(\\mathbf{x} \\mid f_{\\tilde{\\theta}}(\\mathbf{x}, \\bar{s}), s)] $$ They used the above loss to fine-tune a T5 model, resulting in a model named CAE-T5. The conditioning is implemented like CTRL via control code (\u0026ldquo;civil\u0026rdquo; or \u0026ldquo;toxic\u0026rdquo;) prepended to the start of a sequence.\nAutomatic evaluation of the text style transferred results relies on three metrics:\n Accuracy: Classification accuracy measures how successful the style transfer is. Fluency: Fluency is commonly measured by perplexity by another separately trained LM on non-toxic samples. Content preservation: It is the content similarity between transferred and original sentences, measured by BLEU or embedding based content similarity. Human evaluation is also necessary but more costly.\nCompared to the baseline (Shen et al. 2017), the style transfer method by Santos et al. 2018 achieves better classification accuracy, better content preservation, but worse perplexity. CAE-T5 has worse classification accuracy, competitive content preservation, and better perplexity compared to a set of baselines including Style Transformer.\nControllable Generation We can try to avoid toxic outputs via controllable text generation. There are several popular approaches for steering a pretrained language model toward desired styles, topics or safety criteria:\n Apply guided decoding strategies and select desired outputs at test time. Optimize for the most desired outcomes via good prompt design. Fine-tune the base model or steerable layers to do conditioned content generation. Read more in my last post on controllable neural text generation, introducing methods like AutoPrompt, CTRL, PPLM, GeDi and many more.\nGehman et al. (2020) experimented with both data-based (supervised fine-tuning, CTRL training) and decoding-based (vocabulary shifting, blocked word filtering, PPLM) methods for language model detoxification. They found that toxicity control tokens (CTRL) and swear word filters are less successful than more computationally or data-intensive methods like fine-tuning on non-toxic corpora and PPLM.\nFig. 12. Table list expected maximum toxicity score over 25 generations (left) and the empirical probability of generating toxic text over 25 generations (right) for several detoxification methods. Scores are provided by Perspective API. (Image source: Gehman et al., 2020) System-level Safety Solution Xu et al. (2020) presented a thorough system-level design for building safe chatbots.\nFig. 13. Illustration of a safe chat bot system. (Image source: Xu et al. 2020) They consider four general strategies in the recipes for making the bot safer:\n Detect unsafe content: Adopt a classifier for detecting unsafe language on both the input and output side, as an extra safety layer on top of the language model. The classifier is trained on an enhanced version of the Jigsaw toxic comment dataset (safe vs unsafe binary labels), extended with adversarial human attacks (Dinan et al. 2019) and semi-supervision (Khatri et al. 2018). The safety classifier can be used on both the user input and the model output. If it detects unsafe content, the system is configured to return a canned, predefined response (e.g \u0026ldquo;I\u0026rsquo;m sorry I\u0026rsquo;m not sure what to say.\u0026quot;), or decide to change topics. It is worthy noting that this approach relies on a high-quality classifier. The conversation experience would be drastically disrupted with too many false positives. Bot adversarial dialogue (BAD) safety: The idea is to collect data on humans adversarially probing the system to make mistakes and then use the data for further training. During annotation, human labellers can tag the bot\u0026rsquo;s response with an unsafe-safe rating based on the percentage of population who may consider it as unsafe. This probing data collection is used to train a multi-turn safety classifier, predicting whether a response is offensive given the dialogue context. Safe generation: Train a model that is less likely to output unsafe responses. A predefined list of unsafe words/n-grams can be blocked at decoding time. The pretraining data is filtered by the above safety classifier, or filtered based on known authors. The problem with pre-training only with safe datasets is that if the model has never seen toxic language during training, it would not know how to respond at test time (OOD; e.g. may just copy the offensive content). They instead prepare a collection of training samples where the last utterance is labelled as \u0026ldquo;unsafe\u0026rdquo; and then attach a safe response following that unsafe attack. Then the model is fine-tuned on the \u0026ldquo;baked-in\u0026rdquo; safety data. Do CTRL style training by assigning \u0026ldquo;safe\u0026rdquo; vs \u0026ldquo;unsafe\u0026rdquo; label using the safety classifier. Avoid sensitive topics: In order to avoid sensitive topics (politics, religion, drug use, medical advice, and NSFW and relationships/dating), they trained a multi-class classifier to detect those topics using crowdsourced lists of subreddits. The classifier can be periodically re-trained to capture the changes within topics over time. A small validation set is collected by recruiting crowdsourced workers to discuss one of the target topics. Gender bias mitigation: They used CTRL style training to mitigate gender biases. Precisely, given a gendered word list, tag the training samples with $F^0 M^0$, $F^0 M^+$, $F^+ M^+$, and $F^+ M^0$ labels, indicating whether the response contains female / male words ($+$ contains, $-$ does not contain). At test time, the system runs with a control label $F^0 M^0$ to avoid outputting gender specific words. Appendix: Datasets (*Only datasets in English are listed here.)\nHate Speech and Offensive Language Dataset (2017): contains about 25k tweets, each labelled manually as one of three categories: hate speech, offensive but not hate speech, or neither offensive nor hate speech. [Download]\nJigsaw Toxic Comments Classification Dataset (2018): contains about 160k examples extracted from Wikipedia discussion pages, each annotated for 7 classes: toxic, severe toxic, obscene, threat, insult, identity hate and non-toxic. The labelling process involved 5000 crowdsourced annotators. [Download]\nJigsaw Unintended Bias in Toxicity Classification Dataset (2019): contains about 2 Millions comments from the Civil Comments platform, which shut down in 2017. This data is annotated for toxicity, toxicity sub-types, and mentions of identities, which enables evaluation of unintended bias with respect to identity mentions. [Download]\nOLID (Offensive Language Identification Dataset; 2019): contains 14,100 English tweets, annotated according to the three-level taxonomy as described here. [Download]\nSOLID (Semi-Supervised Offensive Language Identification Dataset; 2020): contains 9+ Millions tweets annotated following OLID\u0026rsquo;s three level taxonomy. [Download]\nRealToxicityPrompts dataset (2020): contains 100k sentence snippets from the web with Perspective API toxicity scores for studying the risk of neural toxic degeneration in language models. [Download]\nCitation Cited as:\n Weng, Lilian. (Mar 2021). Reducing toxicity in language models. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2021-03-21-lm-toxicity/.\n Or\n@article{weng2021toxic, title = \u0026quot;Reducing Toxicity in Language Models.\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2021\u0026quot;, month = \u0026quot;Mar\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2021-03-21-lm-toxicity/\u0026quot; } References [1] Vidgen, et al. \u0026ldquo;Challenges and frontiers in abusive content detection.\u0026quot; Workshop on Abusive Language Online 2019.\n[2] Zampieri et al. \u0026ldquo;Predicting the type and target of offensive posts in social media.\u0026quot; NAACL 2019.\n[3] Vidgen \u0026amp; Deczynski. \u0026ldquo;Directions in abusive language training data, a systematic review: Garbage in, garbage out.\u0026quot; PLoS ONE 15(12): e0243300 (2020).\n[4] Davidson et al. \u0026ldquo;Automated hate speech detection and the problem of offensive language.\u0026quot; ICWSM 2017.\n[5] Khatri et al. \u0026ldquo;Detecting offensive content in open-domain conversations using two stage semi-supervision.\u0026quot; NeuriIPS CONVAI Workshop 2018.\n[6] Rosenthal et al. \u0026ldquo;A Large-Scale Semi-Supervised Dataset for Offensive Language Identification\u0026rdquo; arXiv:2004.14454 (2020).\n[7] Pavlopoulos et al. \u0026ldquo;Toxicity Detection: Does Context Really Matter?\u0026quot; arXiv:2006.00998 (2020).\n[8] Dinan et al. \u0026ldquo;Build it, break it, fix it for dialogue safety: Robustness from adversarial human attack.\u0026quot; arXiv:1908.06083 (2019).\n[9] Kurita et al. \u0026ldquo;Towards Robust Toxic Content Classification\u0026rdquo; arXiv:1912.06872 (2019)\n[10] Santos et al. \u0026ldquo;Fighting offensive language on social media with unsupervised text style transfer.\u0026quot; arXiv:1805.07685 (2018)\n[11] Dai et al. \u0026ldquo;Style Transformer: Unpaired Text Style Transfer without Disentangled Latent Representation\u0026rdquo; ACL 2019.\n[12] Laugier et al. \u0026ldquo;Civil Rephrases Of Toxic Texts With Self-Supervised Transformers\u0026rdquo; arXiv:2102.05456 (2021). code\n[13] Schick et al. \u0026ldquo;Self-Diagnosis and Self-Debiasing: A Proposal for Reducing Corpus-Based Bias in NLP\u0026rdquo; arXiv:2103.00453 (2021).\n[14] Gehman et al. \u0026ldquo;RealToxicityPrompts: Evaluating Neural Toxic Degeneration in Language Models\u0026rdquo; EMNLP 2020.\n[15] Xu et al. \u0026ldquo;Recipes for Safety in Open-domain Chatbots\u0026rdquo; arXiv:2010.07079 (2020).\n","permalink":"https://lilianweng.github.io/posts/2021-03-21-lm-toxicity/","summary":"Large pretrained language models are trained over a sizable collection of online data. They unavoidably acquire certain toxic behavior and biases from the Internet. Pretrained language models are very powerful and have shown great success in many NLP tasks. However, to safely deploy them for practical real-world applications demands a strong safety control over the model generation process.\nMany challenges are associated with the effort to diminish various types of unsafe content:","title":"Reducing Toxicity in Language Models"},{"content":"[Updated on 2021-02-01: Updated to version 2.0 with several work added and many typos fixed.] [Updated on 2021-05-26: Add P-tuning and Prompt Tuning in the \u0026ldquo;prompt design\u0026rdquo; section.] [Updated on 2021-09-19: Add \u0026ldquo;unlikelihood training\u0026rdquo;.]\nThere is a gigantic amount of free text on the Web, several magnitude more than labelled benchmark datasets. The state-of-the-art language models (LM) are trained with unsupervised Web data in large scale. When generating samples from LM by iteratively sampling the next token, we do not have much control over attributes of the output text, such as the topic, the style, the sentiment, etc. Many applications would demand a good control over the model output. For example, if we plan to use LM to generate reading materials for kids, we would like to guide the output stories to be safe, educational and easily understood by children.\nHow to steer a powerful unconditioned language model? In this post, we will delve into several approaches for controlled content generation with an unconditioned langage model. Note that model steerability is still an open research question. Each introduced method has certain pros \u0026amp; cons.\n Apply guided decoding strategies and select desired outputs at test time. Optimize for the most desired outcomes via good prompt design. Fine-tune the base model or steerable layers to do conditioned content generation. In the following discussion, we assume we have access to a pretrained generative language model $p_\\theta$. The model has learned the distribution over token sequences by optimizing for the next token prediction: $ \\mathcal{L}_\\text{ML} = - \\sum_t \\log p_\\theta(x_t \\vert x_{\u0026lt;t}) $.\nDecoding Strategies By adopting different decoding methods, we can place restrictions or preferences on the sampling process to alter the generated samples without modifying any model weights. Even though decoding strategies do not change the values of any trainable parameter, it is a quite important component.\nCommon Decoding Methods Since the final layer of the model predicts logits $o$ over the vocabulary space, the next token can be sampled by applying softmax with temperature $T$. The probability of sampling the $i$-th token is\n $$ p_i \\propto \\frac{\\exp(o_i / T)}{\\sum_j \\exp(o_j/T)} $$ A low temperature would make the distribution sharper and a high value makes it softer.\nGreedy search: Always pick the next token with the highest probability, equivalent to setting temperature $T=0$. However, it tends to create repetitions of phrases, even for well-trained models.\nBeam search: It essentially does breadth-first search, one token per tree level, but with a limited bandwidth. At each level of the search tree, beam search keeps track of $n$ (named \u0026ldquo;beam width\u0026rdquo;) best candidates and expands all the successors of these candidates in the next level. Beam search could stop expanding a node if it hits the EOS (end-of-sentence) token.\nHowever, maximization-based decoding does not guarantee high-quality generation.\n Fig. 1. The probability assigned to the next token by beam search versus by humans. The human selected tokens have much higher variance in predicted probability and thus more surprising. (Image source: Holtzman et al. 2019) Top-k sampling (Fan et al., 2018): At each sampling step, only the top $k$ most likely tokens are selected and the probability mass is redistributed among them. In Fan et al., 2018, the authors proposed to use top-k random sampling where the next token is randomly selected among the top $k$ most likely candidates and they argued that this approach can generate more novel and less repetitive content than beam search.\nNucleus sampling (Holtzman et al. 2019): Also known as \u0026ldquo;Top-p sampling\u0026rdquo;. One drawback of top-k sampling is that the predefined number $k$ does not take into consideration how skewed the probability distribution might be. The nucleus sampling selects the smallest set of top candidates with the cumulative probability exceeding a threshold (e.g. 0.95) and then the distribution is rescaled among selected candidates.\nBoth top-k and nucleus sampling have less repetitions with a proper set of hyperparameters.\nPenalized sampling (Keskar et al. 2019): To avoid the common failure case of generating duplicate substrings, the CTRL paper proposed a new sampling method to penalize repetitions by discounting the scores of previously generated tokens. The probability distribution for the next token with repetition penalty is defined as:\n $$ p_i = \\frac{\\exp(o_i / (T \\cdot \\mathbb{1}(i \\in g)))}{\\sum_j \\exp(o_j / (T \\cdot \\mathbb{1}(j \\in g)))} \\quad \\mathbb{1}(c) = \\theta \\text{ if the condition }c\\text{ is True else }1 $$ where $g$ contains a set of previously generated tokens, $\\mathbb{1}(.)$ is an identity function. $\\theta=1.2$ is found to yield a good balance between less repetition and truthful generation.\nGuided Decoding All the above standard decoding strategies sample tokens according to the predicted probability, with no additional information. Our preferences on topic or sentiment can be baked into the candidate ranking function to guide the sample generation by altering the candidate ranking score. The ranking score for token selection at each decoding step can be set as a combination of LM log-likelihood and a set of desired feature discriminators. The features are designed to quantify human preferences by heuristics (Ghazvininejad et al., 2017), supervised learning (Holtzman et al., 2018) or RL (Li et al., 2017).\nGhazvininejad et al. (2017) built a system called \u0026ldquo;Hafez\u0026rdquo; for generating poetry in desired style by adjusting sampling weights in beam search at decoding steps. The likelihood of sampling for the next token $x_{t+1}$ at step $t$ is augmented by a scoring function:\n $$ \\text{score}(x_{t+1}, b_t) = \\text{score}(b_t) + \\log p(x_{t+1}) + \\color{green}{\\sum_i \\alpha_i f_i(x_{t+1})} $$ where $\\log p(x_{t+1})$ is the log-likelihood predicted by LM. $\\text{score}(b_t)$ is the accumulated score of the already-generated words in the current beam state $b_t$. The green part can incorporate many different features for steering the style of the output. A set of feature functions $f_i(.)$ define the preferences and the associated weights $alpha_i$ work like \u0026ldquo;control knobs\u0026rdquo; that can be easily customized at decoding time. Features can measure a variety of attributes and can be easily combined; for example,\n whether $x_{t+1}$ exists in a bag of desired or banned topical words. whether $x_{t+1}$ indicates certain sentiments. whether $x_{t+1}$ is a repeated token (and thus $f_i$ needs to take the history as input too). the length of $x_{t+1}$ if longer or shorter words are in particular preferred. Similar to Hafez, Baheti et al. (2018) manually designed features for ranking and altered the sampling distribution by appending similarity scores between topic distribution or embeddings of the context and the completion.\nHoltzman et al. (2018) adopted a set of learned discriminators, each specializing in a different principle of communication guided by Grice’s maxims: quality, quantity, relation and manner. The discriminators learn to encode these desired principles by measuring repetition, entailment, relevance, and lexical diversity, respectively. Given some ground truth completion, all the discriminator models are trained to minimize the ranking log-likelihood, $\\log\\sigma(f_i(y_g) - f_i(y))$, because the gold continuation $y_g$ is expected to obtain a higher score than the generated one $y$. Here the weight coefficients $\\alpha_i$ are also learned to minimize the score difference between the golden standard and the generated completion. Discriminative Adversarial Search (DAS; Scialom et al., 2020) is inspired by GAN and trains the discriminator to tell apart human created text from machine generated text. The discriminator predicts a label for each token instead of for the entire sequence. The discriminator logprob is added to the score to guide sampling towards the human-written style.\nMeister et al. (2020) studied beam search in a regularized decoding framework:\n $$ \\mathbf{y}^* = \\arg\\max_{\\mathbf{y}\\in\\mathcal{Y}} \\big( \\underbrace{\\log p_\\theta(\\mathbf{y}\\vert\\mathbf{x})}_\\text{MAP} - \\underbrace{\\lambda\\mathcal{R}(\\mathbf{y})}_\\text{regularizer} \\big) $$ Since we expect maximum probability to have minimum surprise, the surprisal of a LM at time step $t$ can be defined as follows:\n $$ \\begin{aligned} u_0(\\texttt{BOS}) \u0026= 0 \\text{ ; BOS is a placeholder token for the beginning of a sentence.}\\\\ u_t(y) \u0026= -\\log P_\\theta(y \\vert \\mathbf{x}, \\mathbf{y}_{The MAP (maximum a posteriori) part demands for sequences with maximum probability given context, while the regularizer introduces other constraints. It is possible a global optimal strategy may need to have a high-surprisal step occasionally so that it can shorten the output length or produce more low-surprisal steps afterwards.\nBeam search has gone through the test of time in the field of NLP. The question is: If we want to model beam search as exact search in a regularized decoding framework, how should $\\mathcal{R}(\\mathbf{y})$ be modeled? The paper proposed a connection between beam search and the uniform information density (UID) hypothesis.\n \u0026ldquo;The uniform information density hypothesis (UID; Levy and Jaeger, 2007) states that—subject to the constraints of the grammar—humans prefer sentences that distribute information (in the sense of information theory) equally across the linguistic signal, e.g., a sentence.\u0026rdquo;\n In other words, it hypothesizes that humans prefer text with evenly distributed surprisal. Popular decoding methods like top-k sampling or nuclear sampling actually filter out high-surprisal options, thus implicitly encouraging the UID property in output sequences.\nThe paper experimented with several forms of regularizers:\n Greedy: $\\mathcal{R}_\\text{greedy}(\\mathbf{y}) = \\sum_{t=1}^{\\vert\\mathbf{y}\\vert} \\big(u_t(y_t) - \\min_{y' \\in \\mathcal{V}} u_t(y') \\big)^2$; if set $\\lambda \\to \\infty$, we have greedy search. Note that being greedy at each individual step does not guarantee global optimality. Variance regularizer: $\\mathcal{R}_\\text{var}(\\mathbf{y}) = \\frac{1}{\\vert\\mathbf{y}\\vert}\\sum_{t=1}^{\\vert\\mathbf{y}\\vert} \\big(u_t(y_t) - \\bar{u} \\big)^2$ , where $\\bar{u}$ is the average surprisal over all timesteps. It directly encodes the UID hypothesis. Local consistency: $\\mathcal{R}_\\text{local}(\\mathbf{y}) = \\frac{1}{\\vert\\mathbf{y}\\vert}\\sum_{t=1}^{\\vert\\mathbf{y}\\vert} \\big(u_t(y_t) - u_{t-1}(y_{t-1}) \\big)^2$; this decoding regularizer encourages adjacent tokens to have similar surprisal. Max regularizer: $\\mathcal{R}_\\text{max}(\\mathbf{y}) = \\max_t u_t(y_t)$ penalizes the maximum compensation of surprisal. Squared regularizer: $\\mathcal{R}_\\text{square}(\\mathbf{y}) = \\sum_{t=1}^{\\vert\\mathbf{y}\\vert} u_t(y_t)^2$ encourages all the tokens to have surprisal close to 0. An experiment with greedy regularizers showed that larger $\\lambda$ results in better performance (e.g. measured by BLEU for NMT task) and lower std dev of surprisal.\nFig. 2. The plot of BLEU and std. dev of surprisals as functions of the strength of the regularizer $\\lambda$. The subgraph in grey shows the relationship between BLEU and surprisal std. dev. (Image source: Meister et al. 2020) A default beam search would have text generation of decreased quality when beam size increases. Regularized beam search greatly helps alleviate this issue. A combined regularizer further improves the performance. In their experiments for NMT, they found $\\lambda=5$ for greedy and $\\lambda=2$ for squared work out as the optimal combined regularizer.\nFig. 3. The plot of BLEU of a function of beam size (left) and BLEU scores for translations created by different regularized decoding strategies. (Image source: Meister et al. 2020) Guided decoding essentially runs a more expensive beam search where the sampling probability distribution is altered by side information about human preferences.\nTrainable Decoding Given a trained language model, Gu et al (2017) proposed a trainable greedy decoding algorithm to maximize an arbitrary objective for sampling sequences. The idea is based on the noisy, parallel approximate decoding (NPAD). NPAD injects unstructured noise into the model hidden states and runs noisy decoding multiple times in parallel to avoid potential degradation. To take a step further, trainable greedy decoding replaces the unstructured noise with a learnable random variable, predicted by a RL agent that takes the previous hidden state, the previous decoded token and the context as input. In other words, the decoding algorithm learns a RL actor to manipulate the model hidden states for better outcomes.\nGrover et al. (2019) trained a binary classifier to distinguish samples from data distribution and samples from the generative model. This classifier is used to estimate importance weights for constructing a new unnormalized distribution. The proposed strategy is called likelihood-free importance weighting (LFIW).\nLet $p$ be the real data distribution and $p_\\theta$ be a learned generative model. A classical approach for evaluating the expectation of a given function $f$ under $p$ using samples from $p_\\theta$ is to use importance sampling.\n $$ \\mathbb{E}_{\\mathbf{x}\\sim p} [f(\\mathbf{x})] = \\mathbb{E}_{\\mathbf{x}\\sim p_\\theta} \\Big[\\frac{p(\\mathbf{x})}{p_\\theta(\\mathbf{x})} f(\\mathbf{x})\\Big] \\approx \\frac{1}{N} \\sum_{i=1}^N w(\\mathbf{x}_i)f(\\mathbf{x}_i) $$ However, $p(\\mathbf{x})$ can only be estimated via finite datasets. Let $c_\\phi: \\mathcal{X} \\to [0,1]$ be a probabilistic binary classifier for predicting whether a sample $\\mathbf{x}$ is from the true data distribution ($y=1$). The joint distribution over $\\mathcal{X}\\times\\mathcal{Y}$ is denoted as $q(\\mathbf{x}, y)$.\n $$ q(\\mathbf{x}\\vert y) = \\begin{cases} p_\\theta(\\mathbf{x}) \u0026 \\text{ if }y=0\\text{; predicted to be generated data} \\\\ p(\\mathbf{x}) \u0026 \\text{ otherwise; from the true data distribution} \\end{cases} $$ Then if $c_\\phi$ is Bayes optimal, the importance weight can be estimated by:\n $$ w_\\phi(\\mathbf{x}) = \\frac{p(\\mathbf{x})}{p_\\theta(\\mathbf{x})} = \\frac{q(\\mathbf{x} \\vert y=1)}{q(\\mathbf{x} \\vert y=0)} = \\frac{q(y=0)}{q(y=1)} \\frac{q(y=1 \\vert \\mathbf{x})}{q(y=0 \\vert \\mathbf{x})} = \\gamma \\frac{c_\\phi(\\mathbf{x})}{1 - c_\\phi(\\mathbf{x})} $$ where $\\gamma = \\frac{q(y=0)}{q(y=1)} \u0026gt; 0$ is a fixed odd ratio.\nSince we cannot learn a perfect optimal classifier, the importance weight would be an estimation $\\hat{w}_\\phi$. A couple of practical tricks can be applied to offset cases when the classifier exploits artifacts in the generated samples to make very confident predictions (i.e. very small importance weights):\n Self-normalization: normalize the weight by the sum $\\hat{w}_\\phi(\\mathbf{x}_i) / \\sum_{j=1}^N \\hat{w}_\\phi(\\mathbf{x}_j)$. Flattening: add a power scaling parameter $\\alpha \u0026gt; 0$, $\\hat{w}_\\phi(\\mathbf{x}_i)^\\alpha$. Clipping: specify a lower bound $\\max(\\hat{w}_\\phi(\\mathbf{x}_i), \\beta)$. To sample from an importance resampled generative model, $\\mathbf{x}\\sim p_{\\theta, \\phi}(\\mathbf{x}) \\propto p_\\theta(\\mathbf{x})\\hat{w}_\\phi(\\mathbf{x})$, they adopt SIR (Sampling-Importance-Resampling),\nFig. 4. The algorithm for sampling from a generative model according to importance weights $\\hat{w}(\\mathbf{x}\\_i)$ using SIR. (Image source: Grover et al., 2019)) Deng et al., 2020 proposed to learn a EBM to steer a LM in the residual space, $P_\\theta(x) \\propto P_\\text{LM}(x)\\exp(-E_\\theta(x))$, where $P_\\theta$ is the joint model; $E_\\theta$ is the residual energy function to be learned. If we know the partition function $Z$, we can model the generative model for generative a sequence $x_{p+1}, \\dots, x_T$ as:\n $$ P_\\theta(x_{p+1:T}\\vert x_{1:p}) = \\frac{P_\\text{LM}(x_{p+1:T}\\vert x_{1:p}) \\exp(-E_\\theta(x_{1:T}))}{Z_\\theta(x_{1:p})} $$ The goal is to learn the parameters of the energy function $E_\\theta$ such that the joint model $P_\\theta$ gets closer to the desired data distribution. The residual energy function is trained by noise contrastive estimation (NCE), considering $P_\\theta$ as the model distribution and $P_\\text{LM}$ as the noise distribution:\n $$ \\theta = \\arg\\max_{\\theta} \\mathbb{E}_{x^+ \\sim P_\\text{data}} \\log\\frac{1}{1+\\exp(E_\\theta(x^+))} + \\mathbb{E}_{x^- \\sim P_\\text{LM}} \\log\\frac{1}{1+\\exp(-E_\\theta(x^-))} $$ However, the partition function is intractable in practice. The paper proposed a simple way to first sample from the original LM and then to resample from them according to the energy function. This is unfortunately quite expensive.\nFig. 5. Top k samples from the base LM are resampled according to the residual energy function. (Image source: Deng et al., 2020) Smart Prompt Design Large language models have been shown to be very powerful on many NLP tasks, even with only prompting and no task-specific fine-tuning (GPT2, GPT3. The prompt design has a big impact on the performance on downstream tasks and often requires time-consuming manual crafting. For example, factual questions can gain a big boost with smart prompt design in \u0026ldquo;closed-book exam\u0026rdquo; (Shin et al., 2020, Jiang et al., 2020)). I’m expecting to see an increasing amount of literature on automatic smart prompt design.\nGradient-based Search AutoPrompt (Shin et al., 2020; code) is a method to automatically create prompts for various tasks via gradient-based search. AutoPrompt constructs a prompt by combining the original task inputs $x$ with a collection of trigger tokens $x_\\text{trig}$ according to a template $\\lambda$. The trigger tokens are shared across all inputs and thus universally effective.\nFig. 6. The overview of AutoPrompt. The trigger tokens are retrieved to optimize for the target outputs across all inputs. (Image source: Shin et al., 2020) The universal trigger tokens are identified using a gradient-guided search strategy same as in Wallace et al., 2019. The universal setting means that the trigger tokens $x_\\text{trig}$ can optimize for the target output $\\tilde{y}$ for all inputs from a dataset:\n $$ x_\\text{trig} = \\arg\\min_{x’_\\text{trig}} \\mathbb{E}_{x\\sim\\mathcal{X}} [\\mathcal{L}(\\tilde{y}, f(x’_\\text{trig}; x))] $$ The search operates in the embedding space. The embedding of every trigger token $e_{\\text{trig}_i}$ is first initialized to some default value and then gets updated to minimize the first-order Taylor expansion of the task-specific loss around the current token embedding:\n $$ e^{(t+1)}_\\text{trig} = \\arg\\min_{e\\in\\mathcal{V}} [e - e^{(t)}_{\\text{trig}_i}]^\\top \\nabla_{e^{(t)}_{\\text{trig}_i}} \\mathcal{L} $$ where $\\mathcal{V}$ refers to the embedding matrix of all the tokens. $\\nabla_{e^{(t)}_{\\text{trig}_i}} \\mathcal{L}$ is the average gradient of the task loss over a batch at iteration $t$. We can brute-force the optimal $e$ by a $\\vert \\mathcal{V} \\vert d$-dimensional dot product, which is cheap and can be computed in parallel.\nFig. 7. We search for trigger tokens by updating their embeddings with the gradient of the task loss per batch. (Image source: Wallace et al., 2019) The above token replacement method can be augmented with beam search. When looking for the optimal token embedding $e$, we can pick top-$k$ candidates instead of a single one, searching from left to right and score each beam by $\\mathcal{L}$ on the current data batch.\nFig. 8. Example prompts discovered by AutoPrompt for different tasks. (Image source: Shin et al., 2020) Smart prompt design essentially produces efficient context that can lead to desired completion. Motivated by this observation, Li \u0026amp; Liang (2021) proposed Prefix-Tuning which assigns a small number of trainable parameters at the beginning of an input sequence (named \u0026ldquo;prefix\u0026rdquo;) to steer a LM, $[\\text{PREFIX}; x; y]$. Let $\\mathcal{P}_\\text{idx}$ be a set of prefix indices and $\\text{dim}(h_i)$ be the embedding size. The prefix parameters $P_\\theta$ has the dimension $\\vert\\mathcal{P}_\\text{idx}\\vert \\times \\text{dim}(h_i) $ and the hidden state takes the form:\n $$ h_i = \\begin{cases} P_\\theta[i,:], \u0026 \\text{if }i \\in \\mathcal{P}_\\text{idx}\\\\ \\text{LM}_\\phi(z_i, h_{Note that only $P_\\theta$ is trainable and the LM parameters $\\phi$ is frozen during training.\nFig. 9. Illustrations of fine-tuning versus prefix-tuning. (Image source: Li \u0026 Liang 2021) The prefix parameters do not tie to any embeddings associated with the real words and thus they are more expressive for steering the context. Direct optimizing $P_\\theta$ unfortunately results in poor performance. To reduce the difficulty associated with high dimensionality training, the matrix $P_\\theta$ is reparameterized by a smaller matrix $P'_\\theta \\in \\mathbb{R}^{\\vert\\mathcal{P}_\\text{idx}\\vert \\times c}$ and a large feed forward network $\\text{MLP}_\\theta \\in \\mathbb{R}^{c\\times \\text{dim}(h_i)}$.\nThe performance increases with the prefix length $\\vert\\mathcal{P}_\\text{idx}\\vert$ up to some value. And this value varies with tasks.\nFig. 10. Task performance, summarization (left) and table-to-text (right), as a function of prefix length. (Image source: Li \u0026 Liang 2021) A few other interesting learnings from their ablation studies include:\n Tuning only the embedding layer (without prefix) is not sufficiently expressive. Placing the trainable parameter between $x$ and $y$, $[x; \\text{INFIX}; y]$, slightly underperforms prefix-tuning, likely because it only affects the context for $y$ while prefix affects both. Random initialization of $P_\\theta$ leads to low performance with high variance. In contrast, initializing $P_\\theta$ with activations of real words improves generation, even the words are irrelevant to the task. Fine-tuned models achieve better task performance but they can fail in the low data regime. Both AutoPrompt and Prefix-Tuning were found to outperform fine-tuning in the regime where the training dataset is small (i.e. $10^2-10^3$ samples). As an alternative to fine-tuning, prompt design or learning the context embedding is much cheaper. AutoPrompt improves the accuracy for sentiment classification a lot more than manual prompts and achieves similar performance as linear probing. For the NLI task, AutoPrompt obtains higher accuracy than linear probing. It is able to retrieve facts more accurately than manual prompts too. In low data regime, Prefix-Tuning achieves performance comparable with fine-tuning on table-to-text generation and summarization.\nTwo successive works, P-tuning (Liu et al. 2021; code) and Prompt Tuning (Lester et al. 2021), follow the similar idea of explicit training continuous prompt embeddings but with a few different choices over the trainable parameters and architecture. Different from Prefix-Tuning which concatenates continuous prompt tokens in every hidden state layer of the transformer, both P-tuning and Prompt Tuning non-invasively add continuous prompts only in the input to work well.\nLet $[P_i]$ be the $i$-th token in the prompt template of P-tuning (Liu et al. 2021), we can denote a prompt as a sequence $T=\\{[P_{0:i}], \\mathbf{x}, [P_{i+1:m}], \\mathbf{y}\\}$. Each token $[P_i]$ does not have to be a real token in the model vocabulary (\u0026ldquo;pseudo-token\u0026rdquo;), and thus the encoded template $T^e$ looks like the following and the pseudo-token hidden state can be optimized with gradient descent.\n $$ T^e = \\{ h_0, \\dots, h_i, \\text{embed}(\\mathbf{x}), h_{i+1}, \\dots, h_m, \\text{embed}(\\mathbf{y})\\} $$ Fig. 11. The illustration of P-tuning. Sometimes, adding a few task-related anchor tokens, such as “capital” in the figure, can bring further improvement. (Image source: Liu et al. 2021) There are two major optimization challenges in P-tuning:\n Discreteness: The word embedding of a pretrained language model are highly discrete. It is hard to optimize $h_i$ if they are intialized at random. Association: $h_i$ should be dependent on each other. Thus they develop a mechanism to model this dependency by training a light-weighted LSTM-based prompt encoder: $$ h_i = \\text{MLP}([\\text{LSTM}(h_{0:i}): \\text{LSTM}(h_{i:m})]) $$ P-tuning is more flexible than prefix-tuning, as it inserts trainable tokens in the middle of a prompt not just at the beginning. The usage of task-specific anchor tokens is like combining manual prompt engineering with trainable prompts.\nPrompt Tuning (Lester et al. 2021) largely simplifies the idea of prefix tuning by only allowing an additional $k$ tunable tokens per downstream task to be prepended to the input text. The conditional generation is $p_{\\theta, \\theta_P}(Y \\vert [P; X])$, where $P$ is the \u0026ldquo;pseudo prompt\u0026rdquo; with parameters $\\theta_P$ trainable via back-propagation. Both $X$ and $P$ are embedding vectors and we have $X \\in \\mathbb{R}^{n \\times d^e}, P \\in \\mathbb{R}^{k \\times d^e}$ and $[P;X] \\in \\mathbb{R}^{(n+k) \\times d^e}$, where $d^e$ is the embedding space dimensionality.\n Prompt tuning produces competitive results as model fine-tuning when the model gets large (billions of parameters and up). This result is especially interesting given that large models are expensive to fine-tune and execute at inference time. With learned task-specific parameters, prompt tuning achieves better transfer learning when adapting to new domains. It outperforms fine-tuning on domain shift problems. They also showed that prompt ensembling of multiple prompts for the same task introduces further improvement. Fig. 12. The illustration of how Prompt Tuning works. (Image source: Lester et al. 2021) The experiments investigated several prompt initialization schemes:\n Random initialization by uniformly sampling from [-0.5, 0.5]; Sample embeddings of top 5000 common tokens; Use the embedding values of the class label strings. If we don\u0026rsquo;t have enough class labels to initialize the soft-prompt, we fall back to scheme 2. Random initialization performs noticeably worse than the other two options. Fig. 13. The effect of (a) different prompt initialization schemes and (b) different prompt lengths. (Image source: Lester et al. 2021) The pre-training objectives also have a big impact on the quality of prompt tuning. T5’s “span corruption” is not a good option here.\nPrompt tuning is found to be less likely to overfit to a specific dataset. To evaluate the robustness to data shifting problem, they trained the model on one dataset of one task and evaluated it on the test dataset but in a different domain. Prompt tuning is more resilient and can generalize to different domains better.\nFig. 14. Prompt tuning is more resilient to domain shift between train and test sets. (Image source: Lester et al. 2021) Heuristic-based Search Paraphrasing is a quick way to explore more prompts similar to the known version, which can be done via back-translation. Using back-translation, the initial prompt is translated into $B$ candidates in another language and then each is translated back into $B$ candidates in the original language. The resulting total $B^2$ candidates are scored and ranked by their round-trip probabilities.\nRibeiro et al (2018) identified semantically equivalent adversaries (SEA) by generating a variety of paraphrases $\\{x'\\}$ of input $x$ until it triggers a different prediction of target function $f$:\n $$ \\begin{aligned} SEA(x, x') \u0026= \\mathbb{1}[\\text{SemEq}(x, x') \\land f(x) \\neq f(x')] \\\\ \\text{where SemEq}(x, x') \u0026= \\mathbb{1}[\\min\\Big(1, \\frac{p(x'\\vert x)}{p(x\\vert x)} \\Big) \\geq \\tau] \\end{aligned} $$ The rules extracted from SEA are considered as \u0026ldquo;bugs\u0026rdquo; in the model. Applying those rules as data augmentation in model training helps robustify the model and fix bugs.\nJiang et al (2020) attempts to validate whether a trained language model knows certain knowledge by automatically discovering better prompts to query. Within the scope of knowledge retrieval where factual knowledge is represented in the form of a triple $\\langle x, r, y \\rangle$ (subject, relation, object). The prompts can be mined from training sentences (e.g. Wikipedia description) or expanded by paraphrase.\nInterestingly some small modifications in the prompts may lead to big gain, as shown in Fig. X.\nFig. 15. Small modifications in prompt templates can lead to big performance gains: replacement in blue, insertion in green, deletion in red. (Image source: Jiang et al., 2020) Fine-tuning Fine-tuning is an intuitive way to guide a LM to output desired content, commonly by training on supervised datasets or by RL. We can fine-tune all the weights in the model or restrict the fine-tuning to only top or additional layers.\nConditional Training Conditional training aims to learn a generative model conditioned on a control variable $z$, $p(y \\vert x, z)$.\nFan et al (2018) trained a conditional language model for 2-step story generation. First, a model outputs the story sketch and then a story writing model creates a story following that sketch. The mechanism of conditioning on the sketch is implemented by a fusion model architecture. The fusion model enforces a form of residual learning that allows the story writing model to focus on learning what the first sketch generation model is missing. Also for story generation, Peng et al (2018) experimented with an ending valence-conditioned story generator LM, $p(x_t \\vert x_{\u0026lt;t}, z)$ where $z$ is the label of the story ending (sad, happy or neutral). Their language model is a bidirectional LSTM and the label is mapped into a learned embedding which then blends into the LSTM cell.\nCTRL (Keskar et al., 2019; code) aims to train a language model conditioned control code $z$ using controllable datasets. CTRL learns the conditioned distribution $p(x \\vert z)$ by training on raw text sequences with control code prefixes, such as [horror], [legal], etc. Then the learned model is able to generate text with respect to the prompt prefix. The training data contains Wikipedia, OpenWebText, books, Amazon reviews, reddit corpus and many more, where each dataset is assigned with a control code and subreddit in the reddit corpus has its own topic as control code.\nFig. 16. Datasets used for training CTRL and associated control codes. (Image source: Edited from Table 7 in Keskar et al., 2019) The control code also can be used for domain annotation given tokens, because $p(z \\vert x) \\propto p(x \\vert z) p(z)$, assuming the prior over domains is uniform. One limitation of CTRL is the lack of control for what not to generate (e.g. avoid toxicity).\nFig. 17. The examples of conditioned sample generation by CTRL. (Image source: Keskar et al., 2019) Note that CTRL trains a transformer model from scratch. However, labelling all the text within the same dataset with the same control code (e.g. All the wikipedia articles have \u0026ldquo;wikipedia\u0026rdquo; as control code) feels quite constrained. Considering that often we need highly customized control codes but only have a limited amount of labelled data, I would expect fine-tuning an unconditional LM with a small labelled dataset in the same way as CTRL to work out well too. Although how much data is needed and how good the sample quality might be are subject to experimentation.\nRL Fine-tuning Fine-tuning a sequential model with RL regarding any arbitrary and possibly non-differentiable reward function has been proved to work well years ago (Ranzato et al., 2015). RL fine-tuning can resolve several problems with teacher forcing method. With teacher forcing, the model only minimizes a maximum-likelihood loss at each individual decoding step during training but it is asked to predict the entire sequence from scratch at test time. Such a discrepancy between train and test could lead to exposure bias and accumulated error. In contrast, RL fine-tuning is able to directly optimize task-specific metrics on the sequence level, such as BLEU for translation (Ranzato et al., 2015, Wu et al., 2016, Nguyen et al., 2017), ROUGE for summarization (Ranzato et al., 2015, Paulus et al., 2017, Wu and Hu, 2018) and customized metric for story generation (Tambwekar et al., 2018).\nRanzato et al (2015) applied REINFORCE to train RNN models for sequence generation tasks. The model is first trained to predict the next token using cross-entropy loss (ML loss) and then fine-tuned alternatively by both ML loss and REINFORCE (RL loss). At the second fine-tuning stage, the number of training steps for next-token prediction is gradually decreasing until none and eventually only RL loss is used. This sequence-level RL fine-tuning was shown by experiments to lead to great improvements over several supervised learning baselines back then.\nGoogle implemented the similar approach in their neural machine translation system (Wu et al., 2016) and Paulus et al (2017) adopted such approach for summarization task. The training objective contains two parts, ML loss for next token prediction, $\\mathcal{L}_\\text{ML} = \\sum_{(x, y^*)\\sim\\mathcal{D}} \\log p_\\theta(y^* \\vert x)$, and RL loss $\\mathcal{L}_\\text{RL}$ for maximizing the expected reward where the reward per sequence is measured by BLEU or ROUGE. The model is first trained with $\\mathcal{L}_\\text{ML}$ until convergence and then fine-tuned with a linear combination of two losses, $\\mathcal{L}_\\text{mix} = \\alpha \\mathcal{L}_\\text{ML} + (1 - \\alpha)\\mathcal{L}_\\text{RL}$.\nThe RL loss of Google NMT is to maximize the expected BLEU score:\n $$ \\mathcal{L}_\\text{RL} = - \\sum_{(x, y^*)\\sim\\mathcal{D}} \\mathbb{E}_{y\\sim p_\\theta(.\\vert x)} [R(y, y^*)] $$ where $y$ is the predicted sequence and $y^*$ is the ground truth.\nPaulus et al (2017) added an extra weighting term based on the reward difference between two output sequences, $y$ by sampling the next token according to the predicted probability and $\\hat{y}$ by greedily taking the most likely token. This RL loss maximizes the conditional likelihood of the sampled sequence $y$ if it obtains a higher reward than the greedy baseline $\\hat{y}$:\n $$ \\mathcal{L}_\\text{RL} = \\sum_{(x, y^*)\\sim\\mathcal{D}} (R(\\hat{y}, y^*) - R(y, y^*)) \\sum_{t=1}^{n'} \\log p(y_t \\vert y_{RL Fine-tuning with Human Preferences Reward learning is critical for defining human preferences. Quantitative measurement like BLEU or ROUGE computes the overlap of words and n-gram phrases between sequences and does not always correlate with better quality by human judges. Reward learning from human feedback (Christiano et al., 2017) is a better way to align what we measure with what we actually care about. Human feedback has been applied to learn a reward function for applications like story generation (Yi et al., 2019) and summarization (Böhm et al., 2019, Ziegler et al., 2019, Stiennon et al., 2020).\nIn order to generate more coherent conversation, Yi et al (2019) collected 4 types of binary human feedback given a conversation pair (user utterance, system response), whether the system response is (1) comprehensive, (2) on topic, (3) interesting and (4) leading to continuation of the conversation. An evaluator is trained to predict human feedback and then is used to rerank the beam search samples, to finetune the model or to do both. (Actually they didn’t use RL fine-tuning but rather use the evaluator to provide a discriminator loss in supervised fine-tuning.)\nLet\u0026rsquo;s define a learned reward function $R_\\psi(x, y)$ parameterized by $\\psi$ as a measurement for the quality of output $y$ given the input $x$.\nTo learn the ground truth reward $R^*$ defined by human judgements, Böhm et al (2019) compared two loss functions:\n(1) Regression loss: simply minimizing the mean squared error.\n $$ \\mathcal{L}^\\text{MSE}_\\text{rm} = [R^*(x, y) - R_\\psi(x, y)]^2 $$ (2) Preference loss: learning to agree with the ground truth reward,\n $$ \\begin{aligned} \\mathcal{L}^\\text{pref}_\\text{rm} =\u0026 - \\sum_{i,j} \\big(\\mathbb{1}[R^*(x, y_i) R^*(x, y_j)] \\log P(y_i \\succ y_j) + \\\\ \u0026\\mathbb{1}[R^*(x, y_j) R^*(x, y_i)] \\log P(y_j \\succ y_i) \\big)\\\\ \\text{where }P(y_i \\succ y_j) =\u0026 \\frac{\\exp(R_\\psi(x, y_i))}{\\exp(R_\\psi(x, y_i)) + \\exp(R_\\psi(x, y_j))} \\end{aligned} $$ Their experiments showed that the preference loss achieves the best performance, where the reward model is a thin MLP layer on top of BERT sentence embedding.\nZiegler et al (2019) collected human labels by asking humans to select the best candidate $y_b$ out of a few options $\\{y_i\\}$ given the input $x \\sim \\mathcal{D}$. The candidates are sampled by $y_0, y_1 \\sim p(.\\vert x), y_2, y_3 \\sim \\pi(.\\vert x)$. We should be aware that human labeling might have very high disagreement when the ground truth is fuzzy.\nFig. 18. The overview of the training framework for fine-tuning a language model policy with reward learned from human feedback. (Image source: Ziegler et al., 2019) The reward model is implemented by a pretrained language model with an extra random linear layer of the final embedding output. It it trained to minimize the loss:\n $$ \\mathcal{L}_\\text{rm} = -\\mathbb{E}_{(x, \\{y_i\\}, b) \\sim \\mathcal{D}} \\Big[ \\log \\frac{\\exp(R_\\psi(x, y_b))}{\\sum_i \\exp(R_\\psi(x, y_i))} \\Big] $$ To keep the scale consistent during training, the reward model is normalized to have mean 0 and variance 1.\nDuring RL fine-tuning, the policy $\\pi$, initialized by a pretrained language model $p$, is optimized via PPO with the above learned reward model. To avoid the policy\u0026rsquo;s deviating from its original behavior too much, a KL penalty is added:\n $$ R(x, y) = R_\\psi(x, y) - \\beta\\log\\frac{\\pi(y \\vert x)}{p(y \\vert x)} $$ If running online data collection, human label collection process is continued during RL fine-tuning and thus the human labelers can review results generated by the latest policy. The number of human labels are evenly spread out during the training process. Meanwhile the reward model is also retrained periodically. Online data collection turns out to be important for the summarization task but not for the text continuation task. In their experiments, jointly training the reward model and the policy with shared parameters did not work well and can lead to overfitting due to the big imbalance between dataset sizes.\nIn the following work (Stiennon et al., 2020), the human label collection was further simplified to select the best option between a pair of summaries, $y_b \\in\\{y_0, y_1\\}$ The reward model loss was updated to optimize the log odds of the selected summary:\n $$ \\mathcal{L}_\\text{rm} = \\mathbb{E}_{(x, y_0, y_1, b)\\sim\\mathcal{D}} [\\log(\\sigma(r_\\theta(x, y_b) − r_\\theta(x, y_{1−b})))] $$ Fig. 19. The overview of fine-tuning the language model policy from human feedback for summarization, including (1) human feedback collection, (2) reward model training, and (3) policy training. (Image source: Stiennon et al., 2020) Guided Fine-tuning with Steerable Layer Instead of fine-tuning the entire model, only fine-tuning a small extra set of parameters while the base model stays fixed is computationally cheaper.\nIn computer vision, plug-and-play generative networks (PPGN; Nguyen et al., 2017) generate images with different attributes by plugging a discriminator $p(a \\vert x)$ into a base generative model $p(x)$. Then the sample with a desired attribute $a$ can be sampled from $p(x \\vert a) \\propto p(a \\vert x)p(x)$. Inspired by PPGN, the plug-and-play language model (PPLM; Dathathri et al., 2019) combines one or multiple simple attribute models with a pretrained language model for controllable text generation.\nGiven an attribute $a$ and the generated sample $x$, let an attribute model be $p(a\\vert x)$. To control content generation, the current latent representation at time $t$, $H_t$ (containing a list of key-value pairs per layer), can be shifted by $\\Delta H_t$ in the direction of the sum of two gradients:\n One toward higher log-likelihood of the attribute $a$ under $p(a \\vert x)$ \u0026mdash; so that the output content acquires a desired attribute. The other toward higher log-likelihood of the unmodified language model $p(x)$ \u0026mdash; so that the generated text is still in fluent and smooth natural language. To shift the output, at decoding time, PPLM runs one forward → one backward → one forward, three passes in total:\n First a forward pass is performed to compute the likelihood of attribute $a$ by $p(a\\vert x)$; Let $\\Delta H_t$ be a stepwise update to the hidden state $H_t$ such that $(H_t + \\Delta H_t)$ shifts the distribution of generated text closer to having the attribute $a$. $\\Delta H_t$ is initialized at zero. Then a backward pass updates the LM hidden states using normalized gradients from the attribute model $\\nabla_{\\Delta H_t} \\log p(a \\vert H_t + \\Delta H_t)$ as $$ \\Delta H_t \\leftarrow \\Delta H_t + \\alpha \\frac{\\nabla_{\\Delta H_t} \\log p(a|H_t + \\Delta H_t)}{\\| \\nabla_{\\Delta H_t} \\log p(a|H_t + \\Delta H_t) \\|^\\gamma} $$ where $\\gamma$ is a normalization scaling coefficient, set per layer. $\\alpha$ is step size. This update can be repeated $m \\in [3, 10]$ times 3. The final forward pass recomputes a new distribution over the vocabulary, generated from the updated latents $\\tilde{H}_t = H_t + \\Delta H_t$. The next token is sampled from the updated distribution.\nFig. 20. The overview of how PPLM runs three passes to update the model output to increase the likelihood of a desired attribute. (Image source: Dathathri et al., 2019) Multiple attribute models can be mix-and-matched during generation with customized weights, acting as a set of \u0026ldquo;control knobs\u0026rdquo;. The PPLM paper explored two types of attribute models:\n The simplest attribution model is based on a predefined bag of words (BoW), $\\{w_1, \\dots, w_k\\}$, that specifies a topic of interest. $$ \\log p(a \\vert x) = \\log\\big( \\sum_{i=1}^k p_{t+1} [w_i] \\big) $$ To encourage the model to output the desired words at least once but not at every step, they normalize the gradient by the maximum gradient norm. Interestingly, they found that increasing the probability of generating words in the bag also increases the probability of generating related but not identical words about the same topic. 2. The discriminator attribute models are based on learned classifiers which define preferences by a distribution instead of hard samples.\nTo ensure the fluency in language, PPLM applied two additional designs:\n Minimizing the KL diverge between modified and unmodified LM, commonly seen in other RL fine-tuning approaches (see above). It performs post-norm fusion to constantly tie the generated text to the unconditional LM $p(x)$, $x_{t+1} \\sim \\frac{1}{\\beta}(\\tilde{p}_{t+1}^{\\gamma_\\text{gm}} p_{t+1}^{1-\\gamma_\\text{gm}})$, where $p_{t+1}$ and $\\tilde{p}_{t+1}$ are the unmodified and modified output distributions, respectively. $\\beta$ is a normalizing factor. $\\gamma_\\text{gm} \\in [0.8, 0.95]$ balances between prediction from before and after models. Fig. 21. Examples of controllable text generation by PPLM. (Image source: Dathathri et al., 2019) Interestingly, they found a large variance in the extent of controllability across topics. Some topics (religion, science, politics) are easier to control for compared to others (computers, space).\nOne obvious drawback of PPLM is that due to multiple passes at every decoding step, the test time computation becomes much more expensive.\nSimilar to PPLM, DELOREAN (DEcoding for nonmonotonic LOgical REAsoNing; Qin et al., 2020) incorporates the future context by back-propagation. Given input text $\\mathbf{x}$, DELOREAN aims to generate continuation completion $\\mathbf{y} = [y_1, \\dots, y_N]$ such that $y$ satisfies certain constraints defined by a context $z$. To keep the generation differentiable, a soft representation of $y$ is tracked, $\\tilde{\\mathbf{y}}=(\\tilde{y}_1, \\dots, \\tilde{y}_N)$ where $\\tilde{y}_i \\in \\mathbb{R}^V$ are logits over the vocabulary. $\\tilde{\\mathbf{y}}^{(t)}$ is the soft representation at iteration $t$.\nGiven the representation $\\tilde{y}^{(t-1)}$ at iteration $t$, it runs the following procedures:\n Backward: The constraint is represented as a loss function $\\mathcal{L}(\\mathbf{x}, \\tilde{\\mathbf{y}}^{(t-1)}, z))$. The logits are updated via gradient descent: $\\tilde{y}^{(t), b}_n = \\tilde{y}_n^{(t-1)} - \\lambda \\nabla_{\\tilde{y}_n} \\mathcal{L}(\\mathbf{x}, \\tilde{\\mathbf{y}}^{(t-1)}, z)$. Forward: Run forward pass to ensure the generated text is fluent. $\\tilde{y}^{(t),f}_n = \\text{LM}(\\mathbf{x}, \\tilde{\\mathbf{y}}^{(t)}_{1:n-1})$. Then linearly combine two logits together to create a new representation $\\tilde{y}^{(t)}_n = \\gamma \\tilde{y}^{(t), f}_n + (1-\\gamma) \\tilde{y}^{(t), b}_n$. Note that each $\\tilde{y}^{(t)}_n$ is needed to sample the next $\\tilde{y}^{(t),f}_{n+1}$. Side-tuning (Zhang et al., 2019) trains a light-weighted side network that learns a residual on top of the original model outputs without modifying the pre-trained model weights. Unlike PPLM, no gradient update is applied on the hidden states. It is a simple yet effective approach for incremental learning. The base model is treated as a black-box model and does not necessarily have to be a neural network. Side-tuning setup assumes the base and side models are fed exactly the same input and the side model is independently learned.\nFig. 22. Comparison of fixed weights, fine-tuning and side-tuning. (Image source: Zhang et al., 2019) The paper explored different strategies of fusing predictions from the base and side models: product is the worst while sum ($\\alpha$-blending), MLP, and FiLM are comparable. Side-tuning is able to achieve better performance, when it is trained with intermediate amounts of data and when the base network is large.\nAuxiliary tuning (Zeldes et al., 2020) supplements the original pre-trained model with an auxiliary model that shifts the output distribution according to the target task. The base and auxiliary model outputs are merged on the logits level. The combined model is trained to maximize the likelihood $p(x_t\\vert x_{\u0026lt;t}, z)$ of target output.\nThe conditional probability of $p(x_t\\vert x_{\u0026lt;t}, z)$ can be decomposed into two parts:\n $p(x_t\\vert x_{\u0026lt;t})$ assigns high probabilities to fluent sequences of tokens; a shift on $p(x_t\\vert x_{\u0026lt;t})$ towards $p(x_t\\vert x_{\u0026lt;t}, z)$. $$ p(x_t\\vert x_{By Bayesian rule, we have\n $$ p(x_t\\vert x_{And therefore the auxiliary model $\\text{logits}_\\text{aux}(x_t \\vert x_{\u0026lt;t}, z))$ effectively should learn to predict $p(z \\vert x_{\\leq t})$. In the experiments of Zeldes et al., 2020, the auxiliary model can re-use the intermediate layers of the pre-trained LM for feature extraction.\nFig. 23. The auxiliary model is trained by reusing features extracted from multiple layers of the base model. (Image source: Zeldes et al., 2020) GeDi (Kruse et al., 2020) guides the text generation by Generative Discriminator. The discriminator is implemented as a class conditional language model (CC-LM), $p_\\theta(x_{1:t} \\vert z)$. The discriminator guides generation at each decoding step by computing classification probabilities for all possible next tokens via Bayes rule by normalizing over two contrastive class-conditional distributions:\n One conditioned on the control code $z$ for desired attribute. The other conditioned on the anti-control code $\\bar{z}$ for undesired attributes. GeDi relies on the contract between $p_\\theta(x_{1:t} \\vert z)$ and $p_\\theta(x_{1:t} \\vert \\bar{z})$ to compute the probability of the sequence belonging to the desired class. The discriminator loss is to maximize the probability of desired attribute $z$:\n $$ \\begin{aligned} p_\\theta(z \\vert x_{1:t}) \u0026= \\frac{p(z) p_\\theta(x_{1:\\tau} \\vert z)^{\\alpha/\\tau}}{\\sum_{z' \\in \\{z, \\bar{z}\\}} p(z') p_\\theta(x_{1:\\tau} \\vert z')^{\\alpha/\\tau} } \\\\ \\mathcal{L}_\\text{desc} \u0026= -\\frac{1}{N} \\sum_{i=1}^N \\log p_\\theta(z^{(i)} \\vert x^{(i)}_{1:\\tau_i}) \\\\ \u0026= -\\frac{1}{N} \\sum_{i=1}^N \\log \\frac{p(z) p_\\theta(x^{(i)}_{1:\\tau_i} \\vert z^{(i)})^{\\alpha/t_i}}{\\sum_{z' \\in \\{z, \\bar{z}\\} } p(z')p_\\theta(x^{(i)}_{1:\\tau_i} \\vert z')^{\\alpha/\\tau_i}} \\end{aligned} $$ where $p(z) = \\exp(b_z) / \\sum_{z'} \\exp(b_{z'})$ and $b_z$ is a learned class prior. The probabilities are normalized by the current sequence length $\\tau$ to robustify generation sequences of variable lengths. $\\tau_i$ is the sequence length of the $i$-th input $x^{(i)}$ in the dataset.\nFig. 24. An illustration of how GeDi works via Bayesian rule. (Image source: Kruse et al., 2020) They finetuned a GPT2-medium model with control code similar to how CTRL is trained to form a CC-LM using a linear combination of discriminative loss and generative loss. This discriminator model is then used as GiDe to guide generation by a larger language model like GPT2-XL.\nOne way of decoding from GeDi is to sample from a weighted posterior $p^w(x_{t+1}\\vert x_{1:t}, z) \\propto p(z \\vert x_{1:t+1})^w p(x_{t+1} \\vert x_{1:t})$ where $w\u0026gt;1$ applies additional bias toward the desired class $z$. In the sampling process, only tokens with the class or next-token probability larger than a certain threshold are selected.\nGeDi guided generation in their experiments showed strong controllability and ran 30x faster than PPLM.\nDistributional Approach Generation with Distributional Control (GDC; Khalifa, et al. 2020) frames controlled text generation as the optimization of a probability distribution with a constraint. It involves two major steps.\nStep 1: Learn a EBM of the target model\nLet\u0026rsquo;s label a pretrained LM as $a$ and a target LM with desired features as $p$. The desired features can be defined by a set of pre-defined real-valued feature functions $\\phi_i(x), i=1,\\dots,k$ over $x \\in X$, denoted as a vector $\\boldsymbol{\\phi}$. When sequences $x \\in X$ are sampled according to the desired model $p$, the expectations of features $\\mathbb{E}_{x\\sim p}\\boldsymbol{\\phi}(x)$ should be close to $\\bar{\\boldsymbol{\\mu}}$ , named \u0026ldquo;moment constraints\u0026rdquo;. The feature function $\\phi_i$ can have distinct values (e.g. identity function for binary classifier) or continuous probabilities. In the meantime, the fine-tuned model $p$ should not diverge from $a$ too much by maintaining a small KL divergence measure.\nIn summary, given a pretrained model $a$, we would like to find a target model $p$ such that:\n $$ \\begin{aligned} \\bar{\\boldsymbol{\\mu}} \u0026= \\mathbb{E}_{x\\sim p}\\boldsymbol{\\phi}(x) \\\\ p \u0026= \\arg\\min_{c \\in \\mathcal{C}} D_\\text{KL}(c, a) \\end{aligned} $$ where $\\mathcal{C}$ is the set of all distributions over $X$ that satisfy the moment constraints.\nAccording to theorems in Information Geometry, $p$ can be approximated by an EBM (energy-based model; an unnormalized probability distribution) $P$ in the form of exponential function, such that $p(x) \\propto P(x)$ and $p(x)=\\frac{1}{Z}P(x)$ where $Z=\\sum_x P(x)$. The energy-based model can be approximated by:\n $$ P(x)=a(x)\\exp\\big(\\sum_i \\lambda_i \\phi_i(x)\\big)=a(x)\\exp(\\boldsymbol{\\lambda}\\cdot\\boldsymbol{\\phi}(x)) $$ Let\u0026rsquo;s define importance weight $w(x, \\boldsymbol{\\lambda}) = \\frac{P(x)}{a(x)} = \\exp\\langle\\boldsymbol{\\lambda}\\cdot\\boldsymbol{\\phi}(x)\\rangle$. Given a large number of sequences sampled from the pretrained model $x_1, \\dots, x_N \\sim a(x)$,\n $$ \\begin{aligned} \\mu(\\boldsymbol{\\lambda}) \u0026= \\mathbb{E}_{x\\sim p}\\boldsymbol{\\phi}(x) = \\mathbb{E}_{x\\sim a} \\frac{p(x)}{a(x)}\\boldsymbol{\\phi}(x) = \\frac{1}{Z}\\mathbb{E}_{x\\sim a} w(x, \\boldsymbol{\\lambda}) \\boldsymbol{\\phi}(x) \\\\ \u0026= \\frac{\\mathbb{E}_{x\\sim a} w(x, \\boldsymbol{\\lambda}) \\boldsymbol{\\phi}(x)}{\\sum_{x\\in X} P(x)} = \\frac{\\mathbb{E}_{x\\sim a} w(x, \\boldsymbol{\\lambda}) \\boldsymbol{\\phi}(x)}{\\sum_{x\\in X} w(x, \\boldsymbol{\\lambda})a(x)} = \\frac{\\mathbb{E}_{x\\sim a} w(x, \\boldsymbol{\\lambda}) \\boldsymbol{\\phi}(x)}{\\mathbb{E}_{x\\sim a} w(x, \\boldsymbol{\\lambda})} \\\\ \u0026\\simeq \\frac{\\sum_{i=1}^N w(x_i,\\boldsymbol{\\lambda}) \\boldsymbol{\\phi}(x_i)}{\\sum_{i=1}^N w(x_i, \\boldsymbol{\\lambda})} = \\frac{\\sum_{i=1}^N \\exp\\langle\\boldsymbol{\\lambda}\\cdot\\boldsymbol{\\phi}(x)\\rangle \\boldsymbol{\\phi}(x_i)}{\\sum_{i=1}^N \\exp\\langle\\boldsymbol{\\lambda}\\cdot\\boldsymbol{\\phi}(x)\\rangle} \\end{aligned} $$ Using SGD over the objective $|\\boldsymbol{\\mu}(\\boldsymbol{\\lambda}) - \\bar{\\boldsymbol{\\mu}}|^2_2$, we can obtain an estimated value for $\\boldsymbol{\\lambda}$ and a representation of $P(x)=a(x)\\exp\\langle\\boldsymbol{\\lambda}\\cdot\\boldsymbol{\\phi}(x)\\rangle$. $P(x)$ is a sequential EBM because $a$ is an autoregressive model.\nStep 2: Learn the target probability distribution\nThe EBM $P(x)$ can compute ratios of probabilities of two sequences, but cannot sample from $p(x)$ with knowing $Z$. In order to sample from a sequential EBM, the paper proposed to use Distributional Policy Gradient (DPG; but not this DPG) with the objective to obtain an autoregressive policy $\\pi_\\theta$ to approximate a target distribution $p$ by minimizing the cross entropy $H(p, \\pi_\\theta)$. DPG runs through a sequence of iterations. Within each iteration, the proposed distribution $q$ is used for sampling and we can correct the cross entropy loss with importance weights too:\n $$ \\begin{aligned} \\nabla_\\theta H(p, \\pi_\\theta) \u0026= - \\nabla_\\theta \\mathbb{E}_{x\\sim p} \\log \\pi_\\theta(x) = - \\mathbb{E}_{x\\sim p} \\nabla_\\theta \\log \\pi_\\theta(x) \\\\ \u0026= - \\mathbb{E}_{x\\sim q} \\frac{p(x)}{q(x)} \\nabla_\\theta \\log \\pi_\\theta(x) = - \\frac{1}{Z}\\mathbb{E}_{x\\sim q} \\frac{P(x)}{q(x)} \\nabla_\\theta \\log \\pi_\\theta(x) \\end{aligned} $$ To learn such a $\\pi_\\theta$, the paper adopts a KL-adaptive version of DPG: It only updates $q$ when the estimated policy $\\pi_\\theta$ gets closer to $p$. This adaptive step is important for fast convergence.\nFig. 25. The algorithm of distributional policy gradient to make it possible to sample from a EBM $P(x)$, where $q$ is initialized to be $a$. (Image source: Khalifa, et al. 2020) This approach can be used to model various constraints in controllable text generation:\n Pointwise constraints: $\\phi_i$ is a binary feature; such as constraining the presence or absence of words, or classifier-based constraints. Distributional constraints: $\\phi_i$ represents a probability distribution; such as constraining the probability of gender, topic, etc. Their experiments showed great progress in debiasing a GPT-2 model that was trained on Wikipedia Biographies corpus. The percentage of generated biographies on females increased from 7.4% to 35.6%. Hybrid constraints: combine multiple constraints by simply summing them up. Fig. 26. Debiasing experiments using GDC with various constraints. (Image source: Khalifa, et al. 2020) Compared to other baselines, GDC using pointwise constraints diverges less from the base model $a$ and produces smoother curves.\nFig. 27. Compare pointwise constrained GDC with several baselines. Low Self-BLEU-5 and high Dist-1 indicate high diversity. (Image source: Khalifa, et al. 2020) REINFORCE that optimizes the reward $\\phi$ directly ($\\text{REINFORCE}$ in Fig. X.) without constraints converges fast but has a high deviation from the original model. REINFORCE that optimizes $P(x)$ ($\\text{REINFORCE}_{P(x)}$ in Fig. X.) has low sample diversity. Compared to Ziegler et al., 2019 GDC has smoother learning curves and produces a richer vocabulary. Unlikelihood Training The standard way of maximizing the log-likelihood loss in language model training leads to incorrect token distribution, which cannot be fixed with only smart decoding methods. Such models tend to output high-frequency words too often and low-frequency words too rarely, especially when using deterministic decoding (e.g. greedy, beam search). In other words, they are overconfident in their predictions.\nUnlikelihood training (Welleck \u0026amp; Kulikov et al. 2019] tries to combat this and incorporates preference to unwanted content into the training objective directly. It combines two updates:\n A routine maximized likelihood update to assign true tokens with high probability; A new type of unlikelihood update to avoid unwanted tokens with high probability. Given a sequence of tokens $(x_1, \\dots, x_T)$ and a set of negative candidate tokens $\\mathcal{C}^t = \\{c_1, \\dots , c_m\\}$ at step $t$, where each token $x_i, c_j \\in \\mathcal{V}$, the combined loss for step $t$ is defined as:\n $$ \\mathcal{L}^t_\\text{UL}(p_\\theta (. \\vert x_{One approach for constructing $\\mathcal{C}^t$ is to randomly select candidates from model-generated sequences.\nThe unlikelihood training can be extended to be on the sequence-level, where the negative continuation is defined by a sequence of per-step negative candidate sets. They should be designed to penalize properties that we don\u0026rsquo;t like. For example, we can penalize repeating n-grams as follows:\n $$ \\mathcal{C}^t_\\text{repeat-n} = \\{x_t\\} \\text{ if }(x_{t-i}, \\dots, x_{t+j}) \\in x_{Their experiments used unlikelihood training to avoid repetitions in language model outputs and indeed showed better results on less repetition and more unique tokens compared to standard MLE training.\nCitation Cited as:\n Weng, Lilian. (Jan 2021). Controllable neural text generation. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2021-01-02-controllable-text-generation/.\n Or\n@article{weng2021conditional, title = \u0026quot;Controllable Neural Text Generation.\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2021\u0026quot;, month = \u0026quot;Jan\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2021-01-02-controllable-text-generation/\u0026quot; } References [1] Patrick von Platen. \u0026ldquo;How to generate text: using different decoding methods for language generation with Transformers\u0026rdquo; Hugging face blog, March 18, 2020.\n[2] Angela Fan, et al. \u0026ldquo;Hierarchical Neural Story Generation/\u0026quot; arXiv preprint arXiv:1805.04833 (2018).\n[3] Ari Holtzman et al. \u0026ldquo;The Curious Case of Neural Text Degeneration.\u0026quot; ICLR 2020.\n[4] Marjan Ghazvininejad et al. \u0026ldquo;Hafez: an interactive poetry generation system.\u0026quot; ACL 2017.\n[5] Ari Holtzman et al. \u0026ldquo;Learning to write with cooperative discriminators.\u0026quot; ACL 2018.\n[6] Ashutosh Baheti et al. \u0026ldquo;Generating More Interesting Responses in Neural Conversation Models with Distributional Constraints.\u0026quot; EMNLP 2018.\n[7] Jiatao Gu et al. \u0026ldquo;Trainable greedy decoding for neural machine translation.\u0026quot; EMNLP 2017.\n[8] Kyunghyun Cho. \u0026ldquo;Noisy Parallel Approximate Decoding for Conditional Recurrent Language Model.\u0026quot; arXiv preprint arXiv:1605.03835. (2016).\n[9] Marco Tulio Ribeiro et al. \u0026ldquo;Semantically equivalent adversarial rules for debugging NLP models.\u0026quot; ACL 2018.\n[10] Eric Wallace et al. \u0026ldquo;Universal Adversarial Triggers for Attacking and Analyzing NLP.\u0026quot; EMNLP 2019. [code]\n[11] Taylor Shin et al. \u0026ldquo;AutoPrompt: Eliciting Knowledge from Language Models with Automatically Generated Prompts.\u0026quot; EMNLP 2020. [code]\n[12] Zhengbao Jiang et al. \u0026ldquo;How Can We Know What Language Models Know?\u0026quot; TACL 2020.\n[13] Nanyun Peng et al. \u0026ldquo;Towards Controllable Story Generation.\u0026quot; NAACL 2018.\n[14] Nitish Shirish Keskar, et al. \u0026ldquo;CTRL: A Conditional Transformer Language Model for Controllable Generation\u0026rdquo; arXiv preprint arXiv:1909.05858 (2019).[code]\n[15] Marc’Aurelio Ranzato et al. \u0026ldquo;Sequence Level Training with Recurrent Neural Networks.\u0026quot; ICLR 2016.\n[16] Yonghui Wu et al. \u0026ldquo;Google\u0026rsquo;s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation.\u0026quot; CoRR 2016.\n[17] Romain Paulus et al. \u0026ldquo;A Deep Reinforced Model for Abstractive Summarization.\u0026quot; ICLR 2018.\n[18] Paul Christiano et al. \u0026ldquo;Deep Reinforcement Learning from Human Preferences.\u0026quot; NIPS 2017.\n[19] Sanghyun Yi et al. \u0026ldquo;Towards coherent and engaging spoken dialog response generation using automatic conversation evaluators.\u0026quot; INLG 2019.\n[20] Florian Böhm et al. \u0026ldquo;Better rewards yield better summaries: Learning to summarise without references.\u0026quot; EMNLP 2019. [code]\n[21] Daniel M Ziegler et al. \u0026ldquo;Fine-tuning language models from human preferences.\u0026quot; arXiv preprint arXiv:1909.08593 (2019). [code]\n[22] Nisan Stiennon, et al. \u0026ldquo;Learning to summarize from human feedback.\u0026quot; arXiv preprint arXiv:2009.01325 (2020).\n[23] Sumanth Dathathri et al. \u0026ldquo;Plug and play language models: a simple approach to controlled text generation.\u0026quot; ICLR 2020. [code]\n[24] Jeffrey O Zhang et al. \u0026ldquo;Side-tuning: Network adaptation via additive side networks\u0026rdquo; ECCV 2020.\n[25] Ben Kruse et al. \u0026ldquo;GeDi: Generative Discriminator Guided Sequence Generation.\u0026quot; arXiv preprint arXiv:2009.06367.\n[26] Yoel Zeldes et al. \u0026ldquo;Technical Report: Auxiliary Tuning and its Application to Conditional Text Generatio.\u0026quot; arXiv preprint arXiv:2006.16823.\n[27] Thomas Scialom, et al. \u0026ldquo;Discriminative Adversarial Search for Abstractive Summarization\u0026rdquo; ICML 2020.\n[28] Clara Meister, et al. \u0026ldquo;If beam search is the answer, what was the question?\u0026quot; EMNLP 2020.\n[29] Xiang Lisa Li and Percy Liang. \u0026ldquo;Prefix-Tuning: Optimizing Continuous Prompts for Generation.\u0026quot; arXiv preprint arXiv:2101.00190 (2021).\n[30] Lianhui Qin, et al. \u0026ldquo;Back to the Future: Unsupervised Backprop-based Decoding for Counterfactual and Abductive Commonsense Reasoning.\u0026quot; arXiv preprint arXiv:2010.05906 (2020).\n[31] Muhammad Khalifa, et al. \u0026ldquo;A Distributional Approach to Controlled Text Generation\u0026rdquo; Accepted by ICLR 2021.\n[32] Aditya Grover, et al. \u0026ldquo;Bias correction of learned generative models using likelihood-free importance weighting.\u0026quot; NeuriPS 2019.\n[33] Yuntian Deng et al. \u0026ldquo;Residual Energy-Based Models for Text Generation.\u0026quot; ICLR 2020.\n[34] Brian Lester et al. “The Power of Scale for Parameter-Efficient Prompt Tuning.” arXiv preprint arXiv:2104.08691 (2021).\n[35] Xiao Liu et al. “GPT Understands, Too.” arXiv preprint arXiv:2103.10385 (2021).\n[36] Welleck \u0026amp; Kulikov et al. “Neural Text Generation with Unlikelihood Training” arXiv:1908.04319 (2019).\n","permalink":"https://lilianweng.github.io/posts/2021-01-02-controllable-text-generation/","summary":"[Updated on 2021-02-01: Updated to version 2.0 with several work added and many typos fixed.] [Updated on 2021-05-26: Add P-tuning and Prompt Tuning in the \u0026ldquo;prompt design\u0026rdquo; section.] [Updated on 2021-09-19: Add \u0026ldquo;unlikelihood training\u0026rdquo;.]\nThere is a gigantic amount of free text on the Web, several magnitude more than labelled benchmark datasets. The state-of-the-art language models (LM) are trained with unsupervised Web data in large scale. When generating samples from LM by iteratively sampling the next token, we do not have much control over attributes of the output text, such as the topic, the style, the sentiment, etc.","title":"Controllable Neural Text Generation"},{"content":"[Updated on 2020-11-12: add an example on closed-book factual QA using OpenAI API (beta).\nA model that can answer any question with regard to factual knowledge can lead to many useful and practical applications, such as working as a chatbot or an AI assistant🤖. In this post, we will review several common approaches for building such an open-domain question answering system.\nDisclaimers given so many papers in the wild:\n Assume we have access to a powerful pretrained language model. We do not cover how to use structured knowledge base (e.g. Freebase, WikiData) here. We only focus on a single-turn QA instead of a multi-turn conversation style QA. We mostly focus on QA models that contain neural networks, specially Transformer-based language models. I admit that I missed a lot of papers with architectures designed specifically for QA tasks between 2017-2019😔 What is Open-Domain Question Answering? Open-domain Question Answering (ODQA) is a type of language tasks, asking a model to produce answers to factoid questions in natural language. The true answer is objective, so it is simple to evaluate model performance.\nFor example,\nQuestion: What did Albert Einstein win the Nobel Prize for? Answer: The law of the photoelectric effect. The \u0026ldquo;open-domain\u0026rdquo; part refers to the lack of the relevant context for any arbitrarily asked factual question. In the above case, the model only takes as the input the question but no article about \u0026ldquo;why Einstein didn\u0026rsquo;t win a Nobel Prize for the theory of relativity\u0026rdquo; is provided, where the term \u0026ldquo;the law of the photoelectric effect\u0026rdquo; is likely mentioned. In the case when both the question and the context are provided, the task is known as Reading comprehension (RC).\nAn ODQA model may work with or without access to an external source of knowledge (e.g. Wikipedia) and these two conditions are referred to as open-book or closed-book question answering, respectively.\nWhen considering different types of open-domain questions, I like the classification by Lewis, et al., 2020, in increasing order of difficulty:\n A model is able to correctly memorize and respond with the answer to a question that has been seen at training time. A model is able to answer novel questions at test time and choose an answer from the set of answers it has seen during training. A model is able to answer novel questions which have answers not contained in the training dataset. Fig. 1. Overview of three frameworks discussed in this post. Notation Given a question $x$ and a ground truth answer span $y$, the context passage containing the true answer is labelled as $z \\in \\mathcal{Z}$, where $\\mathcal{Z}$ is an external knowledge corpus. Wikipedia is a common choice for such an external knowledge source.\nConcerns of QA data fine-tuning Before we dive into the details of many models below. I would like to point out one concern of fine-tuning a model with common QA datasets, which appears as one fine-tuning step in several ODQA models. It could be concerning, because there is a significant overlap between questions in the train and test sets in several public QA datasets.\nLewis, et al., (2020) (code) found that 58-71% of test-time answers are also present somewhere in the training sets and 28-34% of test-set questions have a near-duplicate paraphrase in their corresponding training sets. In their experiments, several models performed notably worse when duplicated or paraphrased questions were removed from the training set.\nOpen-book QA: Retriever-Reader Given a factoid question, if a language model has no context or is not big enough to memorize the context which exists in the training dataset, it is unlikely to guess the correct answer. In an open-book exam, students are allowed to refer to external resources like notes and books while answering test questions. Similarly, a ODQA system can be paired with a rich knowledge base to identify relevant documents as evidence of answers.\nWe can decompose the process of finding answers to given questions into two stages,\n Find the related context in an external repository of knowledge; Process the retrieved context to extract an answer. Fig. 2. The retriever-reader QA framework combines information retrieval with machine reading comprehension. Such a retriever + reader framework was first proposed in DrQA (\u0026ldquo;Document retriever Question-Answering\u0026rdquo; by Chen et al., 2017; code). The retriever and the reader components can be set up and trained independently, or jointly trained end-to-end.\nRetriever Model Two popular approaches for implementing the retriever is to use the information retrieval (IR) system that depends on (1) the classic non-learning-based TF-IDF features (\u0026ldquo;classic IR\u0026rdquo;) or (2) dense embedding vectors of text produced by neural networks (\u0026ldquo;neural IR\u0026rdquo;).\nClassic IR DrQA (Chen et al., 2017) adopts an efficient non-learning-based search engine based on the vector space model. Every query and document is modelled as a bag-of-word vector, where each term is weighted by TF-IDF (term frequency $\\times$ inverse document frequency).\n $$ \\begin{aligned} \\text{tf-idf}(t, d, \\mathcal{D}) \u0026= \\text{tf}(t, d) \\times \\text{idf}(t, \\mathcal{D}) \\\\ \\text{tf}(t, d) \u0026= \\log(1 + \\text{freq}(t, d)) \\\\ \\text{idf}(t, \\mathcal{D}) \u0026= \\log \\Big( \\frac{\\vert\\mathcal{D}\\vert}{\\vert d\\in\\mathcal{D}: t\\in d\\vert} \\Big) \\end{aligned} $$ where $t$ is a unigram or bigram term in a document $d$ from a collection of documents $\\mathcal{D}$ . $\\text{freq}(t, d)$ measures how many times a term $t$ appears in $d$. Note that the term-frequency here includes bigram counts too, which is found to be very helpful because the local word order is taken into consideration via bigrams. As part of the implementation, DrQA maps the bigrams of $2^{24}$ bins using unsigned murmur3 hash.\nPrecisely, DrQA implemented Wikipedia as its knowledge source and this choice has became a default setting for many ODQA studies since then. The non-ML document retriever returns the top $k=5$ most relevant Wikipedia articles given a question.\nBERTserini (Yang et al., 2019) pairs the open-source Anserini IR toolkit as the retriever with a fine-tuned pre-trained BERT model as the reader. The top $k$ documents ($k=10$) are retrieved via the post-v3.0 branch of Anserini with the query treated as a bag of words. The retrieved text segments are ranked by BM25, a classic TF-IDF-based retrieval scoring function. In terms of the effect of text granularity on performance, they found that paragraph retrieval \u0026gt; sentence retrieval \u0026gt; article retrieval.\nFig. 3. An illustration of BERTserini architecture. (Image source: Yang et al., 2019) ElasticSearch + BM25 is used by the Multi-passage BERT QA model (Wang et al., 2019). They found that splitting articles into passages with the length of 100 words by sliding window brings 4% improvements, since splitting documents into passages without overlap may cause some near-boundary evidence to lose useful contexts.\nNeural IR There is a long history in learning a low-dimensional representation of text, denser than raw term-based vectors (Deerwester et al., 1990; Yih, et al., 2011). Dense representations can be learned through matrix decomposition or some neural network architectures (e.g. MLP, LSTM, bidirectional LSTM, etc). When involving neural networks, such approaches are referred to as \u0026ldquo;Neural IR\u0026rdquo;, Neural IR is a new category of methods for retrieval problems, but it is not necessary to perform better/superior than classic IR (Lim, 2018).\nAfter the success of many large-scale general language models, many QA models embrace the following approach:\n $$ h_x = E_x(x)\\quad h_z = E_z(z)\\quad \\text{score}(x, z) = h_x^\\top h_z $$ Extract the dense representations of a question $x$ and a context passage $z$ by feeding them into a language model; Use the dot-product of these two representations as the retrieval score to rank and select most relevant passages. ORQA, REALM and DPR all use such a scoring function for context retrieval, which will be described in detail in a later section on the end-to-end QA model.\nAn extreme approach, investigated by DenSPI (\u0026ldquo;Dense-Sparse Phrase Index\u0026rdquo;; Seo et al., 2019), is to encode all the text in the knowledge corpus at the phrase level and then only rely on the retriever to identify the most relevant phrase as the predicted answer. In this way, the retriever+reader pipeline is reduced to only retriever. Of course, the index would be much larger and the retrieval problem is more challenging.\nDenSPI introduces a query-agnostic indexable representation of document phrases. Precisely it encodes query-agnostic representations of text spans in Wikipedia offline and looks for the answer at inference time by performing nearest neighbor search. It can drastically speed up the inference time, because there is no need to re-encode documents for every new query, which is often required by a reader model.\nGiven a question $x$ and a fixed set of (Wikipedia) documents, $z_1, \\dots, z_K$ and each document $z_k$ contains $N_k$ words, $z_k = \\langle z_k^{(1)}, \\dots, z_k^{(N_k)}\\rangle$. An ODQA model is a scoring function $F$ for each candidate phrase span $z_k^{(i:j)}, 1 \\leq i \\leq j \\leq N_k$, such that the truth answer is the phrase with maximum score: $y = {\\arg\\max}_{k,i,j} F(x, z_k^{(i:j)})$.\nThe phrase representation $z_k^{(i:j)}$ combines both dense and sparse vectors, $z_k^{(i:j)} = [d_k^{(i:j)}, s_k^{(i:j)}] \\in \\mathbb{R}^{d^d + d^s}$ (note that $d^d \\ll d^s$):\n The dense vector $d_k^{(i:j)}$ is effective for encoding local syntactic and semantic cues, as what can be learned by a pretrained language model. The sparse vector $s_k^{(i:j)}$ is superior at encoding precise lexical information. The sparse vector is term-frequency-based encoding. DenSPI uses 2-gram term-frequency same as DrQA, resulting a highly sparse representation ($d^s \\approx 16$M) The dense vector $d^{(i:j)}$ is further decomposed into three parts, $d^{(i:j)} = [a_i, b_j, c_{ij}] \\in \\mathbb{R}^{2d^b + 1}$ where $2d^b + 1 = d^d$. All three components are learned based on different columns of the fine-tuned BERT representations.\n A vector $a_i$ encodes the start position for the $i$-th word of the document; A vector $b_j$ encodes the end position for the $j$-th word of the document; A scalar $c_{ij}$ measures the coherency between the start and the end vectors, helping avoid non-constituent phrases during inference. For all possible $(i,j,k)$ tuples where $j-i \u0026lt; J$, the text span embeddings are precomputed and stored as a phrase index. The maximum span length $J$ is a predefined scalar constant.\nFig. 4. An illustration of Dense-Sparse Phrase Index (DenSPI) architecture. (Image source: Seo et al., 2019) At the inference time, the question is mapped into the same vector space $x=[d', s'] \\in \\mathbb{R}^{d^d + d^s}$, where the dense vector $d'$ is extracted from the BERT embedding of the special [CLS] symbol. The same BERT model is shared for encoding both questions and phrases. The final answer is predicted by $k^*, i^*, j^* = \\arg\\max x^\\top z_k^{(i:j)}$.\nReader Model The reader model learns to solve the reading comprehension task \u0026mdash; extract an answer for a given question from a given context document. Here we only discuss approaches for machine comprehension using neural networks.\nBi-directional LSTM The reader model for answer detection of DrQA (Chen et al., 2017) is a 3-layer bidirectional LSTM with hidden size 128. Every relevant paragraph of retrieved Wikipedia articles is encoded by a sequence of feature vector, $\\{\\tilde{\\mathbf{z}}_1, \\dots, \\tilde{\\mathbf{z}}_m \\}$. Each feature vector $\\hat{\\mathbf{z}}_i \\in \\mathbb{R}^{d_z}$ is expected to capture useful contextual information around one token $z_i$. The feature consists of several categories of features:\n Word embeddings: A 300d Glove word embedding trained from 800B Web crawl data, $f_\\text{embed} = E_g(z_i)$. Exact match: Whether a word $z_i$ appears in the question $x$, $f_\\text{match} = \\mathbb{I}(z_i \\in x)$. Token features: This includes POS (part-of-speech) tagging, NER (named entity recognition), and TF (term-frequency), $f_\\text{token}(z_i) = (\\text{POS}(z_i), \\text{NER}(z_i), \\text{TF}(z_i))$. Aligned question embedding: The attention score $y_{ij}$ is designed to capture inter-sentence matching and similarity between the paragraph token $z_i$ and the question word $x_j$. This feature adds soft alignments between similar but non-identical words. $$ \\begin{aligned} f_\\text{align}(z_i) \u0026= \\sum_j y_{i,j} E_g(x_j) \\\\ y_{i,j} \u0026= \\frac{\\exp(\\alpha(E_g(z_i))^\\top \\alpha(E_g(x_j)) )}{\\sum_{j'} \\exp(\\alpha(E_g(z_i))^\\top \\alpha(E_g(x_{j'})) ) } \\end{aligned} $$ where $\\alpha$ is a single dense layer with ReLU and $E_g(.)$ is the glove word embedding.\nThe feature vector of a paragraph of $m$ tokens is fed into LSTM to obtain the final paragraph vectors:\n $$ \\begin{aligned} \\mathbf{z} = \\{\\mathbf{z}_1, \\dots, \\mathbf{z}_m\\} \u0026= \\text{LSTM}(\\{\\tilde{\\mathbf{z}}_1, \\dots, \\tilde{\\mathbf{z}}_m\\}) \\\\ \\text{where } \\tilde{\\mathbf{z}}_i \u0026= \\{f_\\text{embed}, f_\\text{match}, f_\\text{token}, f_\\text{align}\\} \\end{aligned} $$ The question is encoded as a weighted sum of the embeddings of every word in the question:\n $$ \\mathbf{x} = \\sum_j b_j E(x_j) \\quad b_j = \\text{softmax}(\\mathbf{w}^\\top E(x_j)) $$ where $\\mathbf{w}$ is a weight vector to learn.\nOnce the feature vectors are constructed for the question and all the related paragraphs, the reader needs to predict the probabilities of each position in a paragraph to be the start and the end of an answer span, $p_\\text{start}(i_s)$ and $p_\\text{end}(i_s)$, respectively. Across all the paragraphs, the optimal span is returned as the final answer with maximum $p_\\text{start}(i_s) \\times p_\\text{end}(i_e) $.\n $$ \\begin{aligned} p_\\text{start}(i_s) \\propto \\exp(\\mathbf{z}_{i_s} \\mathbf{W}_s \\mathbf{x}) \\\\ p_\\text{end}(i_e) \\propto \\exp(\\mathbf{z}_{i_e} \\mathbf{W}_e \\mathbf{x}) \\\\ \\text{ s.t. } i_s \\leq i_e \\leq i_s + 15 \\end{aligned} $$ where $\\mathbf{W}_s$ and $\\mathbf{W}_e$ are learned parameters.\nBERT-universe Following the success of BERT (Devlin et al., 2018), many QA models develop the machine comprehension component based on BERT. Let\u0026rsquo;s define the BERT model as a function that can take one or multiple strings (concatenated by [SEP]) as input and outputs a set of BERT encoding vectors for the special [CLS] token and every input token:\n $$ \\text{BERT}(s_1, s_2, \\dots) = [\\mathbf{h}^\\texttt{[CLS]}, \\mathbf{h}^{(1)}, \\mathbf{h}^{(2)}, \\dots] $$ where $\\mathbf{h}^\\texttt{[CLS]}$ is the embedding vector for the special [CLS] token and $\\mathbf{h}^{(i)}$ is the embedding vector for the $i$-th token.\nTo use BERT for reading comprehension, it learns two additional weights, $\\mathbf{W}_s$ and $\\mathbf{W}_e$, and $\\text{softmax}(\\mathbf{h}^{(i)}\\mathbf{W}_s)$ and $\\text{softmax}(\\mathbf{h}^{(i)}\\mathbf{W}_e)$ define two probability distributions of start and end position of the predicted span per token.\nBERTserini (Yang et al., 2019) utilizes a pre-trained BERT model to work as the reader. Their experiments showed that fine-tuning pretrained BERT with SQuAD is sufficient to achieve high accuracy in identifying answer spans.\nFig. 5. How BERT is used to solve question-answering tasks. (Image source: Devlin et al., 2018) The key difference of the BERTserini reader from the original BERT is: to allow comparison and aggregation of results from different segments, the final softmax layer over different answer spans is removed. The pre-trained BERT model is fine-tuned on the training set of SQuAD, where all inputs to the reader are padded to 384 tokens with the learning rate 3e-5.\nWhen ranking all the extracted answer spans, the retriever score (BM25) and the reader score (probability of token being the start position $\\times$ probability of the same token being the end position ) are combined via linear interpolation.\nThe original BERT normalizes the probability distributions of start and end position per token for every passage independently. Differently, the Multi-passage BERT (Wang et al., 2019) normalizes answer scores across all the retrieved passages of one question globally. Precisely, multi-passage BERT removes the final normalization layer per passage in BERT for QA (same as in BERTserini) and then adds a global softmax over all the word positions of all the passages. Global normalization makes the reader model more stable while pin-pointing answers from a large number of passages.\nIn addition, multi-passage BERT implemented an independent passage ranker model via another BERT model and the rank score for $(x, z)$ is generated by a softmax over the representation vectors of the first [CLS] token. The passage ranker brings in extra 2% improvements. Similar idea of re-ranking passages with BERT was discussed in Nogueira \u0026amp; Cho, 2019, too.\nInterestingly, Wang et al., 2019 found that explicit inter-sentence matching does not seem to be critical for RC tasks with BERT; check the original paper for how the experiments were designed. One possible reason is that the multi-head self-attention layers in BERT has already embedded the inter-sentence matching.\nEnd-to-end Joint Training The retriever and reader components can be jointly trained. This section covers R^3, ORQA, REALM and DPR. There are a lot of common designs, such as BERT-based dense vectors for retrieval and the loss function on maximizing the marginal likelihood of obtaining true answers.\nThe retriever and reader models in the R^3 (\u0026ldquo;Reinforced Ranker-Reader\u0026rdquo;; Wang, et al., 2017) QA system are jointly trained via reinforcement learning. (Note that to keep the term consistent between papers in this section, the \u0026ldquo;ranker\u0026rdquo; model in the original R^3 paper is referred to as the \u0026ldquo;retriever\u0026rdquo; model here.) Both components are variants of Match-LSTM, which relies on an attention mechanism to compute word similarities between the passage and question sequences.\nHow does the Match-LSTM module work? Given a question $\\mathbf{X}$ of $d_x$ words and a passage $\\mathbf{Z}$ of $d_z$ words, both representations use fixed Glove word embeddings,\n $$ \\begin{aligned} \\mathbf{H}^x \u0026= \\text{BiLSTM}(\\mathbf{X}) \\in \\mathbb{R}^{l \\times d_x} \\\\ \\mathbf{H}^z \u0026= \\text{BiLSTM}(\\mathbf{Z}) \\in \\mathbb{R}^{l \\times d_z} \\\\ \\mathbf{G} \u0026= \\text{softmax}((\\mathbf{W}^g \\mathbf{H}^x + \\mathbf{b}^g \\otimes \\mathbf{e}_{d_x})^\\top \\mathbf{H}^z) \\in \\mathbb{R}^{d_x \\times d_z} \u0026 \\text{; an attention matrix}\\\\ \\bar{\\mathbf{H}}^x \u0026= \\mathbf{H}^x \\mathbf{G} \\in \\mathbb{R}^{l \\times d_z} \\\\ \\mathbf{M} \u0026= \\text{ReLU} \\Big( \\mathbf{W}^m \\begin{bmatrix} \\mathbf{H}^z \\\\ \\bar{\\mathbf{H}}^x \\\\ \\mathbf{H}^z \\odot \\bar{\\mathbf{H}}^x \\\\ \\mathbf{H}^z - \\bar{\\mathbf{H}}^x \\end{bmatrix} \\Big) \\in \\mathbb{R}^{2l \\times d_z} \\\\ \\mathbf{H}^m \u0026= \\text{BiLSTM}(M) \\in \\mathbb{R}^{l \\times d_z} \\end{aligned} $$ where $l$ is the hidden dimension of the bidirectional LSTM module. $\\mathbf{W}^g \\in \\mathbb{R}^{l\\times l}$, $\\mathbf{b}^g \\in \\mathbb{R}^l$, and $\\mathbf{W}^m \\in \\mathbb{R}^{2l \\times 4l}$ are parameters to learn. The operator $\\otimes \\mathbf{e}_{d_x}$ is the outer product to repeat the column vector $\\mathbf{b}^g$ $d_x$ times.\nThe ranker and reader components share the same Match-LSTM module with two separate prediction heads in the last layer, resulting in $\\mathbf{H}^\\text{rank}$ and $\\mathbf{H}^\\text{reader}$.\nFig. 6. The overview of R^3 (reinforced ranker-reader) architecture. Both components share the same Match-LSTM module. (Image source: Wang, et al., 2017) The retriever runs a max-pooling operation per passage and then aggregates to output a probability of each passage entailing the answer.\n $$ \\begin{aligned} \\mathbf{u}_i \u0026= \\text{max-pooling}(\\mathbf{H}^\\text{rank}_i) \\in \\mathbb{R}^l \\\\ \\mathbf{C} \u0026= \\text{tanh}(\\mathbf{W}^c[\\mathbf{u}_1;\\dots;\\mathbf{u}_N] + \\mathbf{b}^c \\otimes \\mathbf{e}_N) \\in \\mathbb{R}^{l \\times n} \\\\ \\gamma \u0026= \\text{softmax}(\\mathbf{w}^c \\mathbf{C}) \\in \\mathbb{R}^n \\end{aligned} $$ Finally, the retriever is viewed as a policy to output action to sample a passage according to predicted $\\gamma$,\n $$ \\pi(z \\vert x; \\theta^\\gamma) = \\gamma_z $$ The reader predicts the start position $\\beta^s$ and the end position $\\beta^e$ of the answer span. Two positions are computed in the same way, with independent parameters to learn. There are $V$ words in all the passages involved.\n $$ \\begin{aligned} \\mathbf{H}^\\text{read} \u0026= [\\mathbf{H}^\\text{read}_\\tau; \\mathbf{H}^\\text{read}_{\\text{neg}_1}; \\dots; \\mathbf{H}^\\text{read}_{\\text{neg}_n}] \\\\ \\mathbf{F}^s \u0026= \\text{tanh}(\\mathbf{W}^s \\mathbf{H}^\\text{read} + \\mathbf{b}^s \\otimes \\mathbf{e}_V) \\quad \\beta^s = \\text{softmax}(\\mathbf{w}^s \\mathbf{F}^s) \\in \\mathbb{R}^V \\\\ \\mathbf{F}^e \u0026= \\text{tanh}(\\mathbf{W}^e \\mathbf{H}^\\text{read} + \\mathbf{b}^e \\otimes \\mathbf{e}_V) \\quad \\beta^e = \\text{softmax}(\\mathbf{w}^e \\mathbf{F}^e) \\in \\mathbb{R}^V \\\\ L(y \\vert z, x) \u0026= -\\log(\\beta^s_{y_z^s})-\\log(\\beta^e_{y_z^e}) \\end{aligned} $$ where $y$ is the ground-truth answer and the passage $z$ is sampled by the retriever. $\\beta^s_{y_z^s}$ and $\\beta^s_{y_z^e}$ represent the probabilities of the start and end positions of $y$ in passage $z$.\nThe training objective for the end-to-end R^3 QA system is to minimize the negative log-likelihood of obtaining the correct answer $y$ given a question $x$,\n $$ \\begin{aligned} \\mathcal{J}(\\theta) \u0026= -\\mathbb{E}_{z\\sim\\pi(.\\vert x)} [L(y \\vert z, x)] \\\\ \\nabla \\mathcal{J}(\\theta) \u0026= - \\nabla_\\theta \\sum_z \\pi(z \\vert x) L(y \\vert z, x) \\\\ \u0026= - \\sum_z \\big( L(y \\vert z, x) \\nabla_\\theta\\pi(z \\vert x) + \\pi(z \\vert x) \\nabla_\\theta L(y \\vert z, x) \\big) \\\\ \u0026= - \\mathbb{E}_{z\\sim\\pi(.\\vert x)} \\big( \\color{red}{L(y \\vert z, x)\\nabla_\\theta\\log\\pi(z \\vert x)} + \\nabla_\\theta L(y \\vert z, x) \\big) \\\\ \u0026\\approx - \\mathbb{E}_{z\\sim\\pi(.\\vert x)} \\big( \\underbrace{\\color{red}{R(y \\vert z, x)\\nabla_\\theta\\log\\pi(z \\vert x)}}_\\text{REINFORCE} + \\nabla_\\theta L(y \\vert z, x) \\big) \\end{aligned} $$ Essentially in training, given a passage $z$ sampled by the retriever, the reader is trained by gradient descent while the retriever is trained by REINFORCE using $L(y \\vert z, x)$ as the reward function. However, $L(y \\vert z, x)$ is not bounded and may introduce a lot of variance. The paper replaces the reward with a customized scoring function by comparing the ground truth $y$ and the answer extracted by the reader $\\hat{y}$:\n $$ R(y, \\hat{y} \\vert z) = \\begin{cases} 2 \u0026 \\text{if } y = \\hat{y}\\\\ f1(y, \\hat{y}) \u0026 \\text{if } y \\cap \\hat{y} = \\varnothing \\\\ -1 \u0026 \\text{otherwise} \\end{cases} $$ Fig. 7. The workflow of R^3 training process. (Image source: acl2020-openqa-tutorial/slides/part4) ORQA (\u0026ldquo;Open-Retrieval Question-Answering\u0026rdquo;; Lee et al., 2019) jointly learns a retriever + reader QA model to optimize marginal log-likelihood of obtaining correct answers in a supervised manner. No explicit \u0026ldquo;black-box\u0026rdquo; IR system is involved. Instead, it is capable of retrieving any text in an open corpus. During training, ORQA does not need ground-truth context passages (i.e. reading comprehension datasets) but only needs (question, answer) string pairs. Both retriever and reader components are based on BERT, but not shared.\nFig. 8. An illustration of the retriever component in ORQA. (Image source: replotted based on one slide in acl2020-openqa-tutorial/slides/part5) All the evidence blocks are ranked by a retrieval score, defined as the inner product of BERT embedding vectors of the [CLS] token of the question $x$ and the evidence block $z$. Note that the encoders for questions and context are independent.\n $$ \\begin{aligned} h_x \u0026= \\mathbf{W}_x \\text{BERT}_x(x)^{\\mathtt{[CLS]}} \\\\ h_z \u0026= \\mathbf{W}_z \\text{BERT}_z(z)^{\\mathtt{[CLS]}} \\\\ S_\\text{retr}(z, x) \u0026= h_x^\\top h_z \\end{aligned} $$ The retriever module is pretrained with Inverse Cloze Task (ICT), which is to predict the context given a sentence, opposite to the standard Cloze Task. The ICT objective is to maximize the retrieval score of the correct context $z$ given a random sentence $x$:\n $$ L_\\text{ICT} = p_\\text{early}(z \\vert x) = \\frac{\\exp(S_\\text{retr}(z, x))}{\\sum_{z'\\in\\text{BATCH}(\\mathcal{Z})} \\exp(S_\\text{retr}(z', x))} $$ where $\\text{BATCH}(\\mathcal{Z})$ is the set of evidence blocks in the same batch used as sampled negatives.\nAfter such pretraining, the BERT retriever is expected to have representations good enough for evidence retrieval. Only the question encoder needs to be fine-tuned for answer extraction. In other words, the evidence block encoder (i.e., $\\mathbf{W}_z$ and $\\text{BERT}_z$) is fixed and thus all the evidence block encodings can be pre-computed with support for fast Maximum Inner Product Search (MIPS).\nFig. 9. An illustration of the reader component in ORQA. (Image source: acl2020-openqa-tutorial/slides/part5) The reader follows the same design as in the original BERT RC experiments. It learns in a supervised manner, while the parameters of the evidence block encoder are fixed and all other parameters are fine-tuned. Given a question $x$ and a gold answer string $y$, the reader loss contains two parts:\n $$ \\mathcal{L}(x, y) = \\mathcal{L}_\\text{early}(x, y) + \\mathcal{L}_\\text{full}(x, y) $$ (1) Find all correct text spans within top $k$ evidence blocks and optimize for the marginal likelihood of a text span $s$ that matches the true answer $y$:\n $$ \\begin{aligned} h_s \u0026= \\text{BERT}_R(x, y)^{(\\text{START}(s))} \\\\ h_e \u0026= \\text{BERT}_R(x, y)^{(\\text{END}(s))} \\\\ S_\\text{read}(z, s, x) \u0026= \\text{MLP}([h_s; h_e]) \\\\ p(z, s \\vert x) \u0026= \\frac{\\exp(S_\\text{read}(z, s, x))}{\\sum_{z'\\in\\text{TOP}(k)} \\sum_{s'\\in z'} \\exp(S_\\text{read}(z', s', x))} \\\\ L_\\text{full}(x, y) \u0026= - \\log \\sum_{\\substack{z \\in \\text{TOP}(k)\\\\ s \\in z}} \\sum_{y=\\text{TEXT}(s)} p(z, s \\vert x) \\end{aligned} $$ where $y=\\text{TEXT}(s)$ indicates whether the answer $y$ matches the text span $s$. $\\text{TOP}(k)$ is the top $k$ retrieved blocks according to $S_\\text{retr}(z, x)$. The paper sets $k=5$.\n(2) At the early stage of learning, when the retriever is not strong enough, it is possible none of the top $k$ blocks contains the answer. To avoid such sparse learning signals, ORQA considers a larger set of $c$ evidence blocks for more aggressive learning. The paper has $c=5000$.\n $$ L_\\text{early}(x, y) = -\\log \\sum_{\\substack{z\\in \\text{TOP}(c)\\\\y\\in\\text{TEXT}(z)}} p_\\text{early}(z\\vert x) = -\\log \\sum_{\\substack{z\\in \\text{TOP}(c)\\\\y\\in\\text{TEXT}(z)}} \\frac{\\exp(S_\\text{retr}(z, x)}{\\sum_{z'\\in\\text{TOP}(c)} \\exp(S_\\text{retr}(z', x)} $$ Some issues in SQuAD dataset were discussed in the ORQA paper:\n \u0026quot; The notable drop between development and test accuracy for SQuAD is a reflection of an artifact in the dataset\u0026mdash;its 100k questions are derived from only 536 documents. Therefore, good retrieval targets are highly correlated between training examples, violating the IID assumption, and making it unsuitable for learned retrieval. We strongly suggest that those who are interested in end-to-end open-domain QA models no longer train and evaluate with SQuAD for this reason.\u0026quot;\n REALM (\u0026ldquo;Retrieval-Augmented Language Model pre-training\u0026rdquo;; Guu et al., 2020) also jointly trains retriever + reader by optimizing the marginal likelihood of obtaining the true answer:\n $$ p(y \\vert x) = \\sum_{z \\in \\mathcal{Z}} \\underbrace{p(y \\vert x, z)}_\\text{reader} \\underbrace{p(z \\vert x)}_\\text{retriever} \\approx \\sum_{z \\in \\text{TOP}_k(\\mathcal{Z})} p(y \\vert x, z) p(z \\vert x) $$ Fig. 10. REALM is first unsupervised pre-trained with salient spans masking and then fine-tuned with QA data. (Image source: Guu et al., 2020). REALM computes two probabilities, $p(z \\vert x)$ and $p(y \\vert x, z)$, same as ORQA. However, different from ICT in ORQA, REALM upgrades the unsupervised pre-training step with several new design decisions, leading towards better retrievals. REALM pre-trains the model with Wikipedia or CC-News corpus.\n Use salient span masking. Named entities and dates are identified. Then one of these \u0026ldquo;salient spans\u0026rdquo; is selected and masked. Salient span masking is a special case of MLM and works out well for QA tasks. Add an empty null document. Because not every question demands a context document. No trivial retrieval. The context document should not be same as the selected sentence with a masked span. Apply the same ICT loss as in ORQA to encourage learning when the retrieval quality is still poor at the early stage of training. \u0026ldquo;Among all systems, the most direct comparison with REALM is ORQA (Lee et al., 2019), where the fine-tuning setup, hyperparameters and training data are identical. The improvement of REALM over ORQA is purely due to better pre-training methods.\u0026rdquo; \u0026mdash; from REALM paper.\n Both unsupervised pre-training and supervised fine-tuning optimize the same log-likelihood $\\log p(y \\vert x)$. Because the parameters of the retriever encoder for evidence documents are also updated in the process, the index for MIPS is changing. REALM asynchronously refreshes the index with the updated encoder parameters every several hundred training steps.\nBalachandran, et al. (2021) found that REALM is significantly undertrained and REALM++ achieves great EM accuracy improvement (3-5%) by scaling up the model training with larger batch size and more retrieved documents for the reader to process.\nDPR (\u0026ldquo;Dense Passage Retriever\u0026rdquo;; Karpukhin et al., 2020, code) argues that ICT pre-training could be too computationally expensive and the ORQA\u0026rsquo;s context encoder might be sub-optimal because it is not fine-tuned with question-answer pairs. DPR aims to resolve these two issues by only training a dense dual-encoder architecture for retrieval only from a small number of Q/A pairs, without any pre-training.\nSame as previous work, DPR uses the dot-product (L2 distance or cosine similarity also works) of BERT representations as retrieval score. The loss function for training the dual-encoder is the NLL of the positive passage, which essentially takes the same formulation as ICT loss of ORQA. Note that both of them consider other passages in the same batch as the negative samples, named in-batch negative sampling. The main difference is that DPR relies on supervised QA data, while ORQA trains with ICT on unsupervised corpus. At the inference time, DPR uses FAISS to run fast MIPS.\nDPR did a set of comparison experiments involving several different types of negatives:\n Random: any random passage from the corpus; BM25: top passages returned by BM25 which don\u0026rsquo;t contain the answer but match most question tokens; In-batch negative sampling (\u0026ldquo;gold\u0026rdquo;): positive passages paired with other questions which appear in the training set. DPR found that using gold passages from the same mini-batch and one negative passage with high BM25 score works the best. To further improve the retrieval results, DPR also explored a setting where a BM25 score and a dense embedding retrieval score are linearly combined to serve as a new ranking function.\nOpen-book QA: Retriever-Generator Compared to the retriever-reader approach, the retriever-generator also has 2 stages but the second stage is to generate free text directly to answer the question rather than to extract start/end position in a retrieved passage. Some paper also refer to this as Generative question answering.\nFig. 11. The retriever + generator QA framework combines a document retrieval system with a general language model. A pretrained LM has a great capacity of memorizing knowledge in its parameters, as shown above. However, they cannot easily modify or expand their memory, cannot straightforwardly provide insights into their predictions, and may produce non-existent illusion.\nPetroni et al. (2020) studied how the retrieved relevant context can help a generative language model produce better answers. They found:\n Augmenting queries with relevant contexts dramatically improves the pretrained LM on unsupervised machine reading capabilities. An off-the-shelf IR system is sufficient for BERT to match the performance of a supervised ODQA baseline; BERT\u0026rsquo;s NSP pre-training strategy is a highly effective unsupervised mechanism in dealing with noisy and irrelevant contexts. They pair the BERT model with different types of context, including adversarial (unrelated context), retrieved (by BM25), and generative (by an autoregressive language model of 1.4N parameters, trained on CC-NEWS). The model is found to be robust to adversarial context, but only when the question and the context are provided as two segments (e.g. separated by [SEP]). One hypothesis is related to NSP task: \u0026ldquo;BERT might learn to not condition across segments for masked token prediction if the NSP score is low, thereby implicitly detecting irrelevant and noisy contexts.\u0026rdquo;\nRAG (\u0026ldquo;Retrieval-Augmented Generation\u0026rdquo;; Lewis et al., 2020) combines pre-trained parametric (language model) and non-parametric memory (external knowledge index) together for language generation. RAG can be fine-tuned on any seq2seq task, whereby both the retriever and the sequence generator are jointly learned. They found that unconstrained generation outperforms previous extractive approaches.\nRAG consists of a retriever model $p_\\eta(z \\vert x)$ and a generator model $p_\\theta(y_i \\vert x, z, y_{1:i-1})$:\n The retriever uses the input sequence $x$ to retrieve text passages $z$, implemented as a DPR retriever. $\\log p_\\eta(z \\vert x) \\propto E_z(z)^\\top E_x(x)$. The generator uses $z$ as additional context when generating the target sequence $y$, where the context and the question are simply concatenated. Depending on whether using the same or different retrieved documents for each token generation, there are two versions of RAG:\n $$ \\begin{aligned} p_\\text{RAG-seq}(y \\vert x) \u0026= \\sum_{z \\in \\text{TOP}_k(p_\\eta(.\\vert x))} p_\\eta(z \\vert x) \\prod_i^N p_\\theta(y_i \\vert x, z, y_{1:i-1}) \\\\ p_\\text{RAG-token}(y \\vert x) \u0026= \\prod_i^N \\sum_{z \\in \\text{TOP}_k(p_\\eta(.\\vert x))} p_\\eta(z_i\\vert x) p_\\theta(y_i \\vert x, z_i, y_{1:i-1}) \\end{aligned} $$ The retriever + generator in RAG is jointly trained to minimize the NLL loss, $\\mathcal{L}_\\text{RAG} = \\sum_j -\\log p(y_j \\vert x_j)$. Updating the passage encoder $E_z(.)$ is expensive as it requires the model to re-index the documents for fast MIPS. RAG does not find fine-tuning $E_z(.)$ necessary (like in ORQA) and only updates the query encoder + generator.\nFig. 12. An illustration of retrieval-augmented generation (RAG) architecture. (Image source: Lewis et al., 2020) At decoding/test time, RAG-token can be evaluated via a beam search. RAG-seq cannot be broken down into a set of per-token likelihood, so it runs beam search for each candidate document $z$ and picks the one with optimal $p_\\theta(y_i \\vert x, z, y_{1:i-1})$.\nThe Fusion-in-Decoder approach, proposed by Izacard \u0026amp; Grave (2020) is also based on a pre-trained T5. It works similar to RAG but differently for how the context is integrated into the decoder.\n Retrieve top $k$ related passage of 100 words each, using BM25 or DPR. Each retrieved passage and its title are concatenated with the question using special tokens like question:, title: and context: to indicate the content differences. Each retrieved passage is processed independently and later combined in the decoder. Processing passages independently in the encoder allows us to parallelize the computation. OTOH, processing them jointly encourages better aggregation of multiple pieces of evidence. The aggregation part is missing in extractive approaches. Note that they did fine-tune the pretrained LM independently for each dataset.\nClosed-book QA: Generative Language Model Big language models have been pre-trained on a large collection of unsupervised textual corpus. Given enough parameters, these models are able to memorize some factual knowledge within parameter weights. Therefore, we can use these models to do question-answering without explicit context, just like in a closed-book exam. The pre-trained language models produce free text to respond to questions, no explicit reading comprehension.\nFig. 13. The amount of computation used for training big language models of different sizes is getting big. (Image source: Brown et al., 2020). Roberts et al. (2020) measured the practical utility of a language model by fine-tuning a pre-trained model to answer questions without access to any external context or knowledge. They fine-tuned the T5 language model (same architecture as the original Transformer) to answer questions without inputting any additional information or context. Such setup enforces the language model to answer questions based on \u0026ldquo;knowledge\u0026rdquo; that it internalized during pre-training.\nFig. 14. T5 is first pre-trained with salient span masking and then fine-tuned for each QA dataset to produce answers in free text. (Image source: Roberts et al. 2020) The original T5 models were pre-trained on a multi-task mixture including an unsupervised \u0026ldquo;masked language modeling\u0026rdquo; (MLM) tasks on the C4 (\u0026ldquo;Colossal Clean Crawled Corpus\u0026rdquo;) dataset as well as fine-tuned altogether with supervised translation, summarization, classification, and reading comprehension tasks. Roberts, et al. (2020) took a pre-trained T5 model and continued pre-training with salient span masking over Wikipedia corpus, which has been found to substantially boost the performance for ODQA. Then they fine-tuned the model for each QA datasets independently.\nWith a pre-trained T5 language model + continue pre-training with salient spans masking + fine-tuning for each QA dataset,\n It can attain competitive results in open-domain question answering without access to external knowledge. A larger model can obtain better performance. For example, a T5 with 11B parameters is able to match the performance with DPR with 3 BERT-base models, each with 330M parameters. Interestingly, fine-tuning is not strictly necessary. GPT3 (Brown et al., 2020) has been evaluated on the closed book question answering task without any gradient updates or fine-tuning. During evaluation, the few-shot, one-shot and zero-shot settings here only refer to how many demonstrations are provided as context in the text input:\n \u0026ldquo;few-shot learning\u0026rdquo;: GPT3 is allowed to take as many demonstrations as what can fit into the model\u0026rsquo;s context window (typically 10 to 100). \u0026ldquo;one-shot learning\u0026rdquo;: only one demonstration is provided. \u0026ldquo;zero-shot learning\u0026rdquo;: no demonstrations are allowed and only an instruction in natural language is given to the model. The performance grows with the model size. On the TriviaQA dataset, GPT3 evaluation with demonstrations can match or exceed the performance of SOTA baseline with fine-tuning.\nFig. 15. GPT3's performance on TriviaQA grows smoothly with the model size. More demonstrations lead to better performance. (Image source: Brown et al., 2020). Check out this cool example in OpenAI API playground viewer. The model is able to answer factal questions in short answer and not to make up things when the model does not know the answer. I added the last two questions and asked the model to respond with A:. The API is still in beta version, so you might need to apply to get on the wait list.\nQ: Who is Batman? A: Batman is a fictional comic book character. ## Q: What is torsalplexity? A: ? ## Q: What is Devz9? A: ? ## Q: Who is George Lucas? A: George Lucas is American film director and producer famous for creating Star Wars. ## Q: What is the capital of California? A: Sacramento. ## Q: What orbits the Earth? A: The Moon. ## Q: Who is Fred Rickerson? A: ? ## Q: What is an atom? A: An atom is a tiny particle that makes up everything. ## Q: Who is Alvan Muntz? A: ? ## Q: What is Kozar-09? A: ? ## Q: How many moons does Mars have? A: Two, Phobos and Deimos. ## Q: What is COVID-19? A: ? ## Q: What is H1N1? A: H1N1 is a strain of influenza. Related Techniques Fast Maximum Inner Product Search (MIPS) MIPS (maximum inner product search) is a crucial component in many open-domain question answering models. In retriever + reader/generator framework, a large number of passages from the knowledge source are encoded and stored in a memory. A retrieval model is able to query the memory to identify the top relevant passages which have the maximum inner product with the question\u0026rsquo;s embedding.\nWe need fast MIPS because the number of precomputed passage representations can be gigantic. There are several ways to achieve fast MIPS at run time, such as asymmetric LSH, data-dependent hashing, and FAISS.\nLanguage Model Pre-training Two pre-training tasks are especially helpful for QA tasks, as we have discussed above.\n Inverse Cloze Task (proposed by ORQA): The goal of Cloze Task is to predict masked-out text based on its context. The prediction of Inverse Cloze Task (ICT) is in the reverse direction, aiming to predict the context given a sentence. In the context of QA tasks, a random sentence can be treated as a pseudo-question, and its context can be treated as pseudo-evidence.\n Salient Spans Masking (proposed by REALM): Salient span masking is a special case for MLM task in language model training. First, we find salient spans by using a tagger to identify named entities and a regular expression to identify dates. Then one of the detected salient spans is selected and masked. The task is to predict this masked salient span.\n Summary Model Retriever Reader / Generator Pre-training / Fine-tuning End2end DrQA TF-IDF Bi-directional LSTM \u0026ndash; No BERTserini Aserini + BM25 BERT without softmax layer Fine-tune with SQuAD No Multi-passage BERT ElasticSearch + BM25 Multi-passage BERT + Passage ranker No R^3 Classic IR + Match-LSTM Match-LSTM Yes ORQA Dot product of BERT embeddings BERT-RC Inverse cloze task Yes REALM Dot product of BERT embeddings BERT-RC Salient span masking Yes DPR Dot product of BERT embeddings BERT-RC supervised training with QA pairs Yes DenSPI Classic + Neural IR \u0026ndash; Yes T5 + SSM \u0026ndash; T5 SSM on CommonCrawl data + Fine-tuning on QA data Yes GPT3 \u0026ndash; GPT3 NSP on CommonCrawl data Yes RAG DPR retriever BART Yes Fusion-in-Decoder BM25 / DPR retriever Tranformer No Fig. 16. A comparison of performance of several QA models on common QA datasets. On TriviaQA, two columns of results are reported, on the open domain test set (left) and on the hidden test set (right). (Image source: Izacard \u0026 Grave, 2020). Citation Cited as:\n Weng, Lilian. (Oct 2020). How to build an open-domain question answering system? Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2020-10-29-odqa/.\n Or\n@article{weng2020odqa, title = \u0026quot;How to Build an Open-Domain Question Answering System?\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2020\u0026quot;, month = \u0026quot;Oct\u0026quot; url = \u0026quot;https://lilianweng.github.io/posts/2020-10-29-odqa/\u0026quot; } Appendix: QA Datasets SQuAD 2.0: the Stanford QA dataset. RACE: a reading comprehension dataset collected from English Examinations that are created for middle school and high school students. TREC QA: the TREC QA collections. MS MARCO: a QA dataset featuring 100,000 real Bing questions and a human generated answer. CuratedTREC: based on the benchmarks from the TREC QA tasks that have been curated by Baudis \u0026amp; Sedivy (2015). Google Natural Questions: contains real user questions issued to Google search, and answers found from Wikipedia by annotators. WebQuestions: designed for knowledge-base QA with answers restricted to Freebase entities. WikiQA: Bing query logs were used as the source of questions. Each question is then linked to a Wikipedia page that potentially contains the answer. WikiMovies: contains movie-related questions from the OMDb and MovieLens databases and where the questions can be answered using Wikipedia pages. WikiReading: to predict textual values from the structured knowledge base Wikidata by reading the text of the corresponding Wikipedia articles. TriviaQA: a reading comprehension dataset containing 95K question-answer pairs authored by trivia enthusiasts and independently gathered multiple evidence documents per question. Jeopardy! Questions: contains 200,000+ Jeopardy! questions. DeepMind Q\u0026amp;A Dataset: question/answer pairs from CNN and Daily Mail articles. bAbi: a rich collection of datasets for text understanding by Facebook. FEVER: for fact extraction and verification. SearchQA: question-answer pairs were crawled from from J! Archive, and then augmented with text snippets from Google. Quasar-T: a collection of open-domain trivia questions and their answers obtained from various internet sources. Quiz bowl: contains data from a trivia competition called quiz bowl. AmbigNQ: ambiguous questions selected from NQ-OPEN dataset. QA-Overlap: a collections of overlapped answers/questions between train and test set for Natural Questions, TriviaQA, and WebQuestions. References [1] Danqi Chen \u0026amp; Scott Yih. \u0026ldquo;ACL2020 Tutorial: Open-Domain Question Answering\u0026rdquo; July 2020.\n[2] Danqi Chen, et al. \u0026ldquo;Reading Wikipedia to Answer Open-Domain Questions\u0026rdquo; ACL 2017. | code\n[3] Shuohang Wang, et al. \u0026ldquo;R^3: Reinforced Ranker-Reader for Open-Domain Question Answering\u0026rdquo; AAAI 2018.\n[4] Jimmy Lin. \u0026ldquo;The neural hype and comparisons against weak baselines.\u0026quot; ACM SIGIR Forum. Vol. 52. No. 2. 2019.\n[5] Wei Yang, et al. \u0026ldquo;End-to-End Open-Domain Question Answering with BERTserini\u0026rdquo; NAACL 2019.\n[6] Christopher Clark \u0026amp; Matt Gardner. \u0026ldquo;Simple and Effective Multi-Paragraph Reading Comprehension.\u0026quot; arXiv:1710.10723 (2017).\n[7] Rodrigo Nogueira \u0026amp; Kyunghyun Cho. \u0026ldquo;Passage Re-ranking with BERT.\u0026quot; arXiv preprint arXiv:1901.04085 (2019). | code\n[8] Zhiguo Wang, et al. \u0026ldquo;Multi-passage BERT: A globally normalized BERT model for open-domain question answering.\u0026quot; EMNLP 2019.\n[9] Minjoon Seo et al. \u0026ldquo;Real-time open-domain question answering with dense-sparse phrase index.\u0026quot; ACL 2019.\n[10] Kenton Lee, et al. \u0026ldquo;Latent Retrieval for Weakly Supervised Open Domain Question Answering\u0026rdquo; ACL 2019.\n[11] Kelvin Guu, et al. \u0026ldquo;REALM: Retrieval-Augmented Language Model Pre-Training\u0026rdquo; arXiv:2002.08909 (2020).\n[12] Vladimir Karpukhin et al. \u0026ldquo;Dense passage retrieval for open-domain question answering.\u0026quot;. EMNLP 2020. | code\n[13] Patrick Lewis et al. \u0026ldquo;Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks\u0026rdquo; arXiv:2005.11401 (2020).\n[14] Adam Roberts, et al. \u0026ldquo;How Much Knowledge Can You Pack Into the Parameters of a Language Model?\u0026quot; EMNLP 2020.\n[15] Tom Brown, et al. \u0026ldquo;Language models are few-shot learners.\u0026quot; arXiv:2005.14165 (2020).\n[16] Fabio Petroni, et al. \u0026ldquo;How Context Affects Language Models' Factual Predictions\u0026rdquo; AKBC 2020.\n[17] Gautier Izacard \u0026amp; Edouard Grave. \u0026ldquo;Leveraging passage retrieval with generative models for open domain question answering.\u0026quot; arXiv:2007.01282 (2020).\n[18] \u0026ldquo;Dive into deep learning: Beam search\u0026rdquo;\n[19] Patrick Lewis, et al. \u0026ldquo;Question and Answer Test-Train Overlap in Open-Domain Question Answering Datasets\u0026rdquo; arXiv:2008.02637 (2020). | data\n[20] Hervé Jegou, et al. \u0026ldquo;Faiss: A library for efficient similarity search\u0026rdquo; Mar 2017.\n[21] Vidhisha Balachandran, et al. \u0026ldquo;Simple and Efficient ways to Improve REALM.\u0026quot; arXiv:2104.08710 (2021).\n","permalink":"https://lilianweng.github.io/posts/2020-10-29-odqa/","summary":"[Updated on 2020-11-12: add an example on closed-book factual QA using OpenAI API (beta).\nA model that can answer any question with regard to factual knowledge can lead to many useful and practical applications, such as working as a chatbot or an AI assistant🤖. In this post, we will review several common approaches for building such an open-domain question answering system.\nDisclaimers given so many papers in the wild:\n Assume we have access to a powerful pretrained language model.","title":"How to Build an Open-Domain Question Answering System?"},{"content":"Although most popular and successful model architectures are designed by human experts, it doesn\u0026rsquo;t mean we have explored the entire network architecture space and settled down with the best option. We would have a better chance to find the optimal solution if we adopt a systematic and automatic way of learning high-performance model architectures.\nAutomatically learning and evolving network topologies is not a new idea (Stanley \u0026amp; Miikkulainen, 2002). In recent years, the pioneering work by Zoph \u0026amp; Le 2017 and Baker et al. 2017 has attracted a lot of attention into the field of Neural Architecture Search (NAS), leading to many interesting ideas for better, faster and more cost-efficient NAS methods.\nAs I started looking into NAS, I found this nice survey very helpful by Elsken, et al 2019. They characterize NAS as a system with three major components, which is clean \u0026amp; concise, and also commonly adopted in other NAS papers.\n Search space: The NAS search space defines a set of operations (e.g. convolution, fully-connected, pooling) and how operations can be connected to form valid network architectures. The design of search space usually involves human expertise, as well as unavoidably human biases. Search algorithm: A NAS search algorithm samples a population of network architecture candidates. It receives the child model performance metrics as rewards (e.g. high accuracy, low latency) and optimizes to generate high-performance architecture candidates. Evaluation strategy: We need to measure, estimate, or predict the performance of a large number of proposed child models in order to obtain feedback for the search algorithm to learn. The process of candidate evaluation could be very expensive and many new methods have been proposed to save time or computation resources. Fig. 1. Three main components of Neural Architecture Search (NAS) models. (Image source: Elsken, et al. 2019 with customized annotation in red) Search Space The NAS search space defines a set of basic network operations and how operations can be connected to construct valid network architectures.\nSequential Layer-wise Operations The most naive way to design the search space for neural network architectures is to depict network topologies, either CNN or RNN, with a list of sequential layer-wise operations, as seen in the early work of Zoph \u0026amp; Le 2017 \u0026amp; Baker et al. 2017. The serialization of network representation requires a decent amount of expert knowledge, since each operation is associated with different layer-specific parameters and such associations need to be hardcoded. For example, after predicting a conv op, the model should output kernel size, stride size, etc; or after predicting an FC op, we need to see the number of units as the next prediction.\nFig. 2. (Top) A sequential representation of CNN. (Bottom) A sequential representation of the tree structure of a recurrent cell. (Image source: Zoph \u0026 Le 2017) To make sure the generated architecture is valid, additional rules might be needed (Zoph \u0026amp; Le 2017):\n If a layer is not connected to any input layer then it is used as the input layer; At the final layer, take all layer outputs that have not been connected and concatenate them; If one layer has many input layers, then all input layers are concatenated in the depth dimension; If input layers to be concatenated have different sizes, we pad the small layers with zeros so that the concatenated layers have the same sizes. The skip connection can be predicted as well, using an attention-style mechanism. At layer $i$ , an anchor point is added with $i−1$ content-based sigmoids to indicate which of the previous layers to be connected. Each sigmoid takes as input the hidden states of the current node $h_i$ and $i-1$ previous nodes $h_j, j=1, \\dots, i-1$ .\n $$ P(\\text{Layer j is an input to layer i}) = \\text{sigmoid}(v^\\top \\tanh(\\mathbf{W}_\\text{prev} h_j + \\mathbf{W}_\\text{curr} h_i)) $$ The sequential search space has a lot of representation power, but it is very large and consumes a ton of computation resources to exhaustively cover the search space. In the experiments by Zoph \u0026amp; Le 2017, they were running 800 GPUs in parallel for 28 days and Baker et al. 2017 restricted the search space to contain at most 2 FC layers.\nCell-based Representation Inspired by the design of using repeated modules in successful vision model architectures (e.g. Inception, ResNet), the NASNet search space (Zoph et al. 2018) defines the architecture of a conv net as the same cell getting repeated multiple times and each cell contains several operations predicted by the NAS algorithm. A well-designed cell module enables transferability between datasets. It is also easy to scale down or up the model size by adjusting the number of cell repeats.\nPrecisely, the NASNet search space learns two types of cells for network construction:\n Normal Cell: The input and output feature maps have the same dimension. Reduction Cell: The output feature map has its width and height reduced by half. Fig. 3. The NASNet search space constrains the architecture as a repeated stack of cells. The cell architecture is optimized via NAS algorithms. (Image source: Zoph et al. 2018) The predictions for each cell are grouped into $B$ blocks ($B=5$ in the NASNet paper), where each block has 5 prediction steps made by 5 distinct softmax classifiers corresponding to discrete choices of the elements of a block. Note that the NASNet search space does not have residual connections between cells and the model only learns skip connections on their own within blocks.\nFig. 4. (a) Each cell consists of $B$ blocks and each block is predicted by 5 discrete decisions. (b) An concrete example of what operations can be chosen in each decision step. During the experiments, they discovered that a modified version of DropPath, named ScheduledDropPath, significantly improves the final performance of NASNet experiments. DropPath stochastically drops out paths (i.e. edges with operations attached in NASNet) with a fixed probability. ScheduledDropPath is DropPath with a linearly increasing probability of path dropping during training time.\nElsken, et al (2019) point out three major advantages of the NASNet search space:\n The search space size is reduced drastically; The motif-based architecture can be more easily transferred to different datasets. It demonstrates a strong proof of a useful design pattern of repeatedly stacking modules in architecture engineering. For example, we can build strong models by stacking residual blocks in CNN or stacking multi-headed attention blocks in Transformer. Hierarchical Structure To take advantage of already discovered well-designed network motifs, the NAS search space can be constrained as a hierarchical structure, as in Hierarchical NAS (HNAS; (Liu et al 2017)). It starts with a small set of primitives, including individual operations like convolution operation, pooling, identity, etc. Then small sub-graphs (or \u0026ldquo;motifs\u0026rdquo;) that consist of primitive operations are recursively used to form higher-level computation graphs.\nA computation motif at level $\\ell=1, \\dots, L$ can be represented by $(G^{(\\ell)}, \\mathcal{O}^{(\\ell)})$, where:\n $\\mathcal{O}^{(\\ell)}$ is a set of operations, $\\mathcal{O}^{(\\ell)} = \\{ o^{(\\ell)}_1, o^{(\\ell)}_2, \\dots \\}$ $G^{(\\ell)}$ is an adjacency matrix, where the entry $G_{ij}=k$ indicates that operation $o^{(\\ell)}_k$ is placed between node $i$ and $j$. The node indices follow topological ordering in DAG, where the index $1$ is the source and the maximal index is the sink node. Fig. 5. (Top) Three level-1 primitive operations are composed into a level-2 motif. (Bottom) Three level-2 motifs are plugged into a base network structure and assembled into a level-3 motif. (Image source: Liu et al 2017) To build a network according to the hierarchical structure, we start from the lowest level $\\ell=1$ and recursively define the $m$-th motif operation at level $\\ell$ as\n $$ o^{(\\ell)}_m = \\text{assemble}\\Big( G_m^{(\\ell)}, \\mathcal{O}^{(\\ell-1)} \\Big) $$ A hierarchical representation becomes $\\Big( \\big\\{ \\{ G_m^{(\\ell)} \\}_{m=1}^{M_\\ell} \\big\\}_{\\ell=2}^L, \\mathcal{O}^{(1)} \\Big), \\forall \\ell=2, \\dots, L$, where $\\mathcal{O}^{(1)}$ contains a set of primitive operations.\nThe $\\text{assemble}()$ process is equivalent to sequentially compute the feature map of node $i$ by aggregating all the feature maps of its predecessor node $j$ following the topological ordering:\n $$ x_i = \\text{merge} \\big[ \\{ o^{(\\ell)}_{G^{(\\ell)}_{ij}}(x_j) \\}_{j where $\\text{merge}[]$ is implemented as depth-wise concatenation in the paper.\nSame as NASNet, experiments in Liu et al (2017) focused on discovering good cell architecture within a predefined \u0026ldquo;macro\u0026rdquo; structure with repeated modules. They showed that the power of simple search methods (e.g. random search or evolutionary algorithms) can be substantially enhanced using well-designed search spaces.\nCai et al (2018b) propose a tree-structure search space using path-level network transformation. Each node in a tree structure defines an allocation scheme for splitting inputs for child nodes and a merge scheme for combining results from child nodes. The path-level network transformation allows replacing a single layer with a multi-branch motif if its corresponding merge scheme is add or concat.\nFig. 6. An illustration of transforming a single layer to a tree-structured motif via path-level transformation operations. (Image source: Cai et al. 2018b) Memory-bank Representation A memory-bank representation of feed-forward networks is proposed by Brock et al. (2017) in SMASH. Instead of a graph of operations, they view a neural network as a system with multiple memory blocks which can read and write. Each layer operation is designed to: (1) read from a subset of memory blocks; (2) computes results; finally (3) write the results into another subset of blocks. For example, in a sequential model, a single memory block would get read and overwritten consistently.\nFig. 7. Memory-bank representation of several popular network architecture blocks. (Image source: Brock et al. 2017) Search Algorithms NAS search algorithms sample a population of child networks. It receives the child models' performance metrics as rewards and learns to generate high-performance architecture candidates. You may a lot in common with the field of hyperparameter search.\nRandom Search Random search is the most naive baseline. It samples a valid architecture candidate from the search space at random and no learning model is involved. Random search has proved to be quite useful in hyperparameter search (Bergstra \u0026amp; Bengio 2012). With a well-designed search space, random search could be a very challenging baseline to beat.\nReinforcement Learning The initial design of NAS (Zoph \u0026amp; Le 2017) involves a RL-based controller for proposing child model architectures for evaluation. The controller is implemented as a RNN, outputting a variable-length sequence of tokens used for configuring a network architecture.\nFig. 8. A high level overview of NAS, containing a RNN controller and a pipeline for evaluating child models. (Image source: Zoph \u0026 Le 2017) The controller is trained as a RL task using REINFORCE.\n Action space: The action space is a list of tokens for defining a child network predicted by the controller (See more in the above section). The controller outputs action, $a_{1:T}$, where $T$ is the total number of tokens. Reward: The accuracy of a child network that can be achieved at convergence is the reward for training the controller, $R$. Loss: NAS optimizes the controller parameters $\\theta$ with a REINFORCE loss. We want to maximize the expected reward (high accuracy) with the gradient as follows. The nice thing here with policy gradient is that it works even when the reward is non-differentiable. $$ \\nabla_{\\theta} J(\\theta) = \\sum_{t=1}^T \\mathbb{E}[\\nabla_{\\theta} \\log P(a_t \\vert a_{1:(t-1)}; \\theta) R ] $$ MetaQNN (Baker et al. 2017) trains an agent to sequentially choose CNN layers using Q-learning with an $\\epsilon$-greedy exploration strategy and experience replay. The reward is the validation accuracy as well.\n $$ Q^{(t+1)}(s_t, a_t) = (1 - \\alpha)Q^{(t)}(s_t, a_t) + \\alpha (R_t + \\gamma \\max_{a \\in \\mathcal{A}} Q^{(t)}(s_{t+1}, a')) $$ where a state $s_t$ is a tuple of layer operation and related parameters. An action $a$ determines the connectivity between operations. The Q-value is proportional to how confident we are in two connected operations leading to high accuracy.\nFig. 9. Overview of MetaQNN - designing CNN models with Q-Learning. (Image source: Baker et al. 2017) Evolutionary Algorithms NEAT (short for NeuroEvolution of Augmenting Topologies) is an approach for evolving neural network topologies with genetic algorithm (GA), proposed by Stanley \u0026amp; Miikkulainen in 2002. NEAT evolves both connection weights and network topology together. Each gene encodes the full information for configuring a network, including node weights and edges. The population grows by applying mutation of both weights and connections, as well as crossover between two parent genes. For more in neuroevolution, please refer to the in-depth survey by Stanley et al. (2019).\nFig. 10. Mutations in the NEAT algorithm. (Image source: Fig 3 \u0026 4 in Stanley \u0026 Miikkulainen, 2002) Real et al. (2018) adopt the evolutionary algorithms (EA) as a way to search for high-performance network architectures, named AmoebaNet. They apply the tournament selection method, which at each iteration picks a best candidate out of a random set of samples and places its mutated offspring back into the population. When the tournament size is $1$, it is equivalent to random selection.\nAmoebaNet modified the tournament selection to favor younger genotypes and always discard the oldest models within each cycle. Such an approach, named aging evolution, allows AmoebaNet to cover and explore more search space, rather than to narrow down on good performance models too early.\nPrecisely, in every cycle of the tournament selection with aging regularization (See Figure 11):\n Sample $S$ models from the population and the one with highest accuracy is chosen as parent. A child model is produced by mutating parent. Then the child model is trained, evaluated and added back into the population. The oldest model is removed from the population. Fig. 11. The algorithm of aging evolution. (Image source: Real et al. 2018) Two types of mutations are applied:\n Hidden state mutation: randomly chooses a pairwise combination and rewires a random end such that there is no loop in the graph. Operation mutation: randomly replaces an existing operation with a random one. Fig. 12. Two types of mutations in AmoebaNet. (Image source: Real et al. 2018) In their experiments, EA and RL work equally well in terms of the final validation accuracy, but EA has better anytime performance and is able to find smaller models. Here using EA in NAS is still expensive in terms of computation, as each experiment took 7 days with 450 GPUs.\nHNAS (Liu et al 2017) also employs the evolutionary algorithms (the original tournament selection) as their search strategy. In the hierarchical structure search space, each edge is an operation. Thus genotype mutation in their experiments is applied by replacing a random edge with a different operation. The replacement set includes an none op, so it can alter, remove and add an edge. The initial set of genotypes is created by applying a large number of random mutations on \u0026ldquo;trivial\u0026rdquo; motifs (all identity mappings).\nProgressive Decision Process Constructing a model architecture is a sequential process. Every additional operator or layer brings extra complexity. If we guide the search model to start the investigation from simple models and gradually evolve to more complex architectures, it is like to introduce \u0026ldquo;curriculum\u0026rdquo; into the search model\u0026rsquo;s learning process.\nProgressive NAS (PNAS; Liu, et al 2018) frames the problem of NAS as a progressive procedure for searching models of increasing complexity. Instead of RL or EA, PNAS adopts a Sequential Model-based Bayesian Optimization (SMBO) as the search strategy. PNAS works similar to A* search, as it searches for models from simple to hard while simultaneously learning a surrogate function to guide the search.\n A* search algorithm (\u0026ldquo;best-first search\u0026rdquo;) is a popular algorithm for path finding. The problem is framed as finding a path of smallest cost from a specific starting node to a given target node in a weighted graph. At each iteration, A* finds a path to extend by minimizing: $f(n)=g(n)+h(n)$, where $n$ is the next node, $g(n)$ is the cost from start to $n$, and $h(n)$ is the heuristic function that estimates the minimum cost of going from node $n$ to the goal.\n PNAS uses the NASNet search space. Each block is specified as a 5-element tuple and PNAS only considers the element-wise addition as the step 5 combination operator, no concatenation. Differently, instead of setting the number of blocks $B$ at a fixed number, PNAS starts with $B=1$, a model with only one block in a cell, and gradually increases $B$.\nThe performance on a validation set is used as feedback to train a surrogate model for predicting the performance of novel architectures. With this predictor, we can thus decide which models should be prioritized to be evaluated next. Since the performance predictor should be able to handle various-sized inputs, accuracy, and sample-efficient, they ended up using an RNN model.\nFig. 13. The algorithm of Progressive NAS. (Image source: Liu, et al 2018) Gradient descent Using gradient descent to update the architecture search model requires an effort to make the process of choosing discrete operations differentiable. These approaches usually combine the learning of both architecture parameters and network weights together into one model. See more in the section on the \u0026ldquo;one-shot\u0026rdquo; approach.\nEvaluation Strategy We need to measure, estimate or predict the performance of every child model in order to obtain feedback for optimizing the search algorithm. The process of candidate evaluation could be very expensive and many new evaluation methods have been proposed to save time or computation. When evaluating a child model, we mostly care about its performance measured as accuracy on a validation set. Recent work has started looking into other factors of a model, such as model size and latency, as certain devices may have limitations on memory or demand fast response time.\nTraining from Scratch The most naive approach is to train every child network independently from scratch until convergence and then measure its accuracy on a validation set (Zoph \u0026amp; Le 2017). It provides solid performance numbers, but one complete train-converge-evaluate loop only generates a single data sample for training the RL controller (let alone RL is known to be sample-inefficient in general). Thus it is very expensive in terms of computation consumption.\nProxy Task Performance There are several approaches for using a proxy task performance as the performance estimator of a child network, which is generally cheaper and faster to calculate:\n Train on a smaller dataset. Train for fewer epochs. Train and evaluate a down-scaled model in the search stage. For example, once a cell structure is learned, we can play with the number of cell repeats or scale up the number of filters (Zoph et al. 2018). Predict the learning curve. Baker et al (2018) model the prediction of validation accuracies as a time-series regression problem. The features for the regression model ($\\nu$-support vector machine regressions; $\\nu$-SVR) include the early sequences of accuracy per epoch, architecture parameters, and hyperparameters. Parameter Sharing Instead of training every child model independently from scratch. You may ask, ok, what if we fabricate dependency between them and find a way to reuse weights? Some researchers succeeded to make such approaches work.\nInspired by Net2net transformation, Cai et al (2017) proposed Efficient Architecture Search (EAS). EAS sets up an RL agent, known as a meta-controller, to predict function-preserving network transformation so as to grow the network depth or layer width. Because the network is growing incrementally, the weights of previously validated networks can be reused for further exploration. With inherited weights, newly constructed networks only need some light-weighted training.\nA meta-controller learns to generate network transformation actions given the current network architecture, which is specified with a variable-length string. In order to handle architecture configuration of a variable length, the meta-controller is implemented as a bi-directional recurrent network. Multiple actor networks output different transformation decisions:\n Net2WiderNet operation allows to replace a layer with a wider layer, meaning more units for fully-connected layers, or more filters for convolutional layers, while preserving the functionality. Net2DeeperNet operation allows to insert a new layer that is initialized as adding an identity mapping between two layers so as to preserve the functionality. Fig. 14. Overview of the RL based meta-controller in Efficient Architecture Search (NAS). After encoding the architecture configuration, it outputs net2net transformation actions through two separate actor networks. (Image source: Cai et al 2017) With similar motivation, Efficient NAS (ENAS; Pham et al. 2018) speeds up NAS (i.e. 1000x less) by aggressively sharing parameters among child models. The core motivation behind ENAS is the observation that all of the sampled architecture graphs can be viewed as sub-graphs of a larger supergraph. All the child networks are sharing weights of this supergraph.\nFig. 15. (Left) The graph represents the entire search space for a 4-node recurrent cell, but only connections in red are active. (Middle) An example of how the left active sub-graph can be translated into a child model architecture. (Right) The network parameters produced by an RNN controller for the architecture in the middle. (Image source: Pham et al. 2018) ENAS alternates between training the shared model weights $\\omega$ and training the controller $\\theta$:\n The parameters of the controller LSTM $\\theta$ are trained with REINFORCE, where the reward $R(\\mathbf{m}, \\omega)$ is computed on the validation set. The shared parameters of the child models $\\omega$ are trained with standard supervised learning loss. Note that different operators associated with the same node in the supergraph would have their own distinct parameters. Prediction-Based A routine child model evaluation loop is to update model weights via standard gradient descent. SMASH (Brock et al. 2017) proposes a different and interesting idea: Can we predict the model weights directly based on the network architecture parameters?\nThey employ a HyperNet (Ha et al 2016) to directly generate the weights of a model conditioned on an encoding of its architecture configuration. Then the model with HyperNet-generated weights is validated directly. Note that we don\u0026rsquo;t need extra training for every child model but we do need to train the HyperNet.\nFig. 16. The algorithm of SMASH. (Image source: Brock et al. 2017) The correlation between model performance with SMASH-generated weights and true validation errors suggests that predicted weights can be used for model comparison, to some extent. We do need a HyperNet of large enough capacity, as the correlation would be corrupted if the HyperNet model is too small compared to the child model size.\nFig. 17. The algorithm of SMASH. (Image source: Brock et al. 2017) SMASH can be viewed as another way to implement the idea of parameter sharing. One problem of SMASH as pointed out by Pham et al. (2018) is: The usage of HyperNet restricts the weights of SMASH child models to a low-rank space, because weights are generated via tensor products. In comparison, ENAS has no such restrictions.\nOne-Shot Approach: Search + Evaluation Running search \u0026amp; evaluation independently for a large population of child models is expensive. We have seen promising approaches like Brock et al. (2017) or Pham et al. (2018), where training a single model is enough for emulating any child model in the search space.\nThe one-shot architecture search extends the idea of weight sharing and further combines the learning of architecture generation together with weight parameters. The following approaches all treat child architectures as different sub-graphs of a supergraph with shared weights between common edges in the supergraph.\nBender et al (2018) construct a single large over-parameterized network, known as the One-Shot model, such that it contains every possible operation in the search space. With ScheduledDropPath (the dropout rate is increased over time, which is $r^{1/k}$ at the end of training, where $0 \u0026lt; r \u0026lt; 1$ is a hyperparam and $k$ is the number of incoming paths) and some carefully designed tricks (e.g. ghost batch normalization, L2 regularization only on the active architecture), the training of such a giant model can be stabilized enough and used for evaluating any child model sampled from the supergraph.\nFig. 18. The architecture of the One-Shot model in Bender et al 2018. Each cell has $N$ choice blocks and each choice block can select up to 2 operations. Solid edges are used in every architecture, where dash lines are optional. (Image source: Bender et al 2018) Once the one-shot model is trained, it is used for evaluating the performance of many different architectures sampled at random by zeroing out or removing some operations. This sampling process can be replaced by RL or evolution.\nThey observed that the difference between the accuracy measured with the one-shot model and the accuracy of the same architecture after a small fine-tuning could be very large. Their hypothesis is that the one-shot model automatically learns to focus on the most useful operations in the network and comes to rely on these operations when they are available. Thus zeroing out useful operations lead to big reduction in model accuracy, while removing less important components only causes a small impact \u0026mdash; Therefore, we see a larger variance in scores when using the one-shot model for evaluation.\nFig. 19. A stratified sample of models with different one-shot model accuracy versus their true validation accuracy as stand-alone models. (Image source: Bender et al 2018) Clearly designing such a search graph is not a trivial task, but it demonstrates a strong potential with the one-shot approach. It works well with only gradient descent and no additional algorithm like RL or EA is a must.\nSome believe that one main cause for inefficiency in NAS is to treat the architecture search as a black-box optimization and thus we fall into methods like RL, evolution, SMBO, etc. If we shift to rely on standard gradient descent, we could potentially make the search process more effectively. As a result, Liu et al (2019) propose Differentiable Architecture Search (DARTS). DARTS introduces a continuous relaxation on each path in the search supergraph, making it possible to jointly train architecture parameters and weights via gradient descent.\nLet\u0026rsquo;s use the directed acyclic graph (DAG) representation here. A cell is a DAG consisting of a topologically ordered sequence of $N$ nodes. Each node has a latent representation $x_i$ to be learned. Each edge $(i, j)$ is tied to some operation $o^{(i,j)} \\in \\mathcal{O}$ that transforms $x_j$ to compose $x_i$:\n $$ x_i = \\sum_{j To make the search space continuous, DARTS relaxes the categorical choice of a particular operation as a softmax over all the operations and the task of architecture search is reduced to learn a set of mixing probabilities $\\alpha = \\{ \\alpha^{(i,j)} \\}$.\n $$ \\bar{o}^{(i,j)}(x) = \\sum_{o\\in\\mathcal{O}} \\frac{\\exp(\\alpha_{ij}^o)}{\\sum_{o'\\in\\mathcal{O}} \\exp(\\alpha^{o'}_{ij})} o(x) $$ where $\\alpha_{ij}$ is a vector of dimension $\\vert \\mathcal{O} \\vert$, containing weights between nodes $i$ and $j$ over different operations.\nThe bilevel optimization exists as we want to optimize both the network weights $w$ and the architecture representation $\\alpha$:\n $$ \\begin{aligned} \\min_\\alpha \u0026 \\mathcal{L}_\\text{validate} (w^*(\\alpha), \\alpha) \\\\ \\text{s.t.} \u0026 w^*(\\alpha) = \\arg\\min_w \\mathcal{L}_\\text{train} (w, \\alpha) \\end{aligned} $$ At step $k$, given the current architecture parameters $\\alpha_{k−1}$, we first optimize weights $w_k$ by moving $w_{k−1}$ in the direction of minimizing the training loss $\\mathcal{L}_\\text{train}(w_{k−1}, \\alpha_{k−1})$ with a learning rate $\\xi$. Next, while keeping the newly updated weights $w_k$ fixed, we update the mixing probabilities so as to minimize the validation loss after a single step of gradient descent w.r.t. the weights:\n $$ J_\\alpha = \\mathcal{L}_\\text{val}(w_k - \\xi \\nabla_w \\mathcal{L}_\\text{train}(w_k, \\alpha_{k-1}), \\alpha_{k-1}) $$ The motivation here is that we want to find an architecture with a low validation loss when its weights are optimized by gradient descent and the one-step unrolled weights serve as the surrogate for $w^∗(\\alpha)$.\n Side note: Earlier we have seen similar formulation in MAML where the two-step optimization happens between task losses and the meta-learner update, as well as framing Domain Randomization as a bilevel optimization for better transfer in the real environment.\n Fig. 20. An illustration of how DARTS applies continuous relaxation on edges in DAG supergraph and identifies the final model. (Image source: Liu et al 2019) $$ \\begin{aligned} \\text{Let }w'_k \u0026= w_k - \\xi \\nabla_w \\mathcal{L}_\\text{train}(w_k, \\alpha_{k-1}) \u0026 \\\\ J_\\alpha \u0026= \\mathcal{L}_\\text{val}(w_k - \\xi \\nabla_w \\mathcal{L}_\\text{train}(w_k, \\alpha_{k-1}), \\alpha_{k-1}) = \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1}) \u0026 \\\\ \\nabla_\\alpha J_\\alpha \u0026= \\nabla_{\\alpha_{k-1}} \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1}) \\nabla_\\alpha \\alpha_{k-1} + \\nabla_{w'_k} \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1})\\nabla_\\alpha w'_k \u0026 \\\\\u0026 \\text{; multivariable chain rule}\\\\ \u0026= \\nabla_{\\alpha_{k-1}} \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1}) + \\nabla_{w'_k} \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1}) \\big( - \\xi \\color{red}{\\nabla^2_{\\alpha, w} \\mathcal{L}_\\text{train}(w_k, \\alpha_{k-1})} \\big) \u0026 \\\\ \u0026\\approx \\nabla_{\\alpha_{k-1}} \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1}) - \\xi \\nabla_{w'_k} \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1}) \\color{red}{\\frac{\\nabla_\\alpha \\mathcal{L}_\\text{train}(w_k^+, \\alpha_{k-1}) - \\nabla_\\alpha \\mathcal{L}_\\text{train}(w_k^-, \\alpha_{k-1}) }{2\\epsilon}} \u0026 \\\\ \u0026 \\text{; apply numerical differentiation approximation} \\end{aligned} $$ where the red part is using numerical differentiation approximation where $w_k^+ = w_k + \\epsilon \\nabla_{w'_k} \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1})$ and $w_k^- = w_k - \\epsilon \\nabla_{w'_k} \\mathcal{L}_\\text{val}(w'_k, \\alpha_{k-1})$.\nFig. 21. The algorithm overview of DARTS. (Image source: Liu et al 2019) As another idea similar to DARTS, Stochastic NAS (Xie et al., 2019) applies a continuous relaxation by employing the concrete distribution (CONCRETE = CONtinuous relaxations of disCRETE random variables; Maddison et al 2017) and reparametrization tricks. The goal is same as DARTS, to make the discrete distribution differentiable and thus enable optimization by gradient descent.\nDARTS is able to greatly reduce the cost of GPU hours. Their experiments for searching for CNN cells have $N=7$ and only took 1.5 days with a single GPU. However, it suffers from the high GPU memory consumption issue due to its continuous representation of network architecture. In order to fit the model into the memory of a single GPU, they picked a small $N$.\nTo constrain the GPU memory consumption, ProxylessNAS (Cai et al., 2019) views NAS as a path-level pruning process in DAG and binarizes the architecture parameters to force only one path to be active between two nodes at a time. The probabilities for an edge being either masked out or not are then learned by sampling a few binarized architectures and using BinaryConnect (Courbariaux et al., 2015) to update the corresponding probabilities. ProxylessNAS demonstrates a strong connection between NAS and model compression. By using path-level compression, it is able to save memory consumption by one order of magnitude.\nLet\u0026rsquo;s continue with the graph representation. In a DAG adjacency matrix $G$ where $G_{ij}$ represents an edge between node $i$ and $j$ and its value can be chosen from the set of $\\vert \\mathcal{O} \\vert$ candidate primitive operations, $\\mathcal{O} = \\{ o_1, \\dots \\}$. The One-Shot model, DARTS and ProxylessNAS all consider each edge as a mixture of operations, $m_\\mathcal{O}$, but with different tweaks.\nIn One-Shot, $m_\\mathcal{O}(x)$ is the sum of all the operations. In DARTS, it is a weighted sum where weights are softmax over a real-valued architecture weighting vector $\\alpha$ of length $\\vert \\mathcal{O} \\vert$. ProxylessNAS transforms the softmax probabilities of $\\alpha$ into a binary gate and uses the binary gate to keep only one operation active at a time.\n $$ \\begin{aligned} m^\\text{one-shot}_\\mathcal{O}(x) \u0026= \\sum_{i=1}^{\\vert \\mathcal{O} \\vert} o_i(x) \\\\ m^\\text{DARTS}_\\mathcal{O}(x) \u0026= \\sum_{i=1}^{\\vert \\mathcal{O} \\vert} p_i o_i(x) = \\sum_{i=1}^{\\vert \\mathcal{O} \\vert} \\frac{\\exp(\\alpha_i)}{\\sum_j \\exp(\\alpha_j)} o_i(x) \\\\ m^\\text{binary}_\\mathcal{O}(x) \u0026= \\sum_{i=1}^{\\vert \\mathcal{O} \\vert} g_i o_i(x) = \\begin{cases} o_1(x) \u0026 \\text{with probability }p_1, \\\\ \\dots \u0026\\\\ o_{\\vert \\mathcal{O} \\vert}(x) \u0026 \\text{with probability }p_{\\vert \\mathcal{O} \\vert} \\end{cases} \\\\ \\text{ where } g \u0026= \\text{binarize}(p_1, \\dots, p_N) = \\begin{cases} [1, 0, \\dots, 0] \u0026 \\text{with probability }p_1, \\\\ \\dots \u0026 \\\\ [0, 0, \\dots, 1] \u0026 \\text{with probability }p_N. \\\\ \\end{cases} \\end{aligned} $$ Fig. 22. ProxylessNAS has two training steps running alternatively. (Image source: Cai et al., 2019) ProxylessNAS runs two training steps alternatively:\n When training weight parameters $w$, it freezes the architecture parameters $\\alpha$ and stochastically samples binary gates $g$ according to the above $m^\\text{binary}_\\mathcal{O}(x)$. The weight parameters can be updated with standard gradient descent. When training architecture parameters $\\alpha$, it freezes $w$, resets the binary gates and then updates $\\alpha$ on the validation set. Following the idea of BinaryConnect, the gradient w.r.t. architecture parameters can be approximately estimated using $\\partial \\mathcal{L} / \\partial g_i$ in replacement for $\\partial \\mathcal{L} / \\partial p_i$: $$ \\begin{aligned} \\frac{\\partial \\mathcal{L}}{\\partial \\alpha_i} \u0026= \\sum_{j=1}^{\\vert \\mathcal{O} \\vert} \\frac{\\partial \\mathcal{L}}{\\partial p_j} \\frac{\\partial p_j}{\\partial \\alpha_i} \\approx \\sum_{j=1}^{\\vert \\mathcal{O} \\vert} \\frac{\\partial \\mathcal{L}}{\\partial g_j} \\frac{\\partial p_j}{\\partial \\alpha_i} = \\sum_{j=1}^{\\vert \\mathcal{O} \\vert} \\frac{\\partial \\mathcal{L}}{\\partial g_j} \\frac{\\partial \\frac{e^{\\alpha_j}}{\\sum_k e^{\\alpha_k}}}{\\partial \\alpha_i} \\\\ \u0026= \\sum_{j=1}^{\\vert \\mathcal{O} \\vert} \\frac{\\partial \\mathcal{L}}{\\partial g_j} \\frac{\\sum_k e^{\\alpha_k} (\\mathbf{1}_{i=j} e^{\\alpha_j}) - e^{\\alpha_j} e^{\\alpha_i} }{(\\sum_k e^{\\alpha_k})^2} = \\sum_{j=1}^{\\vert \\mathcal{O} \\vert} \\frac{\\partial \\mathcal{L}}{\\partial g_j} p_j (\\mathbf{1}_{i=j} -p_i) \\end{aligned} $$ Instead of BinaryConnect, REINFORCE can also be used for parameter updates with the goal for maximizing the reward, while no RNN meta-controller is involved.\nComputing $\\partial \\mathcal{L} / \\partial g_i$ needs to calculate and store $o_i(x)$, which requires $\\vert \\mathcal{O} \\vert$ times GPU memory. To resolve this issue, they factorize the task of choosing one path out of $N$ into multiple binary selection tasks (Intuition: \u0026ldquo;if a path is the best choice, it should be better than any other path\u0026rdquo;). At every update step, only two paths are sampled while others are masked. These two selected paths are updated according to the above equation and then scaled properly so that other path weights are unchanged. After this process, one of the sampled paths is enhanced (path weight increases) and the other is attenuated (path weight decreases), while all other paths stay unaltered.\nBesides accuracy, ProxylessNAS also considers latency as an important metric to optimize, as different devices might have very different requirements on inference time latency (e.g. GPU, CPU, mobile). To make latency differentiable, they model latency as a continuous function of the network dimensions. The expected latency of a mixed operation can be written as $\\mathbb{E}[\\text{latency}] = \\sum_j p_j F(o_j)$, where $F(.)$ is a latency prediction model:\nFig. 23. Add a differentiable latency loss into the training of ProxylessNAS. (Image source: Cai et al., 2019) What\u0026rsquo;s the Future? So far we have seen many interesting new ideas on automating the network architecture engineering through neural architecture search and many have achieved very impressive performance. However, it is a bit hard to do inference on why some architecture work well and how we can develop modules generalizable across tasks rather than being very dataset-specific.\nAs also noted in Elsken, et al (2019):\n \u0026ldquo;\u0026hellip;, so far it provides little insights into why specific architectures work well and how similar the architectures derived in independent runs would be. Identifying common motifs, providing an understanding why those motifs are important for high performance, and investigating if these motifs generalize over different problems would be desirable.\u0026rdquo;\n In the meantime, purely focusing on improvement over validation accuracy might not be enough (Cai et al., 2019). Devices like mobile phones for daily usage in general have limited memory and computation power. While AI applications are on the way to affect our daily life, it is unavoidable to be more device-specific.\nAnother interesting investigation is to consider unlabelled dataset and self-supervised learning for NAS. The size of labelled dataset is always limited and it is not easy to tell whether such a dataset has biases or big deviation from the real world data distribution.\nLiu et al (2020) delve into the question \u0026ldquo;Can we find high-quality neural architecture without human-annotated labels?\u0026quot; and proposed a new setup called Unsupervised Neural Architecture Search (UnNAS). The quality of the architecture needs to be estimated in an unsupervised fashion during the search phase. The paper experimented with three unsupervised pretext tasks: image rotation prediction, colorization, and solving the jigsaw puzzle.\nThey observed in a set of UnNAS experiments that:\n High rank correlation between supervised accuracy and pretext accuracy on the same dataset. Typically the rank correlation is higher than 0.8, regardless of the dataset, the search space, and the pretext task. High rank correlation between supervised accuracy and pretext accuracy across datasets. Better pretext accuracy translates to better supervised accuracy. Performance of UnNAS architecture is comparable to supervised counterparts, though not better yet. One hypothesis is that the architecture quality is correlated with image statistics. Because CIFAR-10 and ImageNet are all on the natural images, they are comparable and the results are transferable. UnNAS could potentially enable a much larger amount of unlabelled data into the search phase which captures image statistics better.\nHyperparameter search is a long-standing topic in the ML community. And NAS automates architecture engineering. Gradually we are trying to automate processes in ML which usually demand a lot of human efforts. Taking even one more step further, is it possible to automatically discover ML algorithms? AutoML-Zero (Real et al 2020) investigates this idea. Using aging evolutionary algorithms, AutoML-Zero automatically searches for whole ML algorithms using little restriction on the form with only simple mathematical operations as building blocks.\nIt learns three component functions. Each function only adopts very basic operations.\n Setup: initialize memory variables (weights). Learn: modify memory variables Predict: make a prediction from an input $x$. Fig. 24. Algorithm evaluation on one task (Image source: Real et al 2020) Three types of operations are considered when mutating a parent genotype:\n Insert a random instruction or remove an instruction at a random location in a component function; Randomize all the instructions in a component function; Modify one of the arguments of an instruction by replacing it with a random choice (e.g. \u0026ldquo;swap the output address\u0026rdquo; or \u0026ldquo;change the value of a constant\u0026rdquo;) Fig. 25. An illustration of evolutionary progress on projected binary CIFAR-10 with example code. (Image source: Real et al 2020) Appendix: Summary of NAS Papers Model name Search space Search algorithms Child model evaluation NEAT (2002) - Evolution (Genetic algorithm) - NAS (2017) Sequential layer-wise ops RL (REINFORCE) Train from scratch until convergence MetaQNN (2017) Sequential layer-wise ops RL (Q-learning with $\\epsilon$-greedy) Train for 20 epochs HNAS (2017) Hierarchical structure Evolution (Tournament selection) Train for a fixed number of iterations NASNet (2018) Cell-based RL (PPO) Train for 20 epochs AmoebaNet (2018) NASNet search space Evolution (Tournament selection with aging regularization) Train for 25 epochs EAS (2018a) Network transformation RL (REINFORCE) 2-stage training PNAS (2018) Reduced version of NASNet search space SMBO; Progressive search for architectures of increasing complexity Train for 20 epochs ENAS (2018) Both sequential and cell-based search space RL (REINFORCE) Train one model with shared weights SMASH (2017) Memory-bank representation Random search HyperNet predicts weights of evaluated architectures. One-Shot (2018) An over-parameterized one-shot model Random search (zero out some paths at random) Train the one-shot model DARTS (2019) NASNet search space Gradient descent (Softmax weights over operations) ProxylessNAS (2019) Tree structure architecture Gradient descent (BinaryConnect) or REINFORCE SNAS (2019) NASNet search space Gradient descent (concrete distribution) Citation Cited as:\n Weng, Lilian. (Aug 2020). Neural architecture search. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2020-08-06-nas/.\n Or\n@article{weng2020nas, title = \u0026quot;Neural Architecture Search\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2020\u0026quot;, month = \u0026quot;Aug\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2020-08-06-nas/\u0026quot; } Reference [1] Thomas Elsken, Jan Hendrik Metzen, Frank Hutter. \u0026ldquo;Neural Architecture Search: A Survey\u0026rdquo; JMLR 20 (2019) 1-21.\n[2] Kenneth O. Stanley, et al. \u0026ldquo;Designing neural networks through neuroevolution\u0026rdquo; Nature Machine Intelligence volume 1, pages 24–35 (2019).\n[3] Kenneth O. Stanley \u0026amp; Risto Miikkulainen. \u0026ldquo;Evolving Neural Networks through Augmenting Topologies\u0026rdquo; Evolutionary Computation 10(2): 99-127 (2002).\n[4] Barret Zoph, Quoc V. Le. \u0026ldquo;Neural architecture search with reinforcement learning\u0026rdquo; ICLR 2017.\n[5] Bowen Baker, et al. \u0026ldquo;Designing Neural Network Architectures using Reinforcement Learning\u0026rdquo; ICLR 2017.\n[6] Bowen Baker, et al. \u0026ldquo;Accelerating neural architecture search using performance prediction\u0026rdquo; ICLR Workshop 2018.\n[7] Barret Zoph, et al. \u0026ldquo;Learning transferable architectures for scalable image recognition\u0026rdquo; CVPR 2018.\n[8] Hanxiao Liu, et al. \u0026ldquo;Hierarchical representations for efficient architecture search.\u0026quot; ICLR 2018.\n[9] Esteban Real, et al. \u0026ldquo;Regularized Evolution for Image Classifier Architecture Search\u0026rdquo; arXiv:1802.01548 (2018).\n[10] Han Cai, et al. [\u0026ldquo;Efficient architecture search by network transformation\u0026rdquo;] AAAI 2018a.\n[11] Han Cai, et al. \u0026ldquo;Path-Level Network Transformation for Efficient Architecture Search\u0026rdquo; ICML 2018b.\n[12] Han Cai, Ligeng Zhu \u0026amp; Song Han. \u0026ldquo;ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware\u0026rdquo; ICLR 2019.\n[13] Chenxi Liu, et al. \u0026ldquo;Progressive neural architecture search\u0026rdquo; ECCV 2018.\n[14] Hieu Pham, et al. \u0026ldquo;Efficient neural architecture search via parameter sharing\u0026rdquo; ICML 2018.\n[15] Andrew Brock, et al. \u0026ldquo;SMASH: One-shot model architecture search through hypernetworks.\u0026quot; ICLR 2018.\n[16] Gabriel Bender, et al. \u0026ldquo;Understanding and simplifying one-shot architecture search.\u0026quot; ICML 2018.\n[17] Hanxiao Liu, Karen Simonyan, Yiming Yang. \u0026ldquo;DARTS: Differentiable Architecture Search\u0026rdquo; ICLR 2019.\n[18] Sirui Xie, Hehui Zheng, Chunxiao Liu, Liang Lin. \u0026ldquo;SNAS: Stochastic Neural Architecture Search\u0026rdquo; ICLR 2019.\n[19] Chenxi Liu et al. \u0026ldquo;Are Labels Necessary for Neural Architecture Search?\u0026quot; ECCV 2020.\n[20] Esteban Real, et al. \u0026ldquo;AutoML-Zero: Evolving Machine Learning Algorithms From Scratch\u0026rdquo; ICML 2020.\n","permalink":"https://lilianweng.github.io/posts/2020-08-06-nas/","summary":"Although most popular and successful model architectures are designed by human experts, it doesn\u0026rsquo;t mean we have explored the entire network architecture space and settled down with the best option. We would have a better chance to find the optimal solution if we adopt a systematic and automatic way of learning high-performance model architectures.\nAutomatically learning and evolving network topologies is not a new idea (Stanley \u0026amp; Miikkulainen, 2002). In recent years, the pioneering work by Zoph \u0026amp; Le 2017 and Baker et al.","title":"Neural Architecture Search"},{"content":"[Updated on 2020-06-17: Add \u0026ldquo;exploration via disagreement\u0026rdquo; in the \u0026ldquo;Forward Dynamics\u0026rdquo; section.\nExploitation versus exploration is a critical topic in Reinforcement Learning. We\u0026rsquo;d like the RL agent to find the best solution as fast as possible. However, in the meantime, committing to solutions too quickly without enough exploration sounds pretty bad, as it could lead to local minima or total failure. Modern RL algorithms that optimize for the best returns can achieve good exploitation quite efficiently, while exploration remains more like an open topic.\nI would like to discuss several common exploration strategies in Deep RL here. As this is a very big topic, my post by no means can cover all the important subtopics. I plan to update it periodically and keep further enriching the content gradually in time.\nClassic Exploration Strategies As a quick recap, let\u0026rsquo;s first go through several classic exploration algorithms that work out pretty well in the multi-armed bandit problem or simple tabular RL.\n Epsilon-greedy: The agent does random exploration occasionally with probability $\\epsilon$ and takes the optimal action most of the time with probability $1-\\epsilon$. Upper confidence bounds: The agent selects the greediest action to maximize the upper confidence bound $\\hat{Q}_t(a) + \\hat{U}_t(a)$, where $\\hat{Q}_t(a)$ is the average rewards associated with action $a$ up to time $t$ and $\\hat{U}_t(a)$ is a function reversely proportional to how many times action $a$ has been taken. See here for more details. Boltzmann exploration: The agent draws actions from a boltzmann distribution (softmax) over the learned Q values, regulated by a temperature parameter $\\tau$. Thompson sampling: The agent keeps track of a belief over the probability of optimal actions and samples from this distribution. See here for more details. The following strategies could be used for better exploration in deep RL training when neural networks are used for function approximation:\n Entropy loss term: Add an entropy term $H(\\pi(a \\vert s))$ into the loss function, encouraging the policy to take diverse actions. Noise-based Exploration: Add noise into the observation, action or even parameter space (Fortunato, et al. 2017, Plappert, et al. 2017). Key Exploration Problems Good exploration becomes especially hard when the environment rarely provides rewards as feedback or the environment has distracting noise. Many exploration strategies are proposed to solve one or both of the following problems.\nThe Hard-Exploration Problem The \u0026ldquo;hard-exploration\u0026rdquo; problem refers to exploration in an environment with very sparse or even deceptive reward. It is difficult because random exploration in such scenarios can rarely discover successful states or obtain meaningful feedback.\nMontezuma\u0026rsquo;s Revenge is a concrete example for the hard-exploration problem. It remains as a few challenging games in Atari for DRL to solve. Many papers use Montezuma\u0026rsquo;s Revenge to benchmark their results.\nThe Noisy-TV Problem The \u0026ldquo;Noisy-TV\u0026rdquo; problem started as a thought experiment in Burda, et al (2018). Imagine that an RL agent is rewarded with seeking novel experience, a TV with uncontrollable \u0026amp; unpredictable random noise outputs would be able to attract the agent\u0026rsquo;s attention forever. The agent obtains new rewards from noisy TV consistently, but it fails to make any meaningful progress and becomes a \u0026ldquo;couch potato\u0026rdquo;.\nFig. 1. An agent is rewarded with novel experience in the experiment. If a maze has a noisy TC set up, the agent would be attracted and stop moving in the maze. (Image source: OpenAI Blog: \"Reinforcement Learning with Prediction-Based Rewards\") Intrinsic Rewards as Exploration Bonuses One common approach to better exploration, especially for solving the hard-exploration problem, is to augment the environment reward with an additional bonus signal to encourage extra exploration. The policy is thus trained with a reward composed of two terms, $r_t = r^e_t + \\beta r^i_t$, where $\\beta$ is a hyperparameter adjusting the balance between exploitation and exploration.\n $r^e_t$ is an extrinsic reward from the environment at time $t$, defined according to the task in hand. $r^i_t$ is an intrinsic exploration bonus at time $t$. This intrinsic reward is somewhat inspired by intrinsic motivation in psychology (Oudeyer \u0026amp; Kaplan, 2008). Exploration driven by curiosity might be an important way for children to grow and learn. In other words, exploratory activities should be rewarding intrinsically in the human mind to encourage such behavior. The intrinsic rewards could be correlated with curiosity, surprise, familiarity of the state, and many other factors.\nSame ideas can be applied to RL algorithms. In the following sections, methods of bonus-based exploration rewards are roughly grouped into two categories:\n Discovery of novel states Improvement of the agent\u0026rsquo;s knowledge about the environment. Count-based Exploration If we consider intrinsic rewards as rewarding conditions that surprise us, we need a way to measure whether a state is novel or appears often. One intuitive way is to count how many times a state has been encountered and to assign a bonus accordingly. The bonus guides the agent\u0026rsquo;s behavior to prefer rarely visited states to common states. This is known as the count-based exploration method.\nLet $N_n(s)$ be the empirical count function that tracks the real number of visits of a state $s$ in the sequence of $s_{1:n}$. Unfortunately, using $N_n(s)$ for exploration directly is not practical, because most of the states would have $N_n(s)=0$, especially considering that the state space is often continuous or high-dimensional. We need an non-zero count for most states, even when they haven\u0026rsquo;t been seen before.\nCounting by Density Model Bellemare, et al. (2016) used a density model to approximate the frequency of state visits and a novel algorithm for deriving a pseudo-count from this density model. Let\u0026rsquo;s first define a conditional probability over the state space, $\\rho_n(s) = \\rho(s \\vert s_{1:n})$ as the probability of the $(n+1)$-th state being $s$ given the first $n$ states are $s_{1:n}$. To measure this empirically, we can simply use $N_n(s)/n$.\nLet\u0026rsquo;s also define a recoding probability of a state $s$ as the probability assigned by the density model to $s$ after observing a new occurrence of $s$, $\\rho'_n(s) = \\rho(s \\vert s_{1:n}s)$.\nThe paper introduced two concepts to better regulate the density model, a pseudo-count function $\\hat{N}_n(s)$ and a pseudo-count total $\\hat{n}$. As they are designed to imitate an empirical count function, we would have:\n $$ \\rho_n(s) = \\frac{\\hat{N}_n(s)}{\\hat{n}} \\leq \\rho'_n(s) = \\frac{\\hat{N}_n(s) + 1}{\\hat{n} + 1} $$ The relationship between $\\rho_n(x)$ and $\\rho'_n(x)$ requires the density model to be learning-positive: for all $s_{1:n} \\in \\mathcal{S}^n$ and all $s \\in \\mathcal{S}$, $\\rho_n(s) \\leq \\rho'_n(s)$. In other words, After observing one instance of $s$, the density model\u0026rsquo;s prediction of that same $s$ should increase. Apart from being learning-positive, the density model should be trained completely online with non-randomized mini-batches of experienced states, so naturally we have $\\rho'_n = \\rho_{n+1}$.\nThe pseudo-count can be computed from $\\rho_n(s)$ and $\\rho'_n(s)$ after solving the above linear system:\n $$ \\hat{N}_n(s) = \\hat{n} \\rho_n(s) = \\frac{\\rho_n(s)(1 - \\rho'_n(s))}{\\rho'_n(s) - \\rho_n(s)} $$ Or estimated by the prediction gain (PG):\n $$ \\hat{N}_n(s) \\approx (e^{\\text{PG}_n(s)} - 1)^{-1} = (e^{\\log \\rho'_n(s) - \\log \\rho(s)} - 1)^{-1} $$ A common choice of a count-based intrinsic bonus is $r^i_t = N(s_t, a_t)^{-1/2}$ (as in MBIE-EB; Strehl \u0026amp; Littman, 2008). The pseudo-count-based exploration bonus is shaped in a similar form, $r^i_t = \\big(\\hat{N}_n(s_t, a_t) + 0.01 \\big)^{-1/2}$.\nExperiments in Bellemare et al., (2016) adopted a simple CTS (Context Tree Switching) density model to estimate pseudo-counts. The CTS model takes as input a 2D image and assigns to it a probability according to the product of location-dependent L-shaped filters, where the prediction of each filter is given by a CTS algorithm trained on past images. The CTS model is simple but limited in expressiveness, scalability, and data efficiency. In a following-up paper, Georg Ostrovski, et al. (2017) improved the approach by training a PixelCNN (van den Oord et al., 2016) as the density model.\nThe density model can also be a Gaussian Mixture Model as in Zhao \u0026amp; Tresp (2018). They used a variational GMM to estimate the density of trajectories (e.g. concatenation of a sequence of states) and its predicted probabilities to guide prioritization in experience replay in off-policy setting.\nCounting after Hashing Another idea to make it possible to count high-dimensional states is to map states into hash codes so that the occurrences of states become trackable (Tang et al. 2017). The state space is discretized with a hash function $\\phi: \\mathcal{S} \\mapsto \\mathbb{Z}^k$. An exploration bonus $r^{i}: \\mathcal{S} \\mapsto \\mathbb{R}$ is added to the reward function, defined as $r^{i}(s) = {N(\\phi(s))}^{-1/2}$, where $N(\\phi(s))$ is an empirical count of occurrences of $\\phi(s)$.\nTang et al. (2017) proposed to use Locality-Sensitive Hashing (LSH) to convert continuous, high-dimensional data to discrete hash codes. LSH is a popular class of hash functions for querying nearest neighbors based on certain similarity metrics. A hashing scheme $x \\mapsto h(x)$ is locality-sensitive if it preserves the distancing information between data points, such that close vectors obtain similar hashes while distant vectors have very different ones. (See how LSH is used in Transformer improvement if interested.) SimHash is a type of computationally efficient LSH and it measures similarity by angular distance:\n $$ \\phi(s) = \\text{sgn}(A g(s)) \\in \\{-1, 1\\}^k $$ where $A \\in \\mathbb{R}^{k \\times D}$ is a matrix with each entry drawn i.i.d. from a standard Gaussian and $g: \\mathcal{S} \\mapsto \\mathbb{R}^D$ is an optional preprocessing function. The dimension of binary codes is $k$, controlling the granularity of the state space discretization. A higher $k$ leads to higher granularity and fewer collisions.\nFig. 2. Algorithm of count-based exploration through hashing high-dimensional states by SimHash. (Image source: Tang et al. 2017) For high-dimensional images, SimHash may not work well on the raw pixel level. Tang et al. (2017) designed an autoencoder (AE) which takes as input states $s$ to learn hash codes. It has one special dense layer composed of $k$ sigmoid functions as the latent state in the middle and then the sigmoid activation values $b(s)$ of this layer are binarized by rounding to their closest binary numbers $\\lfloor b(s)\\rceil \\in \\{0, 1\\}^D$ as the binary hash codes for state $s$. The AE loss over $n$ states includes two terms:\n $$ \\mathcal{L}(\\{s_n\\}_{n=1}^N) = \\underbrace{-\\frac{1}{N} \\sum_{n=1}^N \\log p(s_n)}_\\text{reconstruction loss} + \\underbrace{\\frac{1}{N} \\frac{\\lambda}{K} \\sum_{n=1}^N\\sum_{i=1}^k \\min \\big \\{ (1-b_i(s_n))^2, b_i(s_n)^2 \\big\\}}_\\text{sigmoid activation being closer to binary} $$ One problem with this approach is that dissimilar inputs $s_i, s_j$ may be mapped to identical hash codes but the AE still reconstructs them perfectly. One can imagine replacing the bottleneck layer $b(s)$ with the hash codes $\\lfloor b(s)\\rceil$, but then gradients cannot be back-propagated through the rounding function. Injecting uniform noise could mitigate this effect, as the AE has to learn to push the latent variable far apart to counteract the noise.\nPrediction-based Exploration The second category of intrinsic exploration bonuses are rewarded for improvement of the agent\u0026rsquo;s knowledge about the environment. The agent\u0026rsquo;s familiarity with the environment dynamics can be estimated through a prediction model. This idea of using a prediction model to measure curiosity was actually proposed quite a long time ago (Schmidhuber, 1991).\nForward Dynamics Learning a forward dynamics prediction model is a great way to approximate how much knowledge our model has obtained about the environment and the task MDPs. It captures an agent\u0026rsquo;s capability of predicting the consequence of its own behavior, $f: (s_t, a_t) \\mapsto s_{t+1}$. Such a model cannot be perfect (e.g. due to partial observation), the error $e(s_t, a_t) = | f(s_t, a_t) - s_{t+1} |^2_2$ can be used for providing intrinsic exploration rewards. The higher the prediction error, the less familiar we are with that state. The faster the error rate drops, the more learning progress signals we acquire.\nIntelligent Adaptive Curiosity (IAC; Oudeyer, et al. 2007) sketched an idea of using a forward dynamics prediction model to estimate learning progress and assigned intrinsic exploration reward accordingly.\nIAC relies on a memory which stores all the experiences encountered by the robot, $M=\\{(s_t, a_t, s_{t+1})\\}$ and a forward dynamics model $f$. IAC incrementally splits the state space (i.e. sensorimotor space in the context of robotics, as discussed in the paper) into separate regions based on the transition samples, using a process similar to how a decision tree is split: The split happens when the number of samples is larger than a threshold, and the variance of states in each leaf should be minimal. Each tree node is characterized by its exclusive set of samples and has its own forward dynamics predictor $f$, named \u0026ldquo;expert\u0026rdquo;.\nThe prediction error $e_t$ of an expert is pushed into a list associated with each region. The learning progress is then measured as the difference between the mean error rate of a moving window with offset $\\tau$ and the current moving window. The intrinsic reward is defined for tracking the learning progress: $r^i_t = \\frac{1}{k}\\sum_{i=0}^{k-1}(e_{t-i-\\tau} - e_{t-i})$, where $k$ is the moving window size. So the larger prediction error rate decrease we can achieve, the higher intrinsic reward we would assign to the agent. In other words, the agent is encouraged to take actions to quickly learn about the environment.\nFig. 3. Architecture of the IAC (Intelligent Adaptive Curiosity) module: the intrinsic reward is assigned w.r.t the learning progress in reducing prediction error of the dynamics model. (Image source: Oudeyer, et al. 2007) Stadie et al. (2015) trained a forward dynamics model in the encoding space defined by $\\phi$, $f_\\phi: (\\phi(s_t), a_t) \\mapsto \\phi(s_{t+1})$. The model\u0026rsquo;s prediction error at time $T$ is normalized by the maximum error up to time $t$, $\\bar{e}_t = \\frac{e_t}{\\max_{i \\leq t} e_i}$, so it is always between 0 and 1. The intrinsic reward is defined accordingly: $r^i_t = (\\frac{\\bar{e}_t(s_t, a_t)}{t \\cdot C})$, where $C \u0026gt; 0$ is a decay constant.\nEncoding the state space via $\\phi(.)$ is necessary, as experiments in the paper have shown that a dynamics model trained directly on raw pixels has very poor behavior \u0026mdash; assigning same exploration bonuses to all the states. In Stadie et al. (2015), the encoding function $\\phi$ is learned via an autocoder (AE) and $\\phi(.)$ is one of the output layers in AE. The AE can be statically trained using a set of images collected by a random agent, or dynamically trained together with the policy where the early frames are gathered using $\\epsilon$-greedy exploration.\nInstead of autoencoder, Intrinsic Curiosity Module (ICM; Pathak, et al., 2017) learns the state space encoding $\\phi(.)$ with a self-supervised inverse dynamics model. Predicting the next state given the agent\u0026rsquo;s own action is not easy, especially considering that some factors in the environment cannot be controlled by the agent or do not affect the agent. ICM believes that a good state feature space should exclude such factors because they cannot influence the agent\u0026rsquo;s behavior and thus the agent has no incentive for learning them. By learning an inverse dynamics model $g: (\\phi(s_t), \\phi(s_{t+1})) \\mapsto a_t$, the feature space only captures those changes in the environment related to the actions of our agent, and ignores the rest.\nGiven a forward model $f$, an inverse dynamics model $g$ and an observation $(s_t, a_t, s_{t+1})$:\n $$ g_{\\psi_I}(\\phi(s_t), \\phi(s_{t+1})) = \\hat{a}_t \\quad f_{\\psi_F}(\\phi(s_t), a_t) = \\hat{\\phi}(s_{t+1}) \\quad r_t^i = \\| \\hat{\\phi}(s_{t+1}) - \\phi(s_{t+1}) \\|_2^2 $$ Such $\\phi(.)$ is expected to be robust to uncontrollable aspects of the environment.\nFig. 4. ICM (Intrinsic Curiosity Module) assigns the forward dynamics prediction error to the agent as the intrinsic reward. This dynamics model operates in a state encoding space learned through an inverse dynamics model to exclude environmental factors that do not affect the agent's behavior. (Image source: Pathak, et al. 2017) Burda, Edwards \u0026amp; Pathak, et al. (2018) did a set of large-scale comparison experiments on purely curiosity-driven learning, meaning that only intrinsic rewards are provided to the agent. In this study, the reward is $r_t = r^i_t = | f(s_t, a_t) - \\phi(s_{t+1})|_2^2$. A good choice of $\\phi$ is crucial to learning forward dynamics, which is expected to be compact, sufficient and stable, making the prediction task more tractable and filtering out irrelevant observation.\nIn comparison of 4 encoding functions:\n Raw image pixels: No encoding, $\\phi(x) = x$. Random features (RF): Each state is compressed through a fixed random neural network. VAE: The probabilistic encoder is used for encoding, $\\phi(x) = q(z \\vert x)$. Inverse dynamic features (IDF): The same feature space as used in ICM. All the experiments have the reward signals normalized by a running estimation of standard deviation of the cumulative returns. And all the experiments are running in an infinite horizon setting to avoid \u0026ldquo;done\u0026rdquo; flag leaking information.\nFig. 5. The mean reward in different games when training with only curiosity signals, generated by different state encoding functions. (Image source: Burda, Edwards \u0026 Pathak, et al. 2018) Interestingly random features turn out to be quite competitive, but in feature transfer experiments (i.e. train an agent in Super Mario Bros level 1-1 and then test it in another level), learned IDF features can generalize better.\nThey also compared RF and IDF in an environment with a noisy TV on. Unsurprisingly the noisy TV drastically slows down the learning and extrinsic rewards are much lower in time.\nFig. 6. Experiments using RF and IDF feature encoding in an environment with noisy TV on or off. The plot tracks extrinsic reward per episode as the training progresses. (Image source: Burda, Edwards \u0026 Pathak, et al. 2018) The forward dynamics optimization can be modeled via variational inference as well. VIME (short for \u0026ldquo;Variational information maximizing exploration\u0026rdquo;; Houthooft, et al. 2017) is an exploration strategy based on maximization of information gain about the agent\u0026rsquo;s belief of environment dynamics. How much additional information has been obtained about the forward dynamics can be measured as the reduction in entropy.\nLet $\\mathcal{P}$ be the environment transition function, $p(s_{t+1}\\vert s_t, a_t; \\theta)$ be the forward prediction model, parameterized by $\\theta \\in \\Theta$, and $\\xi_t = \\{s_1, a_1, \\dots, s_t\\}$ be the trajectory history. We would like to reduce the entropy after taking a new action and observing the next state, which is to maximize the following:\n $$ \\begin{aligned} \u0026\\sum_t H(\\Theta \\vert \\xi_t, a_t) - H(\\Theta \\vert S_{t+1}, \\xi_t, a_t) \\\\ =\u0026 I(\\Theta; S_{t+1} \\vert \\xi_t, a_t) \\quad \\scriptstyle{\\text{; because } I(X; Y) = I(X) - I(X \\vert Y)} \\\\ =\u0026 \\mathbb{E}_{s_{t+1} \\sim \\mathcal{P}(.\\vert\\xi_t,a_t)} [D_\\text{KL}(p(\\theta \\vert \\xi_t, a_t, s_{t+1}) \\| p(\\theta \\vert \\xi_t, a_t))] \\quad \\scriptstyle{\\text{; because } I(X; Y) = \\mathbb{E}_Y [D_\\text{KL} (p_{X \\vert Y} \\| p_X)]} \\\\ =\u0026 \\mathbb{E}_{s_{t+1} \\sim \\mathcal{P}(.\\vert\\xi_t,a_t)} [D_\\text{KL}(p(\\theta \\vert \\xi_t, a_t, s_{t+1}) \\| p(\\theta \\vert \\xi_t))] \\quad \\scriptstyle{\\text{; because } \\theta \\text{ does not depend on } a_t} \\end{aligned} $$ While taking expectation over the new possible states, the agent is expected to take a new action to increase the KL divergence (\u0026ldquo;information gain\u0026rdquo;) between its new belief over the prediction model to the old one. This term can be added into the reward function as an intrinsic reward: $r^i_t = D_\\text{KL} [p(\\theta \\vert \\xi_t, a_t, s_{t+1}) | p(\\theta \\vert \\xi_t))]$.\nHowever, computing the posterior $p(\\theta \\vert \\xi_t, a_t, s_{t+1})$ is generally intractable.\n $$ \\begin{aligned} p(\\theta \\vert \\xi_t, a_t, s_{t+1}) \u0026= \\frac{p(\\theta \\vert \\xi_t, a_t) p(s_{t+1} \\vert \\xi_t, a_t; \\theta)}{p(s_{t+1}\\vert\\xi_t, a_t)} \\\\ \u0026= \\frac{p(\\theta \\vert \\xi_t) p(s_{t+1} \\vert \\xi_t, a_t; \\theta)}{p(s_{t+1}\\vert\\xi_t, a_t)} \u0026 \\scriptstyle{\\text{; because action doesn't affect the belief.}} \\\\ \u0026= \\frac{\\color{red}{p(\\theta \\vert \\xi_t)} p(s_{t+1} \\vert \\xi_t, a_t; \\theta)}{\\int_\\Theta p(s_{t+1}\\vert\\xi_t, a_t; \\theta) \\color{red}{p(\\theta \\vert \\xi_t)} d\\theta} \u0026 \\scriptstyle{\\text{; red part is hard to compute directly.}} \\end{aligned} $$ Since it is difficult to compute $p(\\theta\\vert\\xi_t)$ directly, a natural choice is to approximate it with an alternative distribution $q_\\phi(\\theta)$. With variational lower bound, we know the maximization of $q_\\phi(\\theta)$ is equivalent to maximizing $p(\\xi_t\\vert\\theta)$ and minimizing $D_\\text{KL}[q_\\phi(\\theta) | p(\\theta)]$.\nUsing the approximation distribution $q$, the intrinsic reward becomes:\n $$ r^i_t = D_\\text{KL} [q_{\\phi_{t+1}}(\\theta) \\| q_{\\phi_t}(\\theta))] $$ where $\\phi_{t+1}$ represents $q$\u0026rsquo;s parameters associated with the new relief after seeing $a_t$ and $s_{t+1}$. When used as an exploration bonus, it is normalized by division by the moving median of this KL divergence value.\nHere the dynamics model is parameterized as a Bayesian neural network (BNN), as it maintains a distribution over its weights. The BNN weight distribution $q_\\phi(\\theta)$ is modeled as a fully factorized Gaussian with $\\phi = \\{\\mu, \\sigma\\}$ and we can easily sample $\\theta \\sim q_\\phi(.)$. After applying a second-order Taylor expansion, the KL term $D_\\text{KL}[q_{\\phi + \\lambda \\Delta\\phi}(\\theta) | q_{\\phi}(\\theta)]$ can be estimated using Fisher Information Matrix $\\mathbf{F}_\\phi$, which is easy to compute, because $q_\\phi$ is factorized Gaussian and thus the covariance matrix is only a diagonal matrix. See more details in the paper, especially section 2.3-2.5.\nAll the methods above depend on a single prediction model. If we have multiple such models, we could use the disagreement among models to set the exploration bonus (Pathak, et al. 2019). High disagreement indicates low confidence in prediction and thus requires more exploration. Pathak, et al. (2019) proposed to train a set of forward dynamics models and to use the variance over the ensemble of model outputs as $r_t^i$. Precisely, they encode the state space with random feature and learn 5 models in the ensemble.\nFig. 7. Illustration of training architecture for self-supervised exploration via disagreement. (Image source: Pathak, et al. 2019) Because $r^i_t$ is differentiable, the intrinsic reward in the model could be directly optimized through gradient descent so as to inform the policy agent to change actions. This differentiable exploration approach is very efficient but limited by having a short exploration horizon.\nRandom Networks But, what if the prediction task is not about the environment dynamics at all? It turns out when the prediction is for a random task, it still can help exploration.\nDORA (short for \u0026ldquo;Directed Outreaching Reinforcement Action-Selection\u0026rdquo;; Fox \u0026amp; Choshen, et al. 2018) is a novel framework that injects exploration signals based on a newly introduced, task-independent MDP. The idea of DORA depends on two parallel MDPs:\n One is the original task MDP; The other is an identical MDP but with no reward attached: Rather, every state-action pair is designed to have value 0. The Q-value learned for the second MDP is called E-value. If the model cannot perfectly predict E-value to be zero, it is still missing information. Initially E-value is assigned with value 1. Such positive initialization can encourage directed exploration for better E-value prediction. State-action pairs with high E-value estimation don\u0026rsquo;t have enough information gathered yet, at least not enough to exclude their high E-values. To some extent, the logarithm of E-values can be considered as a generalization of visit counters.\nWhen using a neural network to do function approximation for E-value, another value head is added to predict E-value and it is simply expected to predict zero. Given a predicted E-value $E(s_t, a_t)$, the exploration bonus is $r^i_t = \\frac{1}{\\sqrt{-\\log E(s_t, a_t)}}$.\nSimilar to DORA, Random Network Distillation (RND; Burda, et al. 2018) introduces a prediction task independent of the main task. The RND exploration bonus is defined as the error of a neural network $\\hat{f}(s_t)$ predicting features of the observations given by a fixed randomly initialized neural network $f(s_t)$. The motivation is that given a new state, if similar states have been visited many times in the past, the prediction should be easier and thus has lower error. The exploration bonus is $r^i(s_t) = |\\hat{f}(s_t; \\theta) - f(s_t) |_2^2$.\nFig. 8. How RND (Random Network Distillation) works for providing an intrinsic reward. The features $O_{i+1} \\mapsto f_{i+1}$ are generated by a fixed random neural network. (Image source: OpenAI Blog: \"Reinforcement Learning with Prediction-Based Rewards\") Two factors are important in RND experiments:\n Non-episodic setting results in better exploration, especially when not using any extrinsic rewards. It means that the return is not truncated at \u0026ldquo;Game over\u0026rdquo; and intrinsic return can spread across multiple episodes. Normalization is important since the scale of the reward is tricky to adjust given a random neural network as a prediction target. The intrinsic reward is normalized by division by a running estimate of the standard deviations of the intrinsic return. The RND setup works well for resolving the hard-exploration problem. For example, maximizing the RND exploration bonus consistently finds more than half of the rooms in Montezuma\u0026rsquo;s Revenge.\nPhysical Properties Different from games in simulators, some RL applications like Robotics need to understand objects and intuitive reasoning in the physical world. Some prediction tasks require the agent to perform a sequence of interactions with the environment and to observe the corresponding consequences, such as estimating some hidden properties in physics (e.g. mass, friction, etc).\nMotivated by such ideas, Denil, et al. (2017) found that DRL agents can learn to perform necessary exploration to discover such hidden properties. Precisely they considered two experiments:\n \u0026ldquo;Which is heavier?\u0026quot; \u0026mdash; The agent has to interact with the blocks and infer which one is heavier. \u0026ldquo;Towers\u0026rdquo; \u0026mdash; The agent needs to infer how many rigid bodies a tower is composed of by knocking it down. The agent in the experiments first goes through an exploration phase to interact with the environment and to collect information. Once the exploration phase ends, the agent is asked to output a labeling action to answer the question. Then a positive reward is assigned to the agent if the answer is correct; otherwise a negative one is assigned. Because the answer requires a decent amount of interactions with items in the scene, the agent has to learn to efficiently play around so as to figure out the physics and the correct answer. The exploration naturally happens.\nIn their experiments, the agent is able to learn in both tasks with performance varied by the difficulty of the task. Although the paper didn\u0026rsquo;t use the physics prediction task to provide intrinsic reward bonus along with extrinsic reward associated with another learning task, rather it focused on the exploration tasks themselves. I do enjoy the idea of encouraging sophisticated exploration behavior by predicting hidden physics properties in the environment.\nMemory-based Exploration Reward-based exploration suffers from several drawbacks:\n Function approximation is slow to catch up. Exploration bonus is non-stationary. Knowledge fading, meaning that states cease to be novel and cannot provide intrinsic reward signals in time. Methods in this section rely on external memory to resolve disadvantages of reward bonus-based exploration.\nEpisodic Memory As mentioned above, RND is better running in an non-episodic setting, meaning the prediction knowledge is accumulated across multiple episodes. The exploration strategy, Never Give Up (NGU; Badia, et al. 2020a), combines an episodic novelty module that can rapidly adapt within one episode with RND as a lifelong novelty module.\nPrecisely, the intrinsic reward in NGU consists of two exploration bonuses from two modules, within one episode and across multiple episodes, respectively.\nThe short-term per-episode reward is provided by an episodic novelty module. It contains an episodic memory $M$, a dynamically-sized slot-based memory, and an IDF (inverse dynamics features) embedding function $\\phi$, same as the feature encoding in ICM\n At every step the current state embedding $\\phi(s_t)$ is added into $M$.\n The intrinsic bonus is determined by comparing how similar the current observation is to the content of $M$. A larger difference results in a larger bonus.\n $$ r^\\text{episodic}_t \\approx \\frac{1}{\\sqrt{\\sum_{\\phi_i \\in N_k} K(\\phi(x_t), \\phi_i)} + c} $$ where $K(x, y)$ is a kernel function for measuring the distance between two samples. $N_k$ is a set of $k$ nearest neighbors in $M$ according to $K(., .)$. $c$ is a small constant to keep the denominator non-zero. In the paper, $K(x, y)$ is configured to be the inverse kernel:\n $$ K(x, y) = \\frac{\\epsilon}{\\frac{d^2(x, y)}{d^2_m} + \\epsilon} $$ where $d(.,.)$ is Euclidean distance between two samples and $d_m$ is a running average of the squared Euclidean distance of the k-th nearest neighbors for better robustness. $\\epsilon$ is a small constant.\n Fig. 9. The architecture of NGU's embedding function (left) and reward generator (right). (Image source: Badia, et al. 2020a) The long-term across-episode novelty relies on RND prediction error in life-long novelty module. The exploration bonus is $\\alpha_t = 1 + \\frac{e^\\text{RND}(s_t) - \\mu_e}{\\sigma_e}$ where $\\mu_e$ and $\\sigma_e$ are running mean and std dev for RND error $e^\\text{RND}(s_t)$.\n However in the conclusion section of the RND paper, I noticed the following statement:\n\u0026ldquo;We find that the RND exploration bonus is sufficient to deal with local exploration, i.e. exploring the consequences of short-term decisions, like whether to interact with a particular object, or avoid it. However global exploration that involves coordinated decisions over long time horizons is beyond the reach of our method. \u0026quot;\nAnd this confuses me a bit how RND can be used as a good life-long novelty bonus provider. If you know why, feel free to leave a comment below.\n The final combined intrinsic reward is $r^i_t = r^\\text{episodic}_t \\cdot \\text{clip}(\\alpha_t, 1, L)$, where $L$ is a constant maximum reward scalar.\nThe design of NGU enables it to have two nice properties:\n Rapidly discourages revisiting the same state within the same episode; Slowly discourages revisiting states that have been visited many times across episodes. Later, built on top of NGU, DeepMind proposed \u0026ldquo;Agent57\u0026rdquo; (Badia, et al. 2020b), the first deep RL agent that outperforms the standard human benchmark on all 57 Atari games. Two major improvements in Agent57 over NGU are:\n A population of policies are trained in Agent57, each equipped with a different exploration parameter pair $\\{(\\beta_j, \\gamma_j)\\}_{j=1}^N$. Recall that given $\\beta_j$, the reward is constructed as $r_{j,t} = r_t^e + \\beta_j r^i_t$ and $\\gamma_j$ is the reward discounting factor. It is natural to expect policies with higher $\\beta_j$ and lower $\\gamma_j$ to make more progress early in training, while the opposite would be expected as training progresses. A meta-controller (sliding-window UCB bandit algorithm) is trained to select which policies should be prioritized. The second improvement is a new parameterization of Q-value function that decomposes the contributions of the intrinsic and extrinsic rewards in a similar form as the bundled reward: $Q(s, a; \\theta_j) = Q(s, a; \\theta_j^e) + \\beta_j Q(s, a; \\theta_j^i)$. During training, $Q(s, a; \\theta_j^e)$ and $Q(s, a; \\theta_j^i)$ are optimized separately with rewards $r_j^e$ and $r_j^i$, respectively. Fig. 10. A pretty cool illustration of techniques developed in time since DQN in 2015, eventually leading to Agent57. (Image source: DeepMind Blog: \"Agent57: Outperforming the human Atari benchmark\") Instead of using the Euclidean distance to measure closeness of states in episodic memory, Savinov, et al. (2019) took the transition between states into consideration and proposed a method to measure the number of steps needed to visit one state from other states in memory, named Episodic Curiosity (EC) module. The novelty bonus depends on reachability between states.\n At the beginning of each episode, the agent starts with an empty episodic memory $M$. At every step, the agent compares the current state with saved states in memory to determine novelty bonus: If the current state is novel (i.e., takes more steps to reach from observations in memory than a threshold), the agent gets a bonus. The current state is added into the episodic memory if the novelty bonus is high enough. (Imagine that if all the states were added into memory, any new state could be added within 1 step.) Repeat 1-3 until the end of this episode. Fig. 11. The nodes in the graph are states, the edges are possible transitions. The blue nodes are states in memory. The green nodes are reachable from the memory within $k = 2$ steps (not novel). The orange nodes are further away, so they are considered as novel states. (Image source: Savinov, et al. 2019) In order to estimate reachability between states, we need to access the transition graph, which is unfortunately not entirely known. Thus, Savinov, et al. (2019) trained a siamese neural network to predict how many steps separate two states. It contains one embedding network $\\phi: \\mathcal{S} \\mapsto \\mathbb{R}^n$ to first encode the states to feature vectors and then one comparator network $C: \\mathbb{R}^n \\times \\mathbb{R}^n \\mapsto [0, 1]$ to output a binary label on whether two states are close enough (i.e., reachable within $k$ steps) in the transition graph, $C(\\phi(s_i), \\phi(s_j)) \\mapsto [0, 1]$.\nAn episodic memory buffer $M$ stores embeddings of some past observations within the same episode. A new observation will be compared with existing state embeddings via $C$ and the results are aggregated (e.g. max, 90th percentile) to provide a reachability score $C^M(\\phi(s_t))$. The exploration bonus is $r^i_t = \\big(C' - C^M(f(s_t))\\big)$, where $C'$ is a predefined threshold for determining the sign of the reward (e.g. $C'=0.5$ works well for fixed-duration episodes). High bonus is awarded to new states when they are not easily reachable from states in the memory buffer.\nThey claimed that the EC module can overcome the noisy-TV problem.\nFig. 12. The architecture of episodic curiosity (EC) module for intrinsic reward generation. (Image source: Savinov, et al. 2019) Direct Exploration Go-Explore (Ecoffet, et al., 2019) is an algorithm aiming to solve the \u0026ldquo;hard-exploration\u0026rdquo; problem. It is composed of the following two phases.\nPhase 1 (\u0026ldquo;Explore until solved\u0026rdquo;) feels quite like Dijkstra\u0026rsquo;s algorithm for finding shortest paths in a graph. Indeed, no neural network is involved in phase 1. By maintaining a memory of interesting states as well as trajectories leading to them, the agent can go back (given a simulator is deterministic) to promising states and continue doing random exploration from there. The state is mapped into a short discretized code (named \u0026ldquo;cell\u0026rdquo;) in order to be memorized. The memory is updated if a new state appears or a better/shorter trajectory is found. When selecting which past states to return to, the agent might select one in the memory uniformly or according to heuristics like recency, visit count, count of neighbors in the memory, etc. This process is repeated until the task is solved and at least one solution trajectory is found.\nThe above found high-performance trajectories would not work well on evaluation envs with any stochasticity. Thus, Phase 2 (\u0026ldquo;Robustification\u0026rdquo;) is needed to robustify the solution via imitation learning. They adopted Backward Algorithm, in which the agent is started near the last state in the trajectory and then runs RL optimization from there.\nOne important note in phase 1 is: In order to go back to a state deterministically without exploration, Go-Explore depends on a resettable and deterministic simulator, which is a big disadvantage.\nTo make the algorithm more generally useful to environments with stochasticity, an enhanced version of Go-Explore (Ecoffet, et al., 2020), named policy-based Go-Explore was proposed later.\n Instead of resetting the simulator state effortlessly, the policy-based Go-Explore learns a goal-conditioned policy and uses that to access a known state in memory repeatedly. The goal-conditioned policy is trained to follow the best trajectory that previously led to the selected states in memory. They include a Self-Imitation Learning (SIL; Oh, et al. 2018) loss to help extract as much information as possible from successful trajectories. Also, they found sampling from policy works better than random actions when the agent returns to promising states to continue exploration. Another improvement in policy-based Go-Explore is to make the downscaling function of images to cells adjustable. It is optimized so that there would be neither too many nor too few cells in the memory. Fig. 13. An overview of the Go-Explore algorithm. (Image source: Ecoffet, et al., 2020) After vanilla Go-Explore, Yijie Guo, et al. (2019) proposed DTSIL (Diverse Trajectory-conditioned Self-Imitation Learning), which shared a similar idea as policy-based Go-Explore above. DTSIL maintains a memory of diverse demonstrations collected during training and uses them to train a trajectory-conditioned policy via SIL. They prioritize trajectories that end with a rare state during sampling.\nFig. 14. Algorithm of DTSIL (Diverse Trajectory-conditioned Self-Imitation Learning). (Image source: Yijie Guo, et al. 2019) The similar approach is also seen in Guo, et al. (2019). The main idea is to store goals with high uncertainty in memory so that later the agent can revisit these goal states with a goal-conditioned policy repeatedly. In each episode, the agent flips a coin (probability 0.5) to decide whether it will act greedily w.r.t. the policy or do directed exploration by sampling goals from the memory.\nFig. 15. Different components in directed exploration with function approximation. (Image source: Guo, et al. 2019) The uncertainty measure of a state can be something simple like count-based bonuses or something complex like density or bayesian models. The paper trained a forward dynamics model and took its prediction error as the uncertainty metric.\nQ-Value Exploration Inspired by Thompson sampling, Bootstrapped DQN (Osband, et al. 2016) introduces a notion of uncertainty in Q-value approximation in classic DQN by using the bootstrapping method. Bootstrapping is to approximate a distribution by sampling with replacement from the same population multiple times and then aggregate the results.\nMultiple Q-value heads are trained in parallel but each only consumes a bootstrapped sub-sampled set of data and each has its own corresponding target network. All the Q-value heads share the same backbone network.\nFig. 16. The algorithm of Bootstrapped DQN. (Image source: Osband, et al. 2016) At the beginning of one episode, one Q-value head is sampled uniformly and acts for collecting experience data in this episode. Then a binary mask is sampled from the masking distribution $m \\sim \\mathcal{M}$ and decides which heads can use this data for training. The choice of masking distribution $\\mathcal{M}$ determines how bootstrapped samples are generated; For example,\n If $\\mathcal{M}$ is an independent Bernoulli distribution with $p=0.5$, this corresponds to the double-or-nothing bootstrap. If $\\mathcal{M}$ always returns an all-one mask, the algorithm reduces to an ensemble method. However, this kind of exploration is still restricted, because uncertainty introduced by bootstrapping fully relies on the training data. It is better to inject some prior information independent of the data. This \u0026ldquo;noisy\u0026rdquo; prior is expected to drive the agent to keep exploring when the reward is sparse. The algorithm of adding random prior into bootstrapped DQN for better exploration (Osband, et al. 2018) depends on Bayesian linear regression. The core idea of Bayesian regression is: We can \u0026ldquo;generate posterior samples by training on noisy versions of the data, together with some random regularization\u0026rdquo;.\nLet $\\theta$ be the Q function parameter and $\\theta^-$ for the target Q, the loss function using a randomized prior function $p$ is:\n $$ \\mathcal{L}(\\theta, \\theta^{-}, p, \\mathcal{D}; \\gamma) = \\sum_{t\\in\\mathcal{D}}\\Big( r_t + \\gamma \\max_{a'\\in\\mathcal{A}} (\\underbrace{Q_{\\theta^-} + p)}_\\text{target Q}(s'_t, a') - \\underbrace{(Q_\\theta + p)}_\\text{Q to optimize}(s_t, a_t) \\Big)^2 $$ Varitional Options Options are policies with termination conditions. There are a large set of options available in the search space and they are independent of an agent\u0026rsquo;s intentions. By explicitly including intrinsic options into modeling, the agent can obtain intrinsic rewards for exploration.\nVIC (short for \u0026ldquo;Variational Intrinsic Control\u0026rdquo;; Gregor, et al. 2017) is such a framework for providing the agent with intrinsic exploration bonuses based on modeling options and learning policies conditioned on options. Let $\\Omega$ represent an option which starts from $s_0$ and ends at $s_f$. An environment probability distribution $p^J(s_f \\vert s_0, \\Omega)$ defines where an option $\\Omega$ terminates given a starting state $s_0$. A controllability distribution $p^C(\\Omega \\vert s_0)$ defines the probability distribution of options we can sample from. And by definition we have $p(s_f, \\Omega \\vert s_0) = p^J(s_f \\vert s_0, \\Omega) p^C(\\Omega \\vert s_0)$.\nWhile choosing options, we would like to achieve two goals:\n Achieve a diverse set of the final states from $s_0$ ⇨ Maximization of $H(s_f \\vert s_0)$. Know precisely which state a given option $\\Omega$ can end with ⇨ Minimization of $H(s_f \\vert s_0, \\Omega)$. Combining them, we get mutual information $I(\\Omega; s_f \\vert s_0)$ to maximize:\n $$ \\begin{aligned} I(\\Omega; s_f \\vert s_0) \u0026= H(s_f \\vert s_0) - H(s_f \\vert s_0, \\Omega) \\\\ \u0026= - \\sum_{s_f} p(s_f \\vert s_0) \\log p(s_f \\vert s_0) + \\sum_{s_f, \\Omega} p(s_f, \\Omega \\vert s_0) \\log \\frac{p(s_f, \\Omega \\vert s_0)}{p^C(\\Omega \\vert s_0)} \\\\ \u0026= - \\sum_{s_f} p(s_f \\vert s_0) \\log p(s_f \\vert s_0) + \\sum_{s_f, \\Omega} p^J(s_f \\vert s_0, \\Omega) p^C(\\Omega \\vert s_0) \\log p^J(s_f \\vert s_0, \\Omega) \\\\ \\end{aligned} $$ Because mutual information is symmetric, we can switch $s_f$ and $\\Omega$ in several places without breaking the equivalence. Also because $p(\\Omega \\vert s_0, s_f)$ is difficult to observe, let us replace it with an approximation distribution $q$. According to the variational lower bound, we would have $I(\\Omega; s_f \\vert s_0) \\geq I^{VB}(\\Omega; s_f \\vert s_0)$.\n $$ \\begin{aligned} I(\\Omega; s_f \\vert s_0) \u0026= I(s_f; \\Omega \\vert s_0) \\\\ \u0026= - \\sum_{\\Omega} p(\\Omega \\vert s_0) \\log p(\\Omega \\vert s_0) + \\sum_{s_f, \\Omega} p^J(s_f \\vert s_0, \\Omega) p^C(\\Omega \\vert s_0) \\log \\color{red}{p(\\Omega \\vert s_0, s_f)}\\\\ I^{VB}(\\Omega; s_f \\vert s_0) \u0026= - \\sum_{\\Omega} p(\\Omega \\vert s_0) \\log p(\\Omega \\vert s_0) + \\sum_{s_f, \\Omega} p^J(s_f \\vert s_0, \\Omega) p^C(\\Omega \\vert s_0) \\log \\color{red}{q(\\Omega \\vert s_0, s_f)} \\\\ I(\\Omega; s_f \\vert s_0) \u0026\\geq I^{VB}(\\Omega; s_f \\vert s_0) \\end{aligned} $$ Fig. 17. The algorithm for VIC (Variational Intrinsic Control). (Image source: Gregor, et al. 2017) Here $\\pi(a \\vert \\Omega, s)$ can be optimized with any RL algorithm. The option inference function $q(\\Omega \\vert s_0, s_f)$ is doing supervised learning. The prior $p^C$ is updated so that it tends to choose $\\Omega$ with higher rewards. Note that $p^C$ can also be fixed (e.g. a Gaussian). Various $\\Omega$ will result in different behavior through learning. Additionally, Gregor, et al. (2017) observed that it is difficult to make VIC with explicit options work in practice with function approximation and therefore they also proposed another version of VIC with implicit options.\nDifferent from VIC which models $\\Omega$ conditioned only on the start and end states, VALOR (short for \u0026ldquo;Variational Auto-encoding Learning of Options by Reinforcement\u0026rdquo;; Achiam, et al. 2018) relies on the whole trajectory to extract the option context $c$, which is sampled from a fixed Gaussian distribution. In VALOR:\n A policy acts as an encoder, translating contexts from a noise distribution into trajectories A decoder attempts to recover the contexts from the trajectories, and rewards the policies for making contexts easier to distinguish. The decoder never sees the actions during training, so the agent has to interact with the environment in a way that facilitates communication with the decoder for better prediction. Also, the decoder recurrently takes in a sequence of steps in one trajectory to better model the correlation between timesteps. Fig. 18. The decoder of VALOR is a biLSTM which takes $N = 11$ equally spaced observations from one trajectory as inputs. (Image source: Achiam, et al. 2018) DIAYN (\u0026ldquo;Diversity is all you need\u0026rdquo;; Eysenbach, et al. 2018) has the idea lying in the same direction, although with a different name \u0026mdash; DIAYN models the policies conditioned on a latent skill variable. See my previous post for more details.\nCitation Cited as:\n Weng, Lilian. (Jun 2020). Exploration strategies in deep reinforcement learning. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2020-06-07-exploration-drl/.\n Or\n@article{weng2020exploration, title = \u0026quot;Exploration Strategies in Deep Reinforcement Learning\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2020\u0026quot;, month = \u0026quot;Jun\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2020-06-07-exploration-drl/\u0026quot; } Reference [1] Pierre-Yves Oudeyer \u0026amp; Frederic Kaplan. \u0026ldquo;How can we define intrinsic motivation?\u0026quot; Conf. on Epigenetic Robotics, 2008.\n[2] Marc G. Bellemare, et al. \u0026ldquo;Unifying Count-Based Exploration and Intrinsic Motivation\u0026rdquo;. NIPS 2016.\n[3] Georg Ostrovski, et al. \u0026ldquo;Count-Based Exploration with Neural Density Models\u0026rdquo;. PMLR 2017.\n[4] Rui Zhao \u0026amp; Volker Tresp. \u0026ldquo;Curiosity-Driven Experience Prioritization via Density Estimation\u0026rdquo;. NIPS 2018.\n[5] Haoran Tang, et al. \u0026quot;#Exploration: A Study of Count-Based Exploration for Deep Reinforcement Learning\u0026rdquo;. NIPS 2017.\n[6] Jürgen Schmidhuber. \u0026ldquo;A possibility for implementing curiosity and boredom in model-building neural controllers\u0026rdquo; 1991.\n[7] Pierre-Yves Oudeyer, et al. \u0026ldquo;Intrinsic Motivation Systems for Autonomous Mental Development\u0026rdquo; IEEE Transactions on Evolutionary Computation, 2007.\n[8] Bradly C. Stadie, et al. \u0026ldquo;Incentivizing Exploration In Reinforcement Learning With Deep Predictive Models\u0026rdquo;. ICLR 2016.\n[9] Deepak Pathak, et al. \u0026ldquo;Curiosity-driven Exploration by Self-supervised Prediction\u0026rdquo;. CVPR 2017.\n[10] Yuri Burda, Harri Edwards \u0026amp; Deepak Pathak, et al. \u0026ldquo;Large-Scale Study of Curiosity-Driven Learning\u0026rdquo;. arXiv 1808.04355 (2018).\n[11] Joshua Achiam \u0026amp; Shankar Sastry. \u0026ldquo;Surprise-Based Intrinsic Motivation for Deep Reinforcement Learning\u0026rdquo; NIPS 2016 Deep RL Workshop.\n[12] Rein Houthooft, et al. \u0026ldquo;VIME: Variational information maximizing exploration\u0026rdquo;. NIPS 2016.\n[13] Leshem Choshen, Lior Fox \u0026amp; Yonatan Loewenstein. \u0026ldquo;DORA the explorer: Directed outreaching reinforcement action-selection\u0026rdquo;. ICLR 2018\n[14] Yuri Burda, et al. \u0026ldquo;Exploration by Random Network Distillation\u0026rdquo; ICLR 2019.\n[15] OpenAI Blog: \u0026ldquo;Reinforcement Learning with Prediction-Based Rewards\u0026rdquo; Oct, 2018.\n[16] Misha Denil, et al. \u0026ldquo;Learning to Perform Physics Experiments via Deep Reinforcement Learning\u0026rdquo;. ICLR 2017.\n[17] Ian Osband, et al. \u0026ldquo;Deep Exploration via Bootstrapped DQN\u0026rdquo;. NIPS 2016.\n[18] Ian Osband, John Aslanides \u0026amp; Albin Cassirer. \u0026ldquo;Randomized Prior Functions for Deep Reinforcement Learning\u0026rdquo;. NIPS 2018.\n[19] Karol Gregor, Danilo Jimenez Rezende \u0026amp; Daan Wierstra. \u0026ldquo;Variational Intrinsic Control\u0026rdquo;. ICLR 2017.\n[20] Joshua Achiam, et al. \u0026ldquo;Variational Option Discovery Algorithms\u0026rdquo;. arXiv 1807.10299 (2018).\n[21] Benjamin Eysenbach, et al. \u0026ldquo;Diversity is all you need: Learning skills without a reward function.\u0026quot;. ICLR 2019.\n[22] Adrià Puigdomènech Badia, et al. \u0026ldquo;Never Give Up (NGU): Learning Directed Exploration Strategies\u0026rdquo; ICLR 2020.\n[23] Adrià Puigdomènech Badia, et al. \u0026ldquo;Agent57: Outperforming the Atari Human Benchmark\u0026rdquo;. arXiv 2003.13350 (2020).\n[24] DeepMind Blog: \u0026ldquo;Agent57: Outperforming the human Atari benchmark\u0026rdquo; Mar 2020.\n[25] Nikolay Savinov, et al. \u0026ldquo;Episodic Curiosity through Reachability\u0026rdquo; ICLR 2019.\n[26] Adrien Ecoffet, et al. \u0026ldquo;Go-Explore: a New Approach for Hard-Exploration Problems\u0026rdquo;. arXiv 1901.10995 (2019).\n[27] Adrien Ecoffet, et al. \u0026ldquo;First return then explore\u0026rdquo;. arXiv 2004.12919 (2020).\n[28] Junhyuk Oh, et al. \u0026ldquo;Self-Imitation Learning\u0026rdquo;. ICML 2018.\n[29] Yijie Guo, et al. \u0026ldquo;Self-Imitation Learning via Trajectory-Conditioned Policy for Hard-Exploration Tasks\u0026rdquo;. arXiv 1907.10247 (2019).\n[30] Zhaohan Daniel Guo \u0026amp; Emma Brunskill. \u0026ldquo;Directed Exploration for Reinforcement Learning\u0026rdquo;. arXiv 1906.07805 (2019).\n[31] Deepak Pathak, et al. “Self-Supervised Exploration via Disagreement.” ICML 2019.\n","permalink":"https://lilianweng.github.io/posts/2020-06-07-exploration-drl/","summary":"[Updated on 2020-06-17: Add \u0026ldquo;exploration via disagreement\u0026rdquo; in the \u0026ldquo;Forward Dynamics\u0026rdquo; section.\nExploitation versus exploration is a critical topic in Reinforcement Learning. We\u0026rsquo;d like the RL agent to find the best solution as fast as possible. However, in the meantime, committing to solutions too quickly without enough exploration sounds pretty bad, as it could lead to local minima or total failure. Modern RL algorithms that optimize for the best returns can achieve good exploitation quite efficiently, while exploration remains more like an open topic.","title":"Exploration Strategies in Deep Reinforcement Learning"},{"content":"[Updated on 2023-01-27: After almost three years, I did a big refactoring update of this post to incorporate a bunch of new Transformer models since 2020. The enhanced version of this post is here: The Transformer Family Version 2.0. Please refer to that post on this topic.] \nIt has been almost two years since my last post on attention. Recent progress on new and enhanced versions of Transformer motivates me to write another post on this specific topic, focusing on how the vanilla Transformer can be improved for longer-term attention span, less memory and computation consumption, RL task solving and more.\nNotations Symbol Meaning $d$ The model size / hidden state dimension / positional encoding size. $h$ The number of heads in multi-head attention layer. $L$ The segment length of input sequence. $\\mathbf{X} \\in \\mathbb{R}^{L \\times d}$ The input sequence where each element has been mapped into an embedding vector of shape $d$, same as the model size. $\\mathbf{W}^k \\in \\mathbb{R}^{d \\times d_k}$ The key weight matrix. $\\mathbf{W}^q \\in \\mathbb{R}^{d \\times d_k}$ The query weight matrix. $\\mathbf{W}^v \\in \\mathbb{R}^{d \\times d_v}$ The value weight matrix. Often we have $d_k = d_v = d$. $\\mathbf{W}^k_i, \\mathbf{W}^q_i \\in \\mathbb{R}^{d \\times d_k/h}; \\mathbf{W}^v_i \\in \\mathbb{R}^{d \\times d_v/h}$ The weight matrices per head. $\\mathbf{W}^o \\in \\mathbb{R}^{d_v \\times d}$ The output weight matrix. $\\mathbf{Q} = \\mathbf{X}\\mathbf{W}^q \\in \\mathbb{R}^{L \\times d_k}$ The query embedding inputs. $\\mathbf{K} = \\mathbf{X}\\mathbf{W}^k \\in \\mathbb{R}^{L \\times d_k}$ The key embedding inputs. $\\mathbf{V} = \\mathbf{X}\\mathbf{W}^v \\in \\mathbb{R}^{L \\times d_v}$ The value embedding inputs. $S_i$ A collection of key positions for the $i$-th query $\\mathbf{q}_i$ to attend to. $\\mathbf{A} \\in \\mathbb{R}^{L \\times L}$ The self-attention matrix between a input sequence of lenght $L$ and itself. $\\mathbf{A} = \\text{softmax}(\\mathbf{Q}\\mathbf{K}^\\top / \\sqrt{d_k})$. $a_{ij} \\in \\mathbf{A}$ The scalar attention score between query $\\mathbf{q}_i$ and key $\\mathbf{k}_j$. $\\mathbf{P} \\in \\mathbb{R}^{L \\times d}$ position encoding matrix, where the $i$-th row $\\mathbf{p}_i$ is the positional encoding for input $\\mathbf{x}_i$. Attention and Self-Attention Attention is a mechanism in the neural network that a model can learn to make predictions by selectively attending to a given set of data. The amount of attention is quantified by learned weights and thus the output is usually formed as a weighted average.\nSelf-attention is a type of attention mechanism where the model makes prediction for one part of a data sample using other parts of the observation about the same sample. Conceptually, it feels quite similar to non-local means. Also note that self-attention is permutation-invariant; in other words, it is an operation on sets.\nThere are various forms of attention / self-attention, Transformer (Vaswani et al., 2017) relies on the scaled dot-product attention: given a query matrix $\\mathbf{Q}$, a key matrix $\\mathbf{K}$ and a value matrix $\\mathbf{V}$, the output is a weighted sum of the value vectors, where the weight assigned to each value slot is determined by the dot-product of the query with the corresponding key:\n $$ \\text{Attention}(\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}) = \\text{softmax}(\\frac{\\mathbf{Q} {\\mathbf{K}}^\\top}{\\sqrt{d_k}})\\mathbf{V} $$ And for a query and a key vector $\\mathbf{q}_i, \\mathbf{k}_j \\in \\mathbb{R}^d$ (row vectors in query and key matrices), we have a scalar score:\n $$ a_{ij} = \\text{softmax}(\\frac{\\mathbf{q}_i {\\mathbf{k}_j}^\\top}{\\sqrt{d_k}}) = \\frac{\\exp(\\mathbf{q}_i {\\mathbf{k}_j}^\\top)}{ \\sqrt{d_k} \\sum_{r \\in S_i} \\exp(\\mathbf{q}_i {\\mathbf{k}_r}^\\top) } $$ where $S_i$ is a collection of key positions for the $i$-th query to attend to.\nSee my old post for other types of attention if interested.\nMulti-Head Self-Attention The multi-head self-attention module is a key component in Transformer. Rather than only computing the attention once, the multi-head mechanism splits the inputs into smaller chunks and then computes the scaled dot-product attention over each subspace in parallel. The independent attention outputs are simply concatenated and linearly transformed into expected dimensions.\n $$ \\begin{aligned} \\text{MultiHeadAttention}(\\mathbf{X}_q, \\mathbf{X}_k, \\mathbf{X}_v) \u0026= [\\text{head}_1; \\dots; \\text{head}_h] \\mathbf{W}^o \\\\ \\text{where head}_i \u0026= \\text{Attention}(\\mathbf{X}_q\\mathbf{W}^q_i, \\mathbf{X}_k\\mathbf{W}^k_i, \\mathbf{X}_v\\mathbf{W}^v_i) \\end{aligned} $$ where $[.;.]$ is a concatenation operation. $\\mathbf{W}^q_i, \\mathbf{W}^k_i \\in \\mathbb{R}^{d \\times d_k/h}, \\mathbf{W}^v_i \\in \\mathbb{R}^{d \\times d_v/h}$ are weight matrices to map input embeddings of size $L \\times d$ into query, key and value matrices. And $\\mathbf{W}^o \\in \\mathbb{R}^{d_v \\times d}$ is the output linear transformation. All the weights should be learned during training.\nFig. 1. Illustration of the multi-head scaled dot-product attention mechanism. (Image source: Figure 2 in Vaswani, et al., 2017) Transformer The Transformer (which will be referred to as \u0026ldquo;vanilla Transformer\u0026rdquo; to distinguish it from other enhanced versions; Vaswani, et al., 2017) model has an encoder-decoder architecture, as commonly used in many NMT models. Later simplified Transformer was shown to achieve great performance in language modeling tasks, like in encoder-only BERT or decoder-only GPT.\nEncoder-Decoder Architecture\nThe encoder generates an attention-based representation with capability to locate a specific piece of information from a large context. It consists of a stack of 6 identity modules, each containing two submodules, a multi-head self-attention layer and a point-wise fully connected feed-forward network. By point-wise, it means that it applies the same linear transformation (with same weights) to each element in the sequence. This can also be viewed as a convolutional layer with filter size 1. Each submodule has a residual connection and layer normalization. All the submodules output data of the same dimension $d$.\nThe function of Transformer decoder is to retrieve information from the encoded representation. The architecture is quite similar to the encoder, except that the decoder contains two multi-head attention submodules instead of one in each identical repeating module. The first multi-head attention submodule is masked to prevent positions from attending to the future.\nFig. 2. The architecture of the vanilla Transformer model. (Image source: Figure 17) Positional Encoding\nBecause self-attention operation is permutation invariant, it is important to use proper positional encodingto provide order information to the model. The positional encoding $\\mathbf{P} \\in \\mathbb{R}^{L \\times d}$ has the same dimension as the input embedding, so it can be added on the input directly. The vanilla Transformer considered two types of encodings:\n(1) Sinusoidal positional encoding is defined as follows, given the token position $i=1,\\dots,L$ and the dimension $\\delta=1,\\dots,d$:\n $$ \\text{PE}(i,\\delta) = \\begin{cases} \\sin(\\frac{i}{10000^{2\\delta'/d}}) \u0026 \\text{if } \\delta = 2\\delta'\\\\ \\cos(\\frac{i}{10000^{2\\delta'/d}}) \u0026 \\text{if } \\delta = 2\\delta' + 1\\\\ \\end{cases} $$ In this way each dimension of the positional encoding corresponds to a sinusoid of different wavelengths in different dimensions, from $2\\pi$ to $10000 \\cdot 2\\pi$.\nFig. 3. Sinusoidal positional encoding with $L=32$ and $d=128$. The value is between -1 (black) and 1 (white) and the value 0 is in gray. (2) Learned positional encoding, as its name suggested, assigns each element with a learned column vector which encodes its absolute position (Gehring, et al. 2017).\nQuick Follow-ups\nFollowing the vanilla Transformer, Al-Rfou et al. (2018) added a set of auxiliary losses to enable training a deep Transformer model on character-level language modeling which outperformed LSTMs. Several types of auxiliary tasks are used:\n Instead of producing only one prediction at the sequence end, every immediate position is also asked to make a correct prediction, forcing the model to predict given smaller contexts (e.g. first couple tokens at the beginning of a context window). Each intermediate Transformer layer is used for making predictions as well. Lower layers are weighted to contribute less and less to the total loss as training progresses. Each position in the sequence can predict multiple targets, i.e. two or more predictions of the future tokens. Fig. 4. Auxiliary prediction tasks used in deep Transformer for character-level language modeling. (Image source: Al-Rfou et al. (2018)) Adaptive Computation Time (ACT) Adaptive Computation Time (short for ACT; Graves, 2016) is a mechanism for dynamically deciding how many computational steps are needed in a recurrent neural network. Here is a cool tutorial on ACT from distill.pub.\nLet\u0026rsquo;s say, we have a RNN model $\\mathcal{R}$ composed of input weights $W_x$, a parametric state transition function $\\mathcal{S}(.)$, a set of output weights $W_y$ and an output bias $b_y$. Given an input sequence $(x_1, \\dots, x_L)$, the output sequence $(y_1, \\dots, y_L)$ is computed by:\n $$ s_t = \\mathcal{S}(s_{t-1}, W_x x_t), \\quad y_t = W_y s_t + b_y\\quad\\text{for }t=1, \\dots, L $$ ACT enables the above RNN setup to perform a variable number of steps at each input element. Multiple computational steps lead to a sequence of intermediate states $(s_t^1, \\dots, s_t^{N(t)})$ and outputs $(y_t^1, \\dots, y_t^{N(t)})$ \u0026mdash; they all share the same state transition function $\\mathcal{S}(.)$, as well as the same output weights $W_y$ and bias $b_y$:\n $$ \\begin{aligned} s_t^0 \u0026= s_{t-1} \\\\ s_t^n \u0026= \\mathcal{S}(s_{t}^{n-1}, x_t^n) = \\mathcal{S}(s_{t}^{n-1}, x_t + \\delta_{n,1}) \\text{ for } n=1, \\dots, N(t)\\\\ y_t^n \u0026= W_y s_t^n + b_y \\end{aligned} $$ where $\\delta_{n,1}$ is a binary flag indicating whether the input step has been incremented.\nThe number of steps $N(t)$ is determined by an extra sigmoidal halting unit $h$, with associated weight matrix $W_h$ and bias $b_h$, outputting a halting probability $p_t^n$ at immediate step $n$ for $t$-th input element:\n $$ h_t^n = \\sigma(W_h s_t^n + b_h) $$ In order to allow the computation to halt after a single step, ACT introduces a small constant $\\epsilon$ (e.g. 0.01), so that whenever the cumulative probability goes above $1-\\epsilon$, the computation stops.\n $$ \\begin{aligned} N(t) \u0026= \\min(\\min\\{n': \\sum_{n=1}^{n'} h_t^n \\geq 1 -\\epsilon\\}, M) \\\\ p_t^n \u0026= \\begin{cases} h_t^n \u0026 \\text{if }n where $M$ is an upper limit for the number of immediate steps allowed.\nThe final state and output are mean-field updates:\n $$ s_t = \\sum_{n=1}^{N(t)} p_t^n s_t^n,\\quad y_t = \\sum_{n=1}^{N(t)} p_t^n y_t^n $$ Fig. 5. The computation graph of a RNN with ACT mechanism. (Image source: Graves, 2016) To avoid unnecessary pondering over each input, ACT adds a ponder cost $\\mathcal{P}(x) = \\sum_{t=1}^L N(t) + R(t) $ in the loss function to encourage a smaller number of intermediate computational steps.\nImproved Attention Span The goal of improving attention span is to make the context that can be used in self-attention longer, more efficient and flexible.\nLonger Attention Span (Transformer-XL) The vanilla Transformer has a fixed and limited attention span. The model can only attend to other elements in the same segments during each update step and no information can flow across separated fixed-length segments.\nThis context segmentation causes several issues:\n The model cannot capture very long term dependencies. It is hard to predict the first few tokens in each segment given no or thin context. The evaluation is expensive. Whenever the segment is shifted to the right by one, the new segment is re-processed from scratch, although there are a lot of overlapped tokens. Transformer-XL (Dai et al., 2019; \u0026ldquo;XL\u0026rdquo; means \u0026ldquo;extra long\u0026rdquo;) solves the context segmentation problem with two main modifications:\n Reusing hidden states between segments. Adopting a new positional encoding that is suitable for reused states. Hidden State Reuse\nThe recurrent connection between segments is introduced into the model by continuously using the hidden states from the previous segments.\nFig. 6. A comparison between the training phrase of vanilla Transformer \u0026 Transformer-XL with a segment length 4. (Image source: left part of Figure 2 in Dai et al., 2019). Let\u0026rsquo;s label the hidden state of the $n$-th layer for the $(\\tau + 1)$-th segment in the model as $\\mathbf{h}_{\\tau+1}^{(n)} \\in \\mathbb{R}^{L \\times d}$. In addition to the hidden state of the last layer for the same segment $\\mathbf{h}_{\\tau+1}^{(n-1)}$, it also depends on the hidden state of the same layer for the previous segment $\\mathbf{h}_{\\tau}^{(n)}$. By incorporating information from the previous hidden states, the model extends the attention span much longer in the past, over multiple segments.\n $$ \\begin{aligned} \\color{red}{\\widetilde{\\mathbf{h}}_{\\tau+1}^{(n-1)}} \u0026= [\\text{stop-gradient}(\\mathbf{h}_{\\tau}^{(n-1)}) \\circ \\mathbf{h}_{\\tau+1}^{(n-1)}] \\\\ \\mathbf{Q}_{\\tau+1}^{(n)} \u0026= \\mathbf{h}_{\\tau+1}^{(n-1)}\\mathbf{W}^q \\\\ \\mathbf{K}_{\\tau+1}^{(n)} \u0026= \\color{red}{\\widetilde{\\mathbf{h}}_{\\tau+1}^{(n-1)}} \\mathbf{W}^k \\\\ \\mathbf{V}_{\\tau+1}^{(n)} \u0026= \\color{red}{\\widetilde{\\mathbf{h}}_{\\tau+1}^{(n-1)}} \\mathbf{W}^v \\\\ \\mathbf{h}_{\\tau+1}^{(n)} \u0026= \\text{transformer-layer}(\\mathbf{Q}_{\\tau+1}^{(n)}, \\mathbf{K}_{\\tau+1}^{(n)}, \\mathbf{V}_{\\tau+1}^{(n)}) \\end{aligned} $$ Note that both key and value rely on the extended hidden state, while the query only consumes hidden state at current step. The concatenation operation $[. \\circ .]$ is along the sequence length dimension.\nRelative Positional Encoding\nIn order to work with this new form of attention span, Transformer-XL proposed a new type of positional encoding. If using the same approach by vanilla Transformer and encoding the absolute position, the previous and current segments will be assigned with the same encoding, which is undesired.\nTo keep the positional information flow coherently across segments, Transformer-XL encodes the relative position instead, as it could be sufficient enough to know the position offset for making good predictions, i.e. $i-j$, between one key vector $\\mathbf{k}_{\\tau, j}$ and its query $\\mathbf{q}_{\\tau, i}$.\nIf omitting the scalar $1/\\sqrt{d_k}$ and the normalizing term in softmax but including positional encodings, we can write the attention score between query at position $i$ and key at position $j$ as:\n $$ \\begin{aligned} a_{ij} \u0026= \\mathbf{q}_i {\\mathbf{k}_j}^\\top = (\\mathbf{x}_i + \\mathbf{p}_i)\\mathbf{W}^q ((\\mathbf{x}_j + \\mathbf{p}_j)\\mathbf{W}^k)^\\top \\\\ \u0026= \\mathbf{x}_i\\mathbf{W}^q {\\mathbf{W}^k}^\\top\\mathbf{x}_j^\\top + \\mathbf{x}_i\\mathbf{W}^q {\\mathbf{W}^k}^\\top\\mathbf{p}_j^\\top + \\mathbf{p}_i\\mathbf{W}^q {\\mathbf{W}^k}^\\top\\mathbf{x}_j^\\top + \\mathbf{p}_i\\mathbf{W}^q {\\mathbf{W}^k}^\\top\\mathbf{p}_j^\\top \\end{aligned} $$ Transformer-XL reparameterizes the above four terms as follows:\n $$ a_{ij}^\\text{rel} = \\underbrace{ \\mathbf{x}_i\\mathbf{W}^q \\color{blue}{ {\\mathbf{W}_E^k}^\\top } \\mathbf{x}_j^\\top }_\\text{content-based addressing} + \\underbrace{ \\mathbf{x}_i\\mathbf{W}^q \\color{blue}{ {\\mathbf{W}_R^k}^\\top } \\color{green}{\\mathbf{r}_{i-j}^\\top} }_\\text{content-dependent positional bias} + \\underbrace{ \\color{red}{\\mathbf{u}} \\color{blue}{ {\\mathbf{W}_E^k}^\\top } \\mathbf{x}_j^\\top }_\\text{global content bias} + \\underbrace{ \\color{red}{\\mathbf{v}} \\color{blue}{ {\\mathbf{W}_R^k}^\\top } \\color{green}{\\mathbf{r}_{i-j}^\\top} }_\\text{global positional bias} $$ Replace $\\mathbf{p}_j$ with relative positional encoding $\\mathbf{r}_{i-j} \\in \\mathbf{R}^{d}$; Replace $\\mathbf{p}_i\\mathbf{W}^q$ with two trainable parameters $\\mathbf{u}$ (for content) and $\\mathbf{v}$ (for location) in two different terms; Split $\\mathbf{W}^k$ into two matrices, $\\mathbf{W}^k_E$ for content information and $\\mathbf{W}^k_R$ for location information. Adaptive Attention Span One key advantage of Transformer is the capability of capturing long-term dependencies. Depending on the context, the model may prefer to attend further sometime than others; or one attention head may had different attention pattern from the other. If the attention span could adapt its length flexibly and only attend further back when needed, it would help reduce both computation and memory cost to support longer maximum context size in the model.\nThis is the motivation for Adaptive Attention Span. Sukhbaatar, et al., (2019) proposed a self-attention mechanism that seeks an optimal attention span. They hypothesized that different attention heads might assign scores differently within the same context window (See Fig. 7) and thus the optimal span would be trained separately per head.\nFig. 7. Two attention heads in the same model, A \u0026 B, assign attention differently within the same context window. Head A attends more to the recent tokens, while head B look further back into the past uniformly. (Image source: Sukhbaatar, et al. 2019) Given the $i$-th token, we need to compute the attention weights between this token and other keys at positions $j \\in S_i$, where $S_i$ defineds the $i$-th token\u0026rsquo;s context window.\n $$ \\begin{aligned} e_{ij} \u0026= \\mathbf{q}_i {\\mathbf{k}_j}^\\top \\\\ a_{ij} \u0026= \\text{softmax}(e_{ij}) = \\frac{\\exp(e_{ij})}{\\sum_{r=i-s}^{i-1} \\exp(e_{ir})} \\\\ \\mathbf{y}_i \u0026= \\sum_{r=i-s}^{i-1}a_{ir}\\mathbf{v}_r = \\sum_{r=i-s}^{i-1}a_{ir}\\mathbf{x}_r\\mathbf{W}^v \\end{aligned} $$ A soft mask function $m_z$ is added to control for an effective adjustable attention span, which maps the distance between query and key into a [0, 1] value. $m_z$ is parameterized by $z \\in [0, s]$ and $z$ is to be learned:\n $$ m_z(x) = \\text{clamp}(\\frac{1}{R}(R+z-x), 0, 1) $$ where $R$ is a hyper-parameter which defines the softness of $m_z$.\nFig. 8. The soft masking function used in the adaptive attention span. (Image source: Sukhbaatar, et al. 2019.) The soft mask function is applied to the softmax elements in the attention weights:\n $$ a_{ij} = \\frac{m_z(i-j)\\exp(s_{ij})}{\\sum_{r=i-s}^{i-1}m_z(i-r) \\exp(s_{ir})} $$ In the above equation, $z$ is differentiable so it is trained jointly with other parts of the model. Parameters $z^{(i)}, i=1, \\dots, h$ are learned separately per head. Moreover, the loss function has an extra L1 penalty on $\\sum_{i=1}^h z^{(i)}$.\nUsing Adaptive Computation Time, the approach can be further enhanced to have flexible attention span length, adaptive to the current input dynamically. The span parameter $z_t$ of an attention head at time $t$ is a sigmoidal function, $z_t = S \\sigma(\\mathbf{v} \\cdot \\mathbf{x}_t +b)$, where the vector $\\mathbf{v}$ and the bias scalar $b$ are learned jointly with other parameters.\nIn the experiments of Transformer with adaptive attention span, Sukhbaatar, et al. (2019) found a general tendency that lower layers do not require very long attention spans, while a few attention heads in higher layers may use exceptionally long spans. Adaptive attention span also helps greatly reduce the number of FLOPS, especially in a big model with many attention layers and a large context length.\nLocalized Attention Span (Image Transformer) The original, also the most popular, use case for Transformer is to do language modeling. The text sequence is one-dimensional in a clearly defined chronological order and thus the attention span grows linearly with increased context size.\nHowever, if we want to use Transformer on images, it is unclear how to define the scope of context or the order. Image Transformer (Parmer, et al 2018) embraces a formulation of image generation similar to sequence modeling within the Transformer framework. Additionally, Image Transformer restricts the self-attention span to only local neighborhoods, so that the model can scale up to process more images in parallel and keep the likelihood loss tractable.\nThe encoder-decoder architecture remains for image-conditioned generation:\n The encoder generates a contextualized, per-pixel-channel representation of the source image; The decoder autoregressively generates an output image, one channel per pixel at each time step. Let\u0026rsquo;s label the representation of the current pixel to be generated as the query $\\mathbf{q}$. Other positions whose representations will be used for computing $\\mathbf{q}$ are key vector $\\mathbf{k}_1, \\mathbf{k}_2, \\dots$ and they together form a memory matrix $\\mathbf{M}$. The scope of $\\mathbf{M}$ defines the context window for pixel query $\\mathbf{q}$.\nImage Transformer introduced two types of localized $\\mathbf{M}$, as illustrated below.\nFig. 9. Illustration of 1D and 2D attention span for visual inputs in Image Transformer. The black line marks a query block and the cyan outlines the actual attention span for pixel q. (Image source: Figure 2 in Parmer et al, 2018) (1) 1D Local Attention: The input image is flattened in the raster scanning order, that is, from left to right and top to bottom. The linearized image is then partitioned into non-overlapping query blocks. The context window consists of pixels in the same query block as $\\mathbf{q}$ and a fixed number of additional pixels generated before this query block.\n(2) 2D Local Attention: The image is partitioned into multiple non-overlapping rectangular query blocks. The query pixel can attend to all others in the same memory blocks. To make sure the pixel at the top-left corner can also have a valid context window, the memory block is extended to the top, left and right by a fixed amount, respectively.\nLess Time and Memory Cost This section introduces several improvements made on Transformer to reduce the computation time and memory consumption.\nSparse Attention Matrix Factorization (Sparse Transformers) The compute and memory cost of the vanilla Transformer grows quadratically with sequence length and thus it is hard to be applied on very long sequences.\nSparse Transformer (Child et al., 2019) introduced factorized self-attention, through sparse matrix factorization, making it possible to train dense attention networks with hundreds of layers on sequence length up to 16,384, which would be infeasible on modern hardware otherwise.\nGiven a set of attention connectivity pattern $\\mathcal{S} = \\{S_1, \\dots, S_n\\}$, where each $S_i$ records a set of key positions that the $i$-th query vector attends to.\n $$ \\begin{aligned} \\text{Attend}(\\mathbf{X}, \\mathcal{S}) \u0026= \\Big( a(\\mathbf{x}_i, S_i) \\Big)_{i \\in \\{1, \\dots, L\\}} \\\\ \\text{ where } a(\\mathbf{x}_i, S_i) \u0026= \\text{softmax}\\Big(\\frac{(\\mathbf{x}_i \\mathbf{W}^q)(\\mathbf{x}_j \\mathbf{W}^k)_{j \\in S_i}^\\top}{\\sqrt{d_k}}\\Big) (\\mathbf{x}_j \\mathbf{W}^v)_{j \\in S_i} \\end{aligned} $$ Note that although the size of $S_i$ is not fixed, $a(\\mathbf{x}_i, S_i)$ is always of size $d_v$ and thus $\\text{Attend}(\\mathbf{X}, \\mathcal{S}) \\in \\mathbb{R}^{L \\times d_v}$.\nIn anto-regressive models, one attention span is defined as $S_i = \\{j: j \\leq i\\}$ as it allows each token to attend to all the positions in the past.\nIn factorized self-attention, the set $S_i$ is decomposed into a tree of dependencies, such that for every pair of $(i, j)$ where $j \\leq i$, there is a path connecting $i$ back to $j$ and $i$ can attend to $j$ either directly or indirectly.\nPrecisely, the set $S_i$ is divided into $p$ non-overlapping subsets, where the $m$-th subset is denoted as $A^{(m)}_i \\subset S_i, m = 1,\\dots, p$. Therefore the path between the output position $i$ and any $j$ has a maximum length $p + 1$. For example, if $(j, a, b, c, \\dots, i)$ is a path of indices between $i$ and $j$, we would have $j \\in A_a^{(1)}, a \\in A_b^{(2)}, b \\in A_c^{(3)}, \\dots$, so on and so forth.\nSparse Factorized Attention\nSparse Transformer proposed two types of fractorized attention. It is easier to understand the concepts as illustrated in Fig. 10 with 2D image inputs as examples.\nFig. 10. The top row illustrates the attention connectivity patterns in (a) Transformer, (b) Sparse Transformer with strided attention, and (c) Sparse Transformer with fixed attention. The bottom row contains corresponding self-attention connectivity matrices. Note that the top and bottom rows are not in the same scale. (Image source: Child et al., 2019 + a few of extra annotations.) (1) Strided attention with stride $\\ell \\sim \\sqrt{n}$. This works well with image data as the structure is aligned with strides. In the image case, each pixel would attend to all the previous $\\ell$ pixels in the raster scanning order (naturally cover the entire width of the image) and then those pixels attend to others in the same column (defined by another attention connectivity subset).\n $$ \\begin{aligned} A_i^{(1)} \u0026= \\{ t, t+1, \\dots, i\\} \\text{, where } t = \\max(0, i - \\ell) \\\\ A_i^{(2)} \u0026= \\{j: (i-j) \\mod \\ell = 0\\} \\end{aligned} $$ (2) Fixed attention. A small set of tokens summarize previous locations and propagate that information to all future locations.\n $$ \\begin{aligned} A_i^{(1)} \u0026= \\{j: \\lfloor \\frac{j}{\\ell} \\rfloor = \\lfloor \\frac{i}{\\ell} \\rfloor \\} \\\\ A_i^{(2)} \u0026= \\{j: j \\mod \\ell \\in \\{\\ell-c, \\dots, \\ell-1\\} \\} \\end{aligned} $$ where $c$ is a hyperparameter. If $c=1$, it restricts the representation whereas many depend on a few positions. The paper chose $c\\in \\{ 8, 16, 32 \\}$ for $\\ell \\in \\{ 128, 256 \\}$.\nUse Factorized Self-Attention in Transformer\nThere are three ways to use sparse factorized attention patterns in Transformer architecture:\n One attention type per residual block and then interleave them, $\\text{attention}(\\mathbf{X}) = \\text{Attend}(\\mathbf{X}, A^{(n \\mod p)}) \\mathbf{W}^o$, where $n$ is the index of the current residual block. Set up a single head which attends to locations that all the factorized heads attend to, $\\text{attention}(\\mathbf{X}) = \\text{Attend}(\\mathbf{X}, \\cup_{m=1}^p A^{(m)}) \\mathbf{W}^o $. Use a multi-head attention mechanism, but different from vanilla Transformer, each head might adopt a pattern presented above, 1 or 2. =\u0026gt; This option often performs the best. Sparse Transformer also proposed a set of changes so as to train the Transformer up to hundreds of layers, including gradient checkpointing, recomputing attention \u0026amp; FF layers during the backward pass, mixed precision training, efficient block-sparse implementation, etc. Please check the paper for more details.\nLocality-Sensitive Hashing (Reformer) The improvements proposed by the Reformer model (Kitaev, et al. 2020) aim to solve the following pain points in Transformer:\n Memory in a model with $N$ layers is $N$-times larger than in a single-layer model because we need to store activations for back-propagation. The intermediate FF layers are often quite large. The attention matrix on sequences of length $L$ often requires $O(L^2)$ in both memory and time. Reformer proposed two main changes:\n Replace the dot-product attention with locality-sensitive hashing (LSH) attention, reducing the complexity from $O(L^2)$ to $O(L\\log L)$. Replace the standard residual blocks with reversible residual layers, which allows storing activations only once during training instead of $N$ times (i.e. proportional to the number of layers). Locality-Sensitive Hashing Attention\nIn $\\mathbf{Q} \\mathbf{K}^\\top$ part of the attention formula, we are only interested in the largest elements as only large elements contribute a lot after softmax. For each query $\\mathbf{q}_i \\in \\mathbf{Q}$, we are looking for row vectors in $\\mathbf{K}$ closest to $\\mathbf{q}_i$. In order to find nearest neighbors quickly in high-dimensional space, Reformer incorporates Locality-Sensitive Hashing (LSH) into its attention mechanism.\nA hashing scheme $x \\mapsto h(x)$ is locality-sensitive if it preserves the distancing information between data points, such that close vectors obtain similar hashes while distant vectors have very different ones. The Reformer adopts a hashing scheme as such, given a fixed random matrix $\\mathbf{R} \\in \\mathbb{R}^{d \\times b/2}$ (where $b$ is a hyperparam), the hash function is $h(x) = \\arg\\max([xR; −xR])$.\n$$ \\mathbf{o}_i = \\sum_{j \\in S_i} \\exp(\\mathbf{q}_i \\cdot \\mathbf{k}_j - Z(i, S_i)) \\mathbf{v}_j \\text{, where } S_i = \\{j: j \\leq i\\} $$ -- Fig. 11. Illustration of Locality-Sensitive Hashing (LSH) attention. (Image source: right part of Figure 1 in Kitaev, et al. 2020). In LSH attention, a query can only attend to positions in the same hashing bucket, $S_i = \\{j: h(\\mathbf{q}_i) = h(\\mathbf{k}_j)\\}$. It is carried out in the following process, as illustrated in Fig. 11:\n (a) The attention matrix for full attention is often sparse. (b) Using LSH, we can sort the keys and queries to be aligned according to their hash buckets. (c) Set $\\mathbf{Q} = \\mathbf{K}$ (precisely $\\mathbf{k}_j = \\mathbf{q}_j / |\\mathbf{q}_j|$), so that there are equal numbers of keys and queries in one bucket, easier for batching. Interestingly, this \u0026ldquo;shared-QK\u0026rdquo; config does not affect the performance of the Transformer. (d) Apply batching where chunks of $m$ consecutive queries are grouped together. Fig. 12. The LSH attention consists of 4 steps: bucketing, sorting, chunking, and attention computation. (Image source: left part of Figure 1 in Kitaev, et al. 2020). Reversible Residual Network\nAnother improvement by Reformer is to use reversible residual layers (Gomez et al. 2017). The motivation for reversible residual network is to design the architecture in a way that activations at any given layer can be recovered from the activations at the following layer, using only the model parameters. Hence, we can save memory by recomputing the activation during backprop rather than storing all the activations.\nGiven a layer $x \\mapsto y$, the normal residual layer does $y = x + F(x)$, but the reversible layer splits both input and output into pairs $(x_1, x_2) \\mapsto (y_1, y_2)$ and then executes the following:\n $$ y_1 = x_1 + F(x_2),\\; y_2 = x_2 + G(y_1) $$ and reversing is easy:\n $$ x_2 = y_2 - G(y_1), \\; x_1 = y_1 − F(x_2) $$ Reformer applies the same idea to Transformer by combination attention ($F$) and feed-forward layers ($G$) within a reversible net block:\n $$ Y_1 = X_1 + \\text{Attention}(X_2), \\; Y_2 = X_2 + \\text{FeedForward}(Y_1) $$ The memory can be further reduced by chunking the feed-forward computation:\n $$ Y_2 = [Y_2^{(1)}; \\dots; Y_2^{(c)}] = [X_2^{(1)} + \\text{FeedForward}(Y_1^{(1)}); \\dots; X_2^{(c)} + \\text{FeedForward}(Y_1^{(c)})] $$ The resulting reversible Transformer does not need to store activation in every layer.\nMake it Recurrent (Universal Transformer) The Universal Transformer (Dehghani, et al. 2019) combines self-attention in Transformer with the recurrent mechanism in RNN, aiming to benefit from both a long-term global receptive field of Transformer and learned inductive biases of RNN.\nRather than going through a fixed number of layers, Universal Transformer dynamically adjusts the number of steps using adaptive computation time. If we fix the number of steps, an Universal Transformer is equivalent to a multi-layer Transformer with shared parameters across layers.\nOn a high level, the universal transformer can be viewed as a recurrent function for learning the hidden state representation per token. The recurrent function evolves in parallel across token positions and the information between positions is shared through self-attention.\nFig. 13. How the Universal Transformer refines a set of hidden state representations repeatedly for every position in parallel. (Image source: Figure 1 in Dehghani, et al. 2019). Given an input sequence of length $L$, Universal Transformer iteratively updates the representation $\\mathbf{H}^t \\in \\mathbb{R}^{L \\times d}$ at step $t$ for an adjustable number of steps. At step 0, $\\mathbf{H}^0$ is initialized to be same as the input embedding matrix. All the positions are processed in parallel in the multi-head self-attention mechanism and then go through a recurrent transition function.\n $$ \\begin{aligned} \\mathbf{A}^t \u0026= \\text{LayerNorm}(\\mathbf{H}^{t-1} + \\text{MultiHeadAttention}(\\mathbf{H}^{t-1} + \\mathbf{P}^t) \\\\ \\mathbf{H}^t \u0026= \\text{LayerNorm}(\\mathbf{A}^{t-1} + \\text{Transition}(\\mathbf{A}^t)) \\end{aligned} $$ where $\\text{Transition}(.)$ is either a separable convolution or a fully-connected neural network that consists of two position-wise (i.e. applied to each row of $\\mathbf{A}^t$ individually) affine transformation + one ReLU.\nThe positional encoding $\\mathbf{P}^t$ uses sinusoidal position signal but with an additional time dimension:\n $$ \\text{PE}(i, t, \\delta) = \\begin{cases} \\sin(\\frac{i}{10000^{2\\delta'/d}}) \\oplus \\sin(\\frac{t}{10000^{2\\delta'/d}}) \u0026 \\text{if } \\delta = 2\\delta'\\\\ \\cos(\\frac{i}{10000^{2\\delta'/d}}) \\oplus \\cos(\\frac{t}{10000^{2\\delta'/d}}) \u0026 \\text{if } \\delta = 2\\delta' + 1\\\\ \\end{cases} $$ Fig. 14. A simplified illustration of Universal Transformer. The encoder and decoder share the same basic recurrent structure. But the decoder also attends to final encoder representation $\\mathbf{H}^T$. (Image source: Figure 2 in Dehghani, et al. 2019) In the adaptive version of Universal Transformer, the number of recurrent steps $T$ is dynamically determined by ACT. Each position is equipped with a dynamic ACT halting mechanism. Once a per-token recurrent block halts, it stops taking more recurrent updates but simply copies the current value to the next step until all the blocks halt or until the model reaches a maximum step limit.\nStabilization for RL (GTrXL) The self-attention mechanism avoids compressing the whole past into a fixed-size hidden state and does not suffer from vanishing or exploding gradients as much as RNNs. Reinforcement Learning tasks can for sure benefit from these traits. However, it is quite difficult to train Transformer even in supervised learning, let alone in the RL context. It could be quite challenging to stabilize and train a LSTM agent by itself, after all.\nThe Gated Transformer-XL (GTrXL; Parisotto, et al. 2019) is one attempt to use Transformer for RL. GTrXL succeeded in stabilizing training with two changes on top of Transformer-XL:\n The layer normalization is only applied on the input stream in a residual module, but NOT on the shortcut stream. A key benefit to this reordering is to allow the original input to flow from the first to last layer. The residual connection is replaced with a GRU-style (Gated Recurrent Unit; Chung et al., 2014) gating mechanism. $$ \\begin{aligned} r \u0026= \\sigma(W_r^{(l)} y + U_r^{(l)} x) \\\\ z \u0026= \\sigma(W_z^{(l)} y + U_z^{(l)} x - b_g^{(l)}) \\\\ \\hat{h} \u0026= \\tanh(W_g^{(l)} y + U_g^{(l)} (r \\odot x)) \\\\ g^{(l)}(x, y) \u0026= (1-z)\\odot x + z\\odot \\hat{h} \\end{aligned} $$ The gating function parameters are explicitly initialized to be close to an identity map - this is why there is a $b_g$ term. A $b_g \u0026gt; 0$ greatly helps with the learning speedup.\nFig. 15. Comparison of the model architecture of Transformer-XL, Transformer-XL with the layer norm reordered, and Gated Transformer-XL. (Image source: Figure 1 in Parisotto, et al. 2019) Citation Cited as:\n Weng, Lilian. (Apr 2020). The transformer family. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2020-04-07-the-transformer-family/.\n Or\n@article{weng2020transformer, title = \u0026quot;The Transformer Family\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2020\u0026quot;, month = \u0026quot;Apr\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2020-04-07-the-transformer-family/\u0026quot; } Reference [1] Ashish Vaswani, et al. \u0026ldquo;Attention is all you need.\u0026quot; NIPS 2017.\n[2] Rami Al-Rfou, et al. \u0026ldquo;Character-level language modeling with deeper self-attention.\u0026quot; AAAI 2019.\n[3] Olah \u0026amp; Carter, \u0026ldquo;Attention and Augmented Recurrent Neural Networks\u0026rdquo;, Distill, 2016.\n[4] Sainbayar Sukhbaatar, et al. \u0026ldquo;Adaptive Attention Span in Transformers\u0026rdquo;. ACL 2019.\n[5] Rewon Child, et al. \u0026ldquo;Generating Long Sequences with Sparse Transformers\u0026rdquo; arXiv:1904.10509 (2019).\n[6] Nikita Kitaev, et al. \u0026ldquo;Reformer: The Efficient Transformer\u0026rdquo; ICLR 2020.\n[7] Alex Graves. (\u0026ldquo;Adaptive Computation Time for Recurrent Neural Networks\u0026rdquo;)[https://arxiv.org/abs/1603.08983]\n[8] Niki Parmar, et al. \u0026ldquo;Image Transformer\u0026rdquo; ICML 2018.\n[9] Zihang Dai, et al. \u0026ldquo;Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.\u0026quot; ACL 2019.\n[10] Aidan N. Gomez, et al. \u0026ldquo;The Reversible Residual Network: Backpropagation Without Storing Activations\u0026rdquo; NIPS 2017.\n[11] Mostafa Dehghani, et al. \u0026ldquo;Universal Transformers\u0026rdquo; ICLR 2019.\n[12] Emilio Parisotto, et al. \u0026ldquo;Stabilizing Transformers for Reinforcement Learning\u0026rdquo; arXiv:1910.06764 (2019).\n","permalink":"https://lilianweng.github.io/posts/2020-04-07-the-transformer-family/","summary":"[Updated on 2023-01-27: After almost three years, I did a big refactoring update of this post to incorporate a bunch of new Transformer models since 2020. The enhanced version of this post is here: The Transformer Family Version 2.0. Please refer to that post on this topic.] \nIt has been almost two years since my last post on attention. Recent progress on new and enhanced versions of Transformer motivates me to write another post on this specific topic, focusing on how the vanilla Transformer can be improved for longer-term attention span, less memory and computation consumption, RL task solving and more.","title":"The Transformer Family"},{"content":"[Updated on 2020-02-03: mentioning PCG in the \u0026ldquo;Task-Specific Curriculum\u0026rdquo; section. [Updated on 2020-02-04: Add a new \u0026ldquo;curriculum through distillation\u0026rdquo; section.\nIt sounds like an impossible task if we want to teach integral or derivative to a 3-year-old who does not even know basic arithmetics. That\u0026rsquo;s why education is important, as it provides a systematic way to break down complex knowledge and a nice curriculum for teaching concepts from simple to hard. A curriculum makes learning difficult things easier and approachable for us humans. But, how about machine learning models? Can we train our models more efficiently with a curriculum? Can we design a curriculum to speed up learning?\nBack in 1993, Jeffrey Elman has proposed the idea of training neural networks with a curriculum. His early work on learning simple language grammar demonstrated the importance of such a strategy: starting with a restricted set of simple data and gradually increasing the complexity of training samples; otherwise the model was not able to learn at all.\nCompared to training without a curriculum, we would expect the adoption of the curriculum to expedite the speed of convergence and may or may not improve the final model performance. To design an efficient and effective curriculum is not easy. Keep in mind that, a bad curriculum may even hamper learning.\nNext, we will look into several categories of curriculum learning, as illustrated in Fig. 1. Most cases are applied to Reinforcement Learning, with a few exceptions on Supervised Learning.\nFig. 1. Five types of curriculum for reinforcement learning. In \u0026ldquo;The importance of starting small\u0026rdquo; paper (Elman 1993), I especially like the starting sentences and find them both inspiring and affecting:\n \u0026ldquo;Humans differ from other species along many dimensions, but two are particularly noteworthy. Humans display an exceptional capacity to learn; and humans are remarkable for the unusually long time it takes to reach maturity. The adaptive advantage of learning is clear, and it may be argued that, through culture, learning has created the basis for a non-genetically based transmission of behaviors which may accelerate the evolution of our species.\u0026rdquo;\n Indeed, learning is probably the best superpower we humans have.\nTask-Specific Curriculum Bengio, et al. (2009) provided a good overview of curriculum learning in the old days. The paper presented two ideas with toy experiments using a manually designed task-specific curriculum:\n Cleaner Examples may yield better generalization faster. Introducing gradually more difficult examples speeds up online training. It is plausible that some curriculum strategies could be useless or even harmful. A good question to answer in the field is: What could be the general principles that make some curriculum strategies work better than others? The Bengio 2009 paper hypothesized it would be beneficial to make learning focus on \u0026ldquo;interesting\u0026rdquo; examples that are neither too hard or too easy.\nIf our naive curriculum is to train the model on samples with a gradually increasing level of complexity, we need a way to quantify the difficulty of a task first. One idea is to use its minimal loss with respect to another model while this model is pretrained on other tasks (Weinshall, et al. 2018). In this way, the knowledge of the pretrained model can be transferred to the new model by suggesting a rank of training samples. Fig. 2 shows the effectiveness of the curriculum group (green), compared to control (random order; yellow) and anti (reverse the order; red) groups.\nFig. 2. Image classification accuracy on test image set (5 member classes of \"small mammals\" in CIFAR100). There are 4 experimental groups, (a) `curriculum`: sort the labels by the confidence of another trained classifier (e.g. the margin of an SVM); (b) `control-curriculum`: sort the labels randomly; (c) `anti-curriculum`: sort the labels reversely; (d) `None`: no curriculum. (Image source: Weinshall, et al. 2018) Zaremba \u0026amp; Sutskever (2014) did an interesting experiment on training LSTM to predict the output of a short Python program for mathematical ops without actually executing the code. They found curriculum is necessary for learning. The program\u0026rsquo;s complexity is controlled by two parameters, length ∈ [1, a] and nesting∈ [1, b]. Three strategies are considered:\n Naive curriculum: increase length first until reaching a; then increase nesting and reset length to 1; repeat this process until both reach maximum. Mix curriculum: sample length ~ [1, a] and nesting ~ [1, b] Combined: naive + mix. They noticed that combined strategy always outperformed the naive curriculum and would generally (but not always) outperform the mix strategy \u0026mdash; indicating that it is quite important to mix in easy tasks during training to avoid forgetting.\nProcedural content generation (PCG) is a popular approach for creating video games of various levels of difficulty. PCG involves algorithmic randomness and a heavy dose of human expertise in designing game elements and dependencies among them. Procedurally generated levels have been introduced into several benchmark environments for evaluating whether an RL agent can generalize to a new level that it is not trained on (meta-RL!), such as GVGAI, OpenAI CoinRun and Procgen benchmark. Using GVGAI, Justesen, et al. (2018) demonstrated that an RL policy can easily overfit to a specific game but training over a simple curriculum that grows the task difficulty together with the model performance helps its generalization to new human-designed levels. Similar results are also found in CoinRun (Cobbe, et al. 2018). POET (Wang et al, 2019) is another example for leveraging evolutionary algorithm and procedural generated game levels to improve RL generalization, which I\u0026rsquo;ve described in details in my meta-RL post.\nTo follow the curriculum learning approaches described above, generally we need to figure out two problems in the training procedure:\n Design a metric to quantify how hard a task is so that we can sort tasks accordingly. Provide a sequence of tasks with an increasing level of difficulty to the model during training. However, the order of tasks does not have to be sequential. In our Rubik\u0026rsquo;s cube paper (OpenAI et al, 2019), we depended on Automatic domain randomization (ADR) to generate a curriculum by growing a distribution of environments with increasing complexity. The difficulty of each task (i.e. solving a Rubik\u0026rsquo;s cube in a set of environments) depends on the randomization ranges of various environmental parameters. Even with a simplified assumption that all the environmental parameters are uncorrelated, we were able to create a decent curriculum for our robot hand to learn the task.\nTeacher-Guided Curriculum The idea of Automatic Curriculum Learning was proposed by Graves, et al. 2017 slightly earlier. It considers a $N$-task curriculum as an $N$-armed bandit problem and an adaptive policy which learns to optimize the returns from this bandit.\nTwo categories of learning signals have been considered in the paper:\n Loss-driven progress: the loss function change before and after one gradient update. This type of reward signals tracks the speed of the learning process, because the greatest task loss decrease is equivalent to the fastest learning. Complex-driven progress: the KL divergence between posterior and prior distribution over network weights. This type of learning signals are inspired by the MDL principle, \u0026ldquo;increasing the model complexity by a certain amount is only worthwhile if it compresses the data by a greater amount\u0026rdquo;. The model complexity is therefore expected to increase most in response to the model nicely generalizing to training examples. This framework of proposing curriculum automatically through another RL agent was formalized as Teacher-Student Curriculum Learning (TSCL; Matiisen, et al. 2017). In TSCL, a student is an RL agent working on actual tasks while a teacher agent is a policy for selecting tasks. The student aims to master a complex task that might be hard to learn directly. To make this task easier to learn, we set up the teacher agent to guide the student\u0026rsquo;s training process by picking proper sub-tasks.\nFig. 3. The setup of teacher-student curriculum learning. (Image source: Matiisen, et al. 2017 + my annotation in red.) In the process, the student should learn tasks which:\n can help the student make fastest learning progress, or are at risk of being forgotten. Note: The setup of framing the teacher model as an RL problem feels quite similar to Neural Architecture Search (NAS), but differently the RL model in TSCL operates on the task space and NAS operates on the main model architecture space.\n Training the teacher model is to solve a POMDP problem:\n The unobserved $s_t$ is the full state of the student model. The observed $o = (x_t^{(1)}, \\dots, x_t^{(N)})$ are a list of scores for $N$ tasks. The action $a$ is to pick on subtask. The reward per step is the score delta.$r_t = \\sum_{i=1}^N x_t^{(i)} - x_{t-1}^{(i)}$ (i.e., equivalent to maximizing the score of all tasks at the end of the episode). The method of estimating learning progress from noisy task scores while balancing exploration vs exploitation can be borrowed from the non-stationary multi-armed bandit problem \u0026mdash; use ε-greedy, or Thompson sampling.\nThe core idea, in summary, is to use one policy to propose tasks for another policy to learn better. Interestingly, both works above (in the discrete task space) found that uniformly sampling from all tasks is a surprisingly strong benchmark.\nWhat if the task space is continuous? Portelas, et al. (2019) studied a continuous teacher-student framework, where the teacher has to sample parameters from continuous task space to generate a learning curriculum. Given a newly sampled parameter $p$, the absolute learning progress (short for ALP) is measured as $\\text{ALP}_p = \\vert r - r_\\text{old} \\vert$, where $r$ is the episodic reward associated with $p$ and $r_\\text{old}$ is the reward associated with $p_\\text{old}$. Here, $p_\\text{old}$ is a previous sampled parameter closest to $p$ in the task space, which can be retrieved by nearest neighbor. Note that how this ALP score is different from learning signals in TSCL or Grave, et al. 2017 above: ALP score measures the reward difference between two tasks rather than performance at two time steps of the same task.\nOn top of the task parameter space, a Gaussian mixture model is trained to fit the distribution of $\\text{ALP}_p$ over $p$. ε-greedy is used when sampling the tasks: with some probability, sampling a random task; otherwise sampling proportionally to ALP score from the GMM model.\nFig. 4. The algorithm of ALP-GMM (absolute learning progress Gaussian mixture model). (Image source: Portelas, et al., 2019) Curriculum through Self-Play Different from the teacher-student framework, two agents are doing very different things. The teacher learns to pick a task for the student without any knowledge of the actual task content. What if we want to make both train on the main task directly? How about even make them compete with each other?\nSukhbaatar, et al. (2017) proposed a framework for automatic curriculum learning through asymmetric self-play. Two agents, Alice and Bob, play the same task with different goals: Alice challenges Bob to achieve the same state and Bob attempts to complete it as fast as he can.\nFig. 5. Illustration of the self-play setup when training two agents. The example task is MazeBase: An agent is asked to reach a goal flag in a maze with a light switch, a key and a wall with a door. Toggling the key switch can open or close the door and Turning off the light makes only the glowing light switch available to the agent. (Image source: Sukhbaatar, et al. 2017) Let us consider Alice and Bob as two separate copies for one RL agent trained in the same environment but with different brains. Each of them has independent parameters and loss objective. The self-play-driven training consists of two types of episodes:\n In the self-play episode, Alice alters the state from $s_0$ to $s_t$ and then Bob is asked to return the environment to its original state $s_0$ to get an internal reward. In the target task episode, Bob receives an external reward if he visits the target flag. Note that since B has to repeat the actions between the same pair of $(s_0, s_t)$ of A, this framework only works in reversible or resettable environments.\nAlice should learn to push Bob out of his comfort zone, but not give him impossible tasks. Bob\u0026rsquo;s reward is set as $R_B = -\\gamma t_B$ and Alice\u0026rsquo;s reward is $R_A = \\gamma \\max(0, t_B - t_A)$, where $t_B$ is the total time for B to complete the task, $t_A$ is the time until Alice performs the STOP action and $\\gamma$ is a scalar constant to rescale the reward to be comparable with the external task reward. If B fails a task, $t_B = t_\\max - t_A$. Both policies are goal-conditioned. The losses imply:\n B wants to finish a task asap. A prefers tasks that take more time of B. A does not want to take too many steps when B is failing. In this way, the interaction between Alice and Bob automatically builds a curriculum of increasingly challenging tasks. Meanwhile, as A has done the task herself before proposing the task to B, the task is guaranteed to be solvable.\nThe paradigm of A suggesting tasks and then B solving them does sound similar to the Teacher-Student framework. However, in asymmetric self-play, Alice, who plays a teacher role, also works on the same task to find challenging cases for Bob, rather than optimizes B\u0026rsquo;s learning process explicitly.\nAutomatic Goal Generation Often RL policy needs to be able to perform over a set of tasks. The goal should be carefully chosen so that at every training stage, it would not be too hard or too easy for the current policy. A goal $g \\in \\mathcal{G}$ can be defined as a set of states $S^g$ and a goal is considered as achieved whenever an agent arrives at any of those states.\nThe approach of Generative Goal Learning (Florensa, et al. 2018) relies on a Goal GAN to generate desired goals automatically. In their experiment, the reward is very sparse, just a binary flag for whether a goal is achieved or not and the policy is conditioned on goal,\n $$ \\begin{aligned} \\pi^{*}(a_t\\vert s_t, g) \u0026= \\arg\\max_\\pi \\mathbb{E}_{g\\sim p_g(.)} R^g(\\pi) \\\\ \\text{where }R^g(\\pi) \u0026= \\mathbb{E}_\\pi(.\\mid s_t, g) \\mathbf{1}[\\exists t \\in [1,\\dots, T]: s_t \\in S^g] \\end{aligned} $$ Here $R^g(\\pi)$ is the expected return, also equivalent to the success probability. Given sampled trajectories from the current policy, as long as any state belongs to the goal set, the return will be positive.\nTheir approach iterates through 3 steps until the policy converges:\n Label a set of goals based on whether they are at the appropriate level of difficulty for the current policy. The set of goals at the appropriate level of difficulty are named GOID (short for \u0026ldquo;Goals of Intermediate Difficulty\u0026rdquo;).$\\text{GOID}_i := \\{g : R_\\text{min} \\leq R^g(\\pi_i) \\leq R_\\text{max} \\} \\subseteq G$ Here $R_\\text{min}$ and $R_\\text{max}$ can be interpreted as a minimum and maximum probability of reaching a goal over T time-steps. Train a Goal GAN model using labelled goals from step 1 to produce new goals Use these new goals to train the policy, improving its coverage objective. The Goal GAN generates a curriculum automatically:\n Generator $G(z)$: produces a new goal. =\u0026gt; expected to be a goal uniformly sampled from $GOID$ set. Discriminator $D(g)$: evaluates whether a goal can be achieved. =\u0026gt; expected to tell whether a goal is from $GOID$ set. The Goal GAN is constructed similar to LSGAN (Least-Squared GAN; Mao et al., (2017)), which has better stability of learning compared to vanilla GAN. According to LSGAN, we should minimize the following losses for $D$ and $G$ respectively:\n $$ \\begin{aligned} \\mathcal{L}_\\text{LSGAN}(D) \u0026= \\frac{1}{2} \\mathbb{E}_{g \\sim p_\\text{data}(g)} [ (D(g) - b)^2] + \\frac{1}{2} \\mathbb{E}_{z \\sim p_z(z)} [ (D(G(z)) - a)^2] \\\\ \\mathcal{L}_\\text{LSGAN}(G) \u0026= \\frac{1}{2} \\mathbb{E}_{z \\sim p_z(z)} [ (D(G(z)) - c)^2] \\end{aligned} $$ where $a$ is the label for fake data, $b$ for real data, and $c$ is the value that $G$ wants $D$ to believe for fake data. In LSGAN paper\u0026rsquo;s experiments, they used $a=-1, b=1, c=0$.\nThe Goal GAN introduces an extra binary flag $y_b$ indicating whether a goal $g$ is real ($y_g = 1$) or fake ($y_g = 0$) so that the model can use negative samples for training:\n $$ \\begin{aligned} \\mathcal{L}_\\text{GoalGAN}(D) \u0026= \\frac{1}{2} \\mathbb{E}_{g \\sim p_\\text{data}(g)} [ (D(g) - b)^2 + (1-y_g) (D(g) - a)^2] + \\frac{1}{2} \\mathbb{E}_{z \\sim p_z(z)} [ (D(G(z)) - a)^2] \\\\ \\mathcal{L}_\\text{GoalGAN}(G) \u0026= \\frac{1}{2} \\mathbb{E}_{z \\sim p_z(z)} [ (D(G(z)) - c)^2] \\end{aligned} $$ Fig. 6. The algorithm of Generative Goal Learning. (Image source: (Florensa, et al. 2018) Following the same idea, Racaniere \u0026amp; Lampinen, et al. (2019) designs a method to make the objectives of goal generator more sophisticated. Their method contains three components, same as generative goal learning above:\n Solver/Policy $\\pi$: In each episode, the solver gets a goal $g$ at the beginning and get a single binary reward $R^g$ at the end. Judge/Discriminator $D(.)$: A classifier to predict the binary reward (whether goal can be achieved or not); precisely it outputs the logit of a probability of achieving the given goal, $\\sigma(D(g)) = p(R^g=1\\vert g)$, where $\\sigma$ is the sigmoid function. Setter/Generator $G(.)$: The goal setter takes as input a desired feasibility score $f \\in \\text{Unif}(0, 1)$ and generates $g = G(z, f)$, where the latent variable $z$ is sampled by $z \\sim \\mathcal{N}(0, I)$. The goal generator is designed to reversible, so $G^{-1}$ can map backwards from a goal $g$ to a latent $z = G^{-1}(g, f)$ The generator is optimized with three objectives:\n Goal validity: The proposed goal should be achievable by an expert policy. The corresponding generative loss is designed to increase the likelihood of generating goals that the solver policy has achieved before (like in HER). $\\mathcal{L}_\\text{val}$ is the negative log-likelihood of generated goals that have been solved by the solver in the past. $$ \\begin{align*} \\mathcal{L}_\\text{val} = \\mathbb{E}_{\\substack{ g \\sim \\text{ achieved by solver}, \\\\ \\xi \\in \\text{Uniform}(0, \\delta), \\\\ f \\in \\text{Uniform}(0, 1) }} \\big[ -\\log p(G^{-1}(g + \\xi, f)) \\big] \\end{align*} $$ Goal feasibility: The proposed goal should be achievable by the current policy; that is, the level of difficulty should be appropriate. $\\mathcal{L}_\\text{feas}$ is the output probability by the judge model $D$ on the generated goal $G(z, f)$ should match the desired $f$. $$ \\begin{align*} \\mathcal{L}_\\text{feas} = \\mathbb{E}_{\\substack{ z \\in \\mathcal{N}(0, 1), \\\\ f \\in \\text{Uniform}(0, 1) }} \\big[ D(G(z, f)) - \\sigma^{-1}(f)^2 \\big] \\end{align*} $$ Goal coverage: We should maximize the entropy of generated goals to encourage diverse goal and to improve the coverage over the goal space. $$ \\begin{align*} \\mathcal{L}_\\text{cov} = \\mathbb{E}_{\\substack{ z \\in \\mathcal{N}(0, 1), \\\\ f \\in \\text{Uniform}(0, 1) }} \\big[ \\log p(G(z, f)) \\big] \\end{align*} $$ Their experiments showed complex environments require all three losses above. When the environment is changing between episodes, both the goal generator and the discriminator need to be conditioned on environmental observation to produce better results. If there is a desired goal distribution, an additional loss can be added to match a desired goal distribution using Wasserstein distance. Using this loss, the generator can push the solver toward mastering the desired tasks more efficiently.\nFig. 7. Training schematic for the (a) solver/policy, (b) judge/discriminator, and (c) setter/goal generator models. (Image source: Racaniere \u0026 Lampinen, et al., 2019) Skill-Based Curriculum Another view is to decompose what an agent is able to complete into a variety of skills and each skill set could be mapped into a task. Let\u0026rsquo;s imagine when an agent interacts with the environment in an unsupervised manner, is there a way to discover useful skills from such interaction and further build into the solutions for more complicated tasks through a curriculum?\nJabri, et al. (2019) developed an automatic curriculum, CARML (short for \u0026ldquo;Curricula for Unsupervised Meta-Reinforcement Learning\u0026rdquo;), by modeling unsupervised trajectories into a latent skill space, with a focus on training meta-RL policies (i.e. can transfer to unseen tasks). The setting of training environments in CARML is similar to DIAYN. Differently, CARML is trained on pixel-level observations but DIAYN operates on the true state space. An RL algorithm $\\pi_\\theta$, parameterized by $\\theta$, is trained via unsupervised interaction formulated as a CMP combined with a learned reward function $r$. This setting naturally works for the meta-learning purpose, since a customized reward function can be given only at the test time.\nFig. 8. An illustration of CARML, containing two steps: (1) organizing experiential data into the latent skill space; (2) meta-training the policy with the reward function constructed from the learned skills. (Image source: Jabri, et al 2019) CARML is framed as a variational Expectation-Maximization (EM).\n(1) E-Step: This is the stage for organizing experiential data. Collected trajectories are modeled with a mixture of latent components forming the basis of skills.\nLet $z$ be a latent task variable and $q_\\phi$ be a variational distribution of $z$, which could be a mixture model with discrete $z$ or a VAE with continuous $z$. A variational posterior $q_\\phi(z \\vert s)$ works like a classifier, predicting a skill given a state, and we would like to maximize $q_\\phi(z \\vert s)$ to discriminate between data produced by different skills as much as possible. In E-step, $q_\\phi$ is fitted to a set of trajectories produced by $\\pi_\\theta$.\nPrecisely, given a trajectory $\\tau = (s_1,\\dots,s_T)$, we would like to find $\\phi$ such that\n $$ \\max_\\phi \\mathbb{E}_{z\\sim q_\\phi(z)} \\big[ \\log q_\\phi(\\tau \\vert z) \\big] = \\max_\\phi \\mathbb{E}_{z\\sim q_\\phi(z)} \\big[ \\sum_{s_i \\in \\tau} \\log q_\\phi(s_i \\vert z) \\big] $$ A simplifying assumption is made here to ignore the order of states in one trajectory.\n(2) M-Step: This is the stage for doing meta-RL training with $\\pi_\\theta$. The learned skill space is considered as a training task distribution. CARML is agnostic to the type of meta-RL algorithm for policy parameter updates.\nGiven a trajectory $\\tau$, it makes sense for the policy to maximize the mutual information between $\\tau$ and $z$, $I(\\tau;z) = H(\\tau) - H(\\tau \\vert z)$, because:\n maximizing $H(\\tau)$ =\u0026gt; diversity in the policy data space; expected to be large. minimizing $H(\\tau \\vert z)$ =\u0026gt; given a certain skill, the behavior should be restricted; expected to be small. Then we have,\n $$ \\begin{aligned} I(\\tau; z) \u0026= \\mathcal{H}(z) - \\mathcal{H}(z \\vert s_1,\\dots, s_T) \\\\ \u0026\\geq \\mathbb{E}_{s \\in \\tau} [\\mathcal{H}(z) - \\mathcal{H}(z\\vert s)] \u0026 \\scriptstyle{\\text{; discard the order of states.}} \\\\ \u0026= \\mathbb{E}_{s \\in \\tau} [\\mathcal{H}(s_t) - \\mathcal{H}(s\\vert z)] \u0026 \\scriptstyle{\\text{; by definition of MI.}} \\\\ \u0026= \\mathbb{E}_{z\\sim q_\\phi(z), s\\sim \\pi_\\theta(s|z)} [\\log q_\\phi(s|z) - \\log \\pi_\\theta(s)] \\\\ \u0026\\approx \\mathbb{E}_{z\\sim q_\\phi(z), s\\sim \\pi_\\theta(s|z)} [\\color{green}{\\log q_\\phi(s|z) - \\log q_\\phi(s)}] \u0026 \\scriptstyle{\\text{; assume learned marginal distr. matches policy.}} \\end{aligned} $$ We can set the reward as $\\log q_\\phi(s \\vert z) - \\log q_\\phi(s)$, as shown in the red part in the equation above. In order to balance between task-specific exploration (as in red below) and latent skill matching (as in blue below) , a parameter $\\lambda \\in [0, 1]$ is added. Each realization of $z \\sim q_\\phi(z)$ induces a reward function $r_z(s)$ (remember that reward + CMP =\u0026gt; MDP) as follows:\n $$ \\begin{aligned} r_z(s) \u0026= \\lambda \\log q_\\phi(s|z) - \\log q_\\phi(s) \\\\ \u0026= \\lambda \\log q_\\phi(s|z) - \\log \\frac{q_\\phi(s|z) q_\\phi(z)}{q_\\phi(z|s)} \\\\ \u0026= \\lambda \\log q_\\phi(s|z) - \\log q_\\phi(s|z) - \\log q_\\phi(z) + \\log q_\\phi(z|s) \\\\ \u0026= (\\lambda - 1) \\log \\color{red}{q_\\phi(s|z)} + \\color{blue}{\\log q_\\phi(z|s)} + C \\end{aligned} $$ Fig. 9. The algorithm of CARML. (Image source: Jabri, et al 2019) Learning a latent skill space can be done in different ways, such as in Hausman, et al. 2018. The goal of their approach is to learn a task-conditioned policy, $\\pi(a \\vert s, t^{(i)})$, where $t^{(i)}$ is from a discrete list of $N$ tasks, $\\mathcal{T} = [t^{(1)}, \\dots, t^{(N)}]$. However, rather than learning $N$ separate solutions, one per task, it would be nice to learn a latent skill space so that each task could be represented in a distribution over skills and thus skills are reused between tasks. The policy is defined as $\\pi_\\theta(a \\vert s,t) = \\int \\pi_\\theta(a \\vert z,s,t) p_\\phi(z \\vert t)\\mathrm{d}z$, where $\\pi_\\theta$ and $p_\\phi$ are policy and embedding networks to learn, respectively. If $z$ is discrete, i.e. drawn from a set of $K$ skills, then the policy becomes a mixture of $K$ sub-policies. The policy training uses SAC and the dependency on $z$ is introduced in the entropy term.\nCurriculum through Distillation [I was thinking of the name of this section for a while, deciding between cloning, inheritance, and distillation. Eventually, I picked distillation because it sounds the coolest B-)]\nThe motivation for the progressive neural network (Rusu et al. 2016) architecture is to efficiently transfer learned skills between different tasks and in the meantime avoid catastrophic forgetting. The curriculum is realized through a set of progressively stacked neural network towers (or \u0026ldquo;columns\u0026rdquo;, as in the paper).\nA progressive network has the following structure:\n It starts with a single column containing $L$ layers of neurons, in which the corresponding activation layers are labelled as $h^{(1)}_i, i=1, \\dots, L$. We first train this single-column network for one task to convergence, achieving parameter config $\\theta^{(1)}$.\n Once switch to the next task, we need to add a new column to adapt to the new context while freezing $\\theta^{(1)}$ to lock down the learned skills from the previous task. The new column has activation layers labelled as $h^{(2)}_i, i=1, \\dots, L$, and parameters $\\theta^{(2)}$.\n Step 2 can be repeated with every new task. The $i$-th layer activation in the $k$-th column depends on the previous activation layers in all the existing columns:\n $$ h^{(k)}_i = f(W^{(k)}_i h^{(k)}_{i-1} + \\sum_{j where $W^{(k)}_i$ is the weight matrix of the layer $i$ in the column $k$; $U_i^{(k:j)}, j \u0026lt; k$ are the weight matrices for projecting the layer $i-1$ of the column $j$ to the layer $i$ of column $k$ ($ j \u0026lt; k $). The above weights matrices should be learned. $f(.)$ is a non-linear activation function by choice.\n Fig. 10. The progressive neural network architecture. (Image source: Rusu, et al. 2017) The paper experimented with Atari games by training a progressive network on multiple games to check whether features learned in one game can transfer to another. That is indeed the case. Though interestingly, learning a high dependency on features in the previous columns does not always indicate good transfer performance on the new task. One hypothesis is that features learned from the old task might introduce biases into the new task, leading to policy getting trapped in a sub-optimal solution. Overall, the progressive network works better than only fine-tuning the top layer and can achieve similar transfer performance as fine-tuning the entire network.\nOne use case for the progressive network is to do sim2real transfer (Rusu, et al. 2017), in which the first column is trained in simulator with a lot of samples and then the additional columns (could be for different real-world tasks) are added and trained with a few real data samples.\nCzarnecki, et al. (2018) proposed another RL training framework, Mix \u0026amp; Match (short for M\u0026amp;M) to provide curriculum through coping knowledge between agents. Given a sequence of agents from simple to complex, $\\pi_1, \\dots, \\pi_K$, each parameterized with some shared weights (e.g. by shared some lower common layers). M\u0026amp;M trains a mixture of agents, but only the final performance of the most complex one $\\pi_K$ matters.\nIn the meantime, M\u0026amp;M learns a categorical distribution $c \\sim \\text{Categorical}(1, \\dots, K \\vert \\alpha)$ with pmf $p(c=i) = \\alpha_i$ probability to pick which policy to use at a given time. The mixed M\u0026amp;M policy is a simple weighted sum: $\\pi_\\text{mm}(a \\vert s) = \\sum_{i=1}^K \\alpha_i \\pi_i(a \\vert s)$. Curriculum learning is realized by dynamically adjusting $\\alpha_i$, from $\\alpha_K=0$ to $\\alpha_K=1$. The tuning of $\\alpha$ can be manual or through population-based training.\nTo encourage cooperation rather than competition among policies, besides the RL loss $\\mathcal{L}_\\text{RL}$, another distillation-like loss $\\mathcal{L}_\\text{mm}(\\theta)$ is added. The knowledge transfer loss $\\mathcal{L}_\\text{mm}(\\theta)$ measures the KL divergence between two policies, $\\propto D_\\text{KL}(\\pi_{i}(. \\vert s) | \\pi_j(. \\vert s))$ for $i \u0026lt; j$. It encourages complex agents to match the simpler ones early on. The final loss is $\\mathcal{L} = \\mathcal{L}_\\text{RL}(\\theta \\vert \\pi_\\text{mm}) + \\lambda \\mathcal{L}_\\text{mm}(\\theta)$.\nFig. 11. The Mix \u0026 Match architecture for training a mixture of policies. (Image source: Czarnecki, et al., 2018) Citation Cited as:\n Weng, Lilian. (Jan 2020). Curriculum for reinforcement learning. Lil\u0026rsquo;Log. https://lilianweng.github.io/posts/2020-01-29-curriculum-rl/.\n Or\n@article{weng2020curriculum, title = \u0026quot;Curriculum for Reinforcement Learning\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2020\u0026quot;, month = \u0026quot;Jan\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2020-01-29-curriculum-rl/\u0026quot; } References [1] Jeffrey L. Elman. \u0026ldquo;Learning and development in neural networks: The importance of starting small.\u0026quot; Cognition 48.1 (1993): 71-99.\n[2] Yoshua Bengio, et al. \u0026ldquo;Curriculum learning.\u0026quot; ICML 2009.\n[3] Daphna Weinshall, Gad Cohen, and Dan Amir. \u0026ldquo;Curriculum learning by transfer learning: Theory and experiments with deep networks.\u0026quot; ICML 2018.\n[4] Wojciech Zaremba and Ilya Sutskever. \u0026ldquo;Learning to execute.\u0026quot; arXiv preprint arXiv:1410.4615 (2014).\n[5] Tambet Matiisen, et al. \u0026ldquo;Teacher-student curriculum learning.\u0026quot; IEEE Trans. on neural networks and learning systems (2017).\n[6] Alex Graves, et al. \u0026ldquo;Automated curriculum learning for neural networks.\u0026quot; ICML 2017.\n[7] Remy Portelas, et al. Teacher algorithms for curriculum learning of Deep RL in continuously parameterized environments. CoRL 2019.\n[8] Sainbayar Sukhbaatar, et al. \u0026ldquo;Intrinsic Motivation and Automatic Curricula via Asymmetric Self-Play.\u0026quot; ICLR 2018.\n[9] Carlos Florensa, et al. \u0026ldquo;Automatic Goal Generation for Reinforcement Learning Agents\u0026rdquo; ICML 2019.\n[10] Sebastien Racaniere \u0026amp; Andrew K. Lampinen, et al. \u0026ldquo;Automated Curriculum through Setter-Solver Interactions\u0026rdquo; ICLR 2020.\n[11] Allan Jabri, et al. \u0026ldquo;Unsupervised Curricula for Visual Meta-Reinforcement Learning\u0026rdquo; NeuriPS 2019.\n[12] Karol Hausman, et al. \u0026ldquo;Learning an Embedding Space for Transferable Robot Skills \u0026ldquo; ICLR 2018.\n[13] Josh Merel, et al. \u0026ldquo;Reusable neural skill embeddings for vision-guided whole body movement and object manipulation\u0026rdquo; arXiv preprint arXiv:1911.06636 (2019).\n[14] OpenAI, et al. \u0026ldquo;Solving Rubik\u0026rsquo;s Cube with a Robot Hand.\u0026quot; arXiv preprint arXiv:1910.07113 (2019).\n[15] Niels Justesen, et al. \u0026ldquo;Illuminating Generalization in Deep Reinforcement Learning through Procedural Level Generation\u0026rdquo; NeurIPS 2018 Deep RL Workshop.\n[16] Karl Cobbe, et al. \u0026ldquo;Quantifying Generalization in Reinforcement Learning\u0026rdquo; arXiv preprint arXiv:1812.02341 (2018).\n[17] Andrei A. Rusu et al. \u0026ldquo;Progressive Neural Networks\u0026rdquo; arXiv preprint arXiv:1606.04671 (2016).\n[18] Andrei A. Rusu et al. \u0026ldquo;Sim-to-Real Robot Learning from Pixels with Progressive Nets.\u0026quot; CoRL 2017.\n[19] Wojciech Marian Czarnecki, et al. \u0026ldquo;Mix \u0026amp; Match – Agent Curricula for Reinforcement Learning.\u0026quot; ICML 2018.\n","permalink":"https://lilianweng.github.io/posts/2020-01-29-curriculum-rl/","summary":"[Updated on 2020-02-03: mentioning PCG in the \u0026ldquo;Task-Specific Curriculum\u0026rdquo; section. [Updated on 2020-02-04: Add a new \u0026ldquo;curriculum through distillation\u0026rdquo; section.\nIt sounds like an impossible task if we want to teach integral or derivative to a 3-year-old who does not even know basic arithmetics. That\u0026rsquo;s why education is important, as it provides a systematic way to break down complex knowledge and a nice curriculum for teaching concepts from simple to hard.","title":"Curriculum for Reinforcement Learning"},{"content":"[Updated on 2020-01-09: add a new section on Contrastive Predictive Coding]. [Updated on 2020-04-13: add a \u0026ldquo;Momentum Contrast\u0026rdquo; section on MoCo, SimCLR and CURL.] [Updated on 2020-07-08: add a \u0026ldquo;Bisimulation\u0026rdquo; section on DeepMDP and DBC.] [Updated on 2020-09-12: add MoCo V2 and BYOL in the \u0026ldquo;Momentum Contrast\u0026rdquo; section.] [Updated on 2021-05-31: remove section on \u0026ldquo;Momentum Contrast\u0026rdquo; and add a pointer to a full post on \u0026ldquo;Contrastive Representation Learning\u0026rdquo;]\nGiven a task and enough labels, supervised learning can solve it really well. Good performance usually requires a decent amount of labels, but collecting manual labels is expensive (i.e. ImageNet) and hard to be scaled up. Considering the amount of unlabelled data (e.g. free text, all the images on the Internet) is substantially more than a limited number of human curated labelled datasets, it is kinda wasteful not to use them. However, unsupervised learning is not easy and usually works much less efficiently than supervised learning.\nWhat if we can get labels for free for unlabelled data and train unsupervised dataset in a supervised manner? We can achieve this by framing a supervised learning task in a special form to predict only a subset of information using the rest. In this way, all the information needed, both inputs and labels, has been provided. This is known as self-supervised learning.\nThis idea has been widely used in language modeling. The default task for a language model is to predict the next word given the past sequence. BERT adds two other auxiliary tasks and both rely on self-generated labels.\nFig. 1. A great summary of how self-supervised learning tasks can be constructed (Image source: LeCun’s talk) Here is a nicely curated list of papers in self-supervised learning. Please check it out if you are interested in reading more in depth.\nNote that this post does not focus on either NLP / language modeling or generative modeling.\nWhy Self-Supervised Learning? Self-supervised learning empowers us to exploit a variety of labels that come with the data for free. The motivation is quite straightforward. Producing a dataset with clean labels is expensive but unlabeled data is being generated all the time. To make use of this much larger amount of unlabeled data, one way is to set the learning objectives properly so as to get supervision from the data itself.\nThe self-supervised task, also known as pretext task, guides us to a supervised loss function. However, we usually don’t care about the final performance of this invented task. Rather we are interested in the learned intermediate representation with the expectation that this representation can carry good semantic or structural meanings and can be beneficial to a variety of practical downstream tasks.\nFor example, we might rotate images at random and train a model to predict how each input image is rotated. The rotation prediction task is made-up, so the actual accuracy is unimportant, like how we treat auxiliary tasks. But we expect the model to learn high-quality latent variables for real-world tasks, such as constructing an object recognition classifier with very few labeled samples.\nBroadly speaking, all the generative models can be considered as self-supervised, but with different goals: Generative models focus on creating diverse and realistic images, while self-supervised representation learning care about producing good features generally helpful for many tasks. Generative modeling is not the focus of this post, but feel free to check my previous posts.\nImages-Based Many ideas have been proposed for self-supervised representation learning on images. A common workflow is to train a model on one or multiple pretext tasks with unlabelled images and then use one intermediate feature layer of this model to feed a multinomial logistic regression classifier on ImageNet classification. The final classification accuracy quantifies how good the learned representation is.\nRecently, some researchers proposed to train supervised learning on labelled data and self-supervised pretext tasks on unlabelled data simultaneously with shared weights, like in Zhai et al, 2019 and Sun et al, 2019.\nDistortion We expect small distortion on an image does not modify its original semantic meaning or geometric forms. Slightly distorted images are considered the same as original and thus the learned features are expected to be invariant to distortion.\nExemplar-CNN (Dosovitskiy et al., 2015) create surrogate training datasets with unlabeled image patches:\n Sample $N$ patches of size 32 × 32 pixels from different images at varying positions and scales, only from regions containing considerable gradients as those areas cover edges and tend to contain objects or parts of objects. They are \u0026ldquo;exemplary\u0026rdquo; patches. Each patch is distorted by applying a variety of random transformations (i.e., translation, rotation, scaling, etc.). All the resulting distorted patches are considered to belong to the same surrogate class. The pretext task is to discriminate between a set of surrogate classes. We can arbitrarily create as many surrogate classes as we want. Fig. 2. The original patch of a cute deer is in the top left corner. Random transformations are applied, resulting in a variety of distorted patches. All of them should be classified into the same class in the pretext task. (Image source: Dosovitskiy et al., 2015) Rotation of an entire image (Gidaris et al. 2018 is another interesting and cheap way to modify an input image while the semantic content stays unchanged. Each input image is first rotated by a multiple of $90^\\circ$ at random, corresponding to $[0^\\circ, 90^\\circ, 180^\\circ, 270^\\circ]$. The model is trained to predict which rotation has been applied, thus a 4-class classification problem.\nIn order to identify the same image with different rotations, the model has to learn to recognize high level object parts, such as heads, noses, and eyes, and the relative positions of these parts, rather than local patterns. This pretext task drives the model to learn semantic concepts of objects in this way.\nFig. 3. Illustration of self-supervised learning by rotating the entire input images. The model learns to predict which rotation is applied. (Image source: Gidaris et al. 2018) Patches The second category of self-supervised learning tasks extract multiple patches from one image and ask the model to predict the relationship between these patches.\nDoersch et al. (2015) formulates the pretext task as predicting the relative position between two random patches from one image. A model needs to understand the spatial context of objects in order to tell the relative position between parts.\nThe training patches are sampled in the following way:\n Randomly sample the first patch without any reference to image content. Considering that the first patch is placed in the middle of a 3x3 grid, and the second patch is sampled from its 8 neighboring locations around it. To avoid the model only catching low-level trivial signals, such as connecting a straight line across boundary or matching local patterns, additional noise is introduced by: Add gaps between patches Small jitters Randomly downsample some patches to as little as 100 total pixels, and then upsampling it, to build robustness to pixelation. Shift green and magenta toward gray or randomly drop 2 of 3 color channels (See \u0026ldquo;chromatic aberration\u0026rdquo; below) The model is trained to predict which one of 8 neighboring locations the second patch is selected from, a classification problem over 8 classes. Fig. 4. Illustration of self-supervised learning by predicting the relative position of two random patches. (Image source: Doersch et al., 2015) Other than trivial signals like boundary patterns or textures continuing, another interesting and a bit surprising trivial solution was found, called \u0026ldquo;chromatic aberration\u0026rdquo;. It is triggered by different focal lengths of lights at different wavelengths passing through the lens. In the process, there might exist small offsets between color channels. Hence, the model can learn to tell the relative position by simply comparing how green and magenta are separated differently in two patches. This is a trivial solution and has nothing to do with the image content. Pre-processing images by shifting green and magenta toward gray or randomly dropping 2 of 3 color channels can avoid this trivial solution.\nFig. 5. Illustration of how chromatic aberration happens. (Image source: wikipedia) Since we have already set up a 3x3 grid in each image in the above task, why not use all of 9 patches rather than only 2 to make the task more difficult? Following this idea, Noroozi \u0026amp; Favaro (2016) designed a jigsaw puzzle game as pretext task: The model is trained to place 9 shuffled patches back to the original locations.\nA convolutional network processes each patch independently with shared weights and outputs a probability vector per patch index out of a predefined set of permutations. To control the difficulty of jigsaw puzzles, the paper proposed to shuffle patches according to a predefined permutation set and configured the model to predict a probability vector over all the indices in the set.\nBecause how the input patches are shuffled does not alter the correct order to predict. A potential improvement to speed up training is to use permutation-invariant graph convolutional network (GCN) so that we don’t have to shuffle the same set of patches multiple times, same idea as in this paper.\nFig. 6. Illustration of self-supervised learning by solving jigsaw puzzle. (Image source: Noroozi \u0026 Favaro, 2016) Another idea is to consider \u0026ldquo;feature\u0026rdquo; or \u0026ldquo;visual primitives\u0026rdquo; as a scalar-value attribute that can be summed up over multiple patches and compared across different patches. Then the relationship between patches can be defined by counting features and simple arithmetic (Noroozi, et al, 2017).\nThe paper considers two transformations:\n Scaling: If an image is scaled up by 2x, the number of visual primitives should stay the same. Tiling: If an image is tiled into a 2x2 grid, the number of visual primitives is expected to be the sum, 4 times the original feature counts. The model learns a feature encoder $\\phi(.)$ using the above feature counting relationship. Given an input image $\\mathbf{x} \\in \\mathbb{R}^{m \\times n \\times 3}$, considering two types of transformation operators:\n Downsampling operator, $D: \\mathbb{R}^{m \\times n \\times 3} \\mapsto \\mathbb{R}^{\\frac{m}{2} \\times \\frac{n}{2} \\times 3}$: downsample by a factor of 2 Tiling operator $T_i: \\mathbb{R}^{m \\times n \\times 3} \\mapsto \\mathbb{R}^{\\frac{m}{2} \\times \\frac{n}{2} \\times 3}$: extract the $i$-th tile from a 2x2 grid of the image. We expect to learn:\n $$ \\phi(\\mathbf{x}) = \\phi(D \\circ \\mathbf{x}) = \\sum_{i=1}^4 \\phi(T_i \\circ \\mathbf{x}) $$ Thus the MSE loss is: $\\mathcal{L}_\\text{feat} = |\\phi(D \\circ \\mathbf{x}) - \\sum_{i=1}^4 \\phi(T_i \\circ \\mathbf{x})|^2_2$. To avoid trivial solution $\\phi(\\mathbf{x}) = \\mathbf{0}, \\forall{\\mathbf{x}}$, another loss term is added to encourage the difference between features of two different images: $\\mathcal{L}_\\text{diff} = \\max(0, c -|\\phi(D \\circ \\mathbf{y}) - \\sum_{i=1}^4 \\phi(T_i \\circ \\mathbf{x})|^2_2)$, where $\\mathbf{y}$ is another input image different from $\\mathbf{x}$ and $c$ is a scalar constant. The final loss is:\n $$ \\mathcal{L} = \\mathcal{L}_\\text{feat} + \\mathcal{L}_\\text{diff} = \\|\\phi(D \\circ \\mathbf{x}) - \\sum_{i=1}^4 \\phi(T_i \\circ \\mathbf{x})\\|^2_2 + \\max(0, M -\\|\\phi(D \\circ \\mathbf{y}) - \\sum_{i=1}^4 \\phi(T_i \\circ \\mathbf{x})\\|^2_2) $$ Fig. 7. Self-supervised representation learning by counting features. (Image source: Noroozi, et al, 2017) Colorization Colorization can be used as a powerful self-supervised task: a model is trained to color a grayscale input image; precisely the task is to map this image to a distribution over quantized color value outputs (Zhang et al. 2016).\nThe model outputs colors in the the CIE Lab* color space. The Lab* color is designed to approximate human vision, while, in contrast, RGB or CMYK models the color output of physical devices.\n L* component matches human perception of lightness; L* = 0 is black and L* = 100 indicates white. a* component represents green (negative) / magenta (positive) value. b* component models blue (negative) /yellow (positive) value. Due to the multimodal nature of the colorization problem, cross-entropy loss of predicted probability distribution over binned color values works better than L2 loss of the raw color values. The ab color space is quantized with bucket size 10.\nTo balance between common colors (usually low ab values, of common backgrounds like clouds, walls, and dirt) and rare colors (which are likely associated with key objects in the image), the loss function is rebalanced with a weighting term that boosts the loss of infrequent color buckets. This is just like why we need both tf and idf for scoring words in information retrieval model. The weighting term is constructed as: (1-λ) * Gaussian-kernel-smoothed empirical probability distribution + λ * a uniform distribution, where both distributions are over the quantized ab color space.\nGenerative Modeling The pretext task in generative modeling is to reconstruct the original input while learning meaningful latent representation.\nThe denoising autoencoder (Vincent, et al, 2008) learns to recover an image from a version that is partially corrupted or has random noise. The design is inspired by the fact that humans can easily recognize objects in pictures even with noise, indicating that key visual features can be extracted and separated from noise. See my old post.\nThe context encoder (Pathak, et al., 2016) is trained to fill in a missing piece in the image. Let $\\hat{M}$ be a binary mask, 0 for dropped pixels and 1 for remaining input pixels. The model is trained with a combination of the reconstruction (L2) loss and the adversarial loss. The removed regions defined by the mask could be of any shape.\n $$ \\begin{aligned} \\mathcal{L}(\\mathbf{x}) \u0026= \\mathcal{L}_\\text{recon}(\\mathbf{x}) + \\mathcal{L}_\\text{adv}(\\mathbf{x})\\\\ \\mathcal{L}_\\text{recon}(\\mathbf{x}) \u0026= \\|(1 - \\hat{M}) \\odot (\\mathbf{x} - E(\\hat{M} \\odot \\mathbf{x})) \\|_2^2 \\\\ \\mathcal{L}_\\text{adv}(\\mathbf{x}) \u0026= \\max_D \\mathbb{E}_{\\mathbf{x}} [\\log D(\\mathbf{x}) + \\log(1 - D(E(\\hat{M} \\odot \\mathbf{x})))] \\end{aligned} $$ where $E(.)$ is the encoder and $D(.)$ is the decoder.\nFig. 8. Illustration of context encoder. (Image source: Pathak, et al., 2016) When applying a mask on an image, the context encoder removes information of all the color channels in partial regions. How about only hiding a subset of channels? The split-brain autoencoder (Zhang et al., 2017) does this by predicting a subset of color channels from the rest of channels. Let the data tensor $\\mathbf{x} \\in \\mathbb{R}^{h \\times w \\times \\vert C \\vert }$ with $C$ color channels be the input for the $l$-th layer of the network. It is split into two disjoint parts, $\\mathbf{x}_1 \\in \\mathbb{R}^{h \\times w \\times \\vert C_1 \\vert}$ and $\\mathbf{x}_2 \\in \\mathbb{R}^{h \\times w \\times \\vert C_2 \\vert}$, where $C_1 , C_2 \\subseteq C$. Then two sub-networks are trained to do two complementary predictions: one network $f_1$ predicts $\\mathbf{x}_2$ from $\\mathbf{x}_1$ and the other network $f_1$ predicts $\\mathbf{x}_1$ from $\\mathbf{x}_2$. The loss is either L1 loss or cross entropy if color values are quantized.\nThe split can happen once on the RGB-D or Lab* colorspace, or happen even in every layer of a CNN network in which the number of channels can be arbitrary.\nFig. 9. Illustration of split-brain autoencoder. (Image source: Zhang et al., 2017) The generative adversarial networks (GANs) are able to learn to map from simple latent variables to arbitrarily complex data distributions. Studies have shown that the latent space of such generative models captures semantic variation in the data; e.g. when training GAN models on human faces, some latent variables are associated with facial expression, glasses, gender, etc (Radford et al., 2016).\nBidirectional GANs (Donahue, et al, 2017) introduces an additional encoder $E(.)$ to learn the mappings from the input to the latent variable $\\mathbf{z}$. The discriminator $D(.)$ predicts in the joint space of the input data and latent representation, $(\\mathbf{x}, \\mathbf{z})$, to tell apart the generated pair $(\\mathbf{x}, E(\\mathbf{x}))$ from the real one $(G(\\mathbf{z}), \\mathbf{z})$. The model is trained to optimize the objective: $\\min_{G, E} \\max_D V(D, E, G)$, where the generator $G$ and the encoder $E$ learn to generate data and latent variables that are realistic enough to confuse the discriminator and at the same time the discriminator $D$ tries to differentiate real and generated data.\n $$ V(D, E, G) = \\mathbb{E}_{\\mathbf{x} \\sim p_\\mathbf{x}} [ \\underbrace{\\mathbb{E}_{\\mathbf{z} \\sim p_E(.\\vert\\mathbf{x})}[\\log D(\\mathbf{x}, \\mathbf{z})]}_{\\log D(\\text{real})} ] + \\mathbb{E}_{\\mathbf{z} \\sim p_\\mathbf{z}} [ \\underbrace{\\mathbb{E}_{\\mathbf{x} \\sim p_G(.\\vert\\mathbf{z})}[\\log 1 - D(\\mathbf{x}, \\mathbf{z})]}_{\\log(1- D(\\text{fake}))}) ] $$ Fig. 10. Illustration of how Bidirectional GAN works. (Image source: Donahue, et al, 2017) Contrastive Learning The Contrastive Predictive Coding (CPC) (van den Oord, et al. 2018) is an approach for unsupervised learning from high-dimensional data by translating a generative modeling problem to a classification problem. The contrastive loss or InfoNCE loss in CPC, inspired by Noise Contrastive Estimation (NCE), uses cross-entropy loss to measure how well the model can classify the \u0026ldquo;future\u0026rdquo; representation amongst a set of unrelated \u0026ldquo;negative\u0026rdquo; samples. Such design is partially motivated by the fact that the unimodal loss like MSE has no enough capacity but learning a full generative model could be too expensive.\nFig. 11. Illustration of applying Contrastive Predictive Coding on the audio input. (Image source: van den Oord, et al. 2018) CPC uses an encoder to compress the input data $z_t = g_\\text{enc}(x_t)$ and an autoregressive decoder to learn the high-level context that is potentially shared across future predictions, $c_t = g_\\text{ar}(z_{\\leq t})$. The end-to-end training relies on the NCE-inspired contrastive loss.\nWhile predicting future information, CPC is optimized to maximize the the mutual information between input $x$ and context vector $c$:\n $$ I(x; c) = \\sum_{x, c} p(x, c) \\log\\frac{p(x, c)}{p(x)p(c)} = \\sum_{x, c} p(x, c)\\log\\frac{p(x|c)}{p(x)} $$ Rather than modeling the future observations $p_k(x_{t+k} \\vert c_t)$ directly (which could be fairly expensive), CPC models a density function to preserve the mutual information between $x_{t+k}$ and $c_t$:\n $$ f_k(x_{t+k}, c_t) = \\exp(z_{t+k}^\\top W_k c_t) \\propto \\frac{p(x_{t+k}|c_t)}{p(x_{t+k})} $$ where $f_k$ can be unnormalized and a linear transformation $W_k^\\top c_t$ is used for the prediction with a different $W_k$ matrix for every step $k$.\nGiven a set of $N$ random samples $X = \\{x_1, \\dots, x_N\\}$ containing only one positive sample $x_t \\sim p(x_{t+k} \\vert c_t)$ and $N-1$ negative samples $x_{i \\neq t} \\sim p(x_{t+k})$, the cross-entropy loss for classifying the positive sample (where $\\frac{f_k}{\\sum f_k}$ is the prediction) correctly is:\n $$ \\mathcal{L}_N = - \\mathbb{E}_X \\Big[\\log \\frac{f_k(x_{t+k}, c_t)}{\\sum_{i=1}^N f_k (x_i, c_t)}\\Big] $$ Fig. 12. Illustration of applying Contrastive Predictive Coding on images. (Image source: van den Oord, et al. 2018) When using CPC on images (Henaff, et al. 2019), the predictor network should only access a masked feature set to avoid a trivial prediction. Precisely:\n Each input image is divided into a set of overlapped patches and each patch is encoded by a resnet encoder, resulting in compressed feature vector $z_{i,j}$. A masked conv net makes prediction with a mask such that the receptive field of a given output neuron can only see things above it in the image. Otherwise, the prediction problem would be trivial. The prediction can be made in both directions (top-down and bottom-up). The prediction is made for $z_{i+k, j}$ from context $c_{i,j}$: $\\hat{z}_{i+k, j} = W_k c_{i,j}$. A contrastive loss quantifies this prediction with a goal to correctly identify the target among a set of negative representation $\\{z_l\\}$ sampled from other patches in the same image and other images in the same batch:\n $$ \\mathcal{L}_\\text{CPC} = -\\sum_{i,j,k} \\log p(z_{i+k, j} \\vert \\hat{z}_{i+k, j}, \\{z_l\\}) = -\\sum_{i,j,k} \\log \\frac{\\exp(\\hat{z}_{i+k, j}^\\top z_{i+k, j})}{\\exp(\\hat{z}_{i+k, j}^\\top z_{i+k, j}) + \\sum_l \\exp(\\hat{z}_{i+k, j}^\\top z_l)} $$ For more content on contrastive learning, check out the post on \u0026ldquo;Contrastive Representation Learning\u0026rdquo;.\nVideo-Based A video contains a sequence of semantically related frames. Nearby frames are close in time and more correlated than frames further away. The order of frames describes certain rules of reasonings and physical logics; such as that object motion should be smooth and gravity is pointing down.\nA common workflow is to train a model on one or multiple pretext tasks with unlabelled videos and then feed one intermediate feature layer of this model to fine-tune a simple model on downstream tasks of action classification, segmentation or object tracking.\nTracking The movement of an object is traced by a sequence of video frames. The difference between how the same object is captured on the screen in close frames is usually not big, commonly triggered by small motion of the object or the camera. Therefore any visual representation learned for the same object across close frames should be close in the latent feature space. Motivated by this idea, Wang \u0026amp; Gupta, 2015 proposed a way of unsupervised learning of visual representation by tracking moving objects in videos.\nPrecisely patches with motion are tracked over a small time window (e.g. 30 frames). The first patch $\\mathbf{x}$ and the last patch $\\mathbf{x}^+$ are selected and used as training data points. If we train the model directly to minimize the difference between feature vectors of two patches, the model may only learn to map everything to the same value. To avoid such a trivial solution, same as above, a random third patch $\\mathbf{x}^-$ is added. The model learns the representation by enforcing the distance between two tracked patches to be closer than the distance between the first patch and a random one in the feature space, $D(\\mathbf{x}, \\mathbf{x}^-)) \u0026gt; D(\\mathbf{x}, \\mathbf{x}^+)$, where $D(.)$ is the cosine distance,\n $$ D(\\mathbf{x}_1, \\mathbf{x}_2) = 1 - \\frac{f(\\mathbf{x}_1) f(\\mathbf{x}_2)}{\\|f(\\mathbf{x}_1)\\| \\|f(\\mathbf{x}_2\\|)} $$ The loss function is:\n $$ \\mathcal{L}(\\mathbf{x}, \\mathbf{x}^+, \\mathbf{x}^-) = \\max\\big(0, D(\\mathbf{x}, \\mathbf{x}^+) - D(\\mathbf{x}, \\mathbf{x}^-) + M\\big) + \\text{weight decay regularization term} $$ where $M$ is a scalar constant controlling for the minimum gap between two distances; $M=0.5$ in the paper. The loss enforces $D(\\mathbf{x}, \\mathbf{x}^-) \u0026gt;= D(\\mathbf{x}, \\mathbf{x}^+) + M$ at the optimal case.\nThis form of loss function is also known as triplet loss in the face recognition task, in which the dataset contains images of multiple people from multiple camera angles. Let $\\mathbf{x}^a$ be an anchor image of a specific person, $\\mathbf{x}^p$ be a positive image of this same person from a different angle and $\\mathbf{x}^n$ be a negative image of a different person. In the embedding space, $\\mathbf{x}^a$ should be closer to $\\mathbf{x}^p$ than $\\mathbf{x}^n$:\n $$ \\mathcal{L}_\\text{triplet}(\\mathbf{x}^a, \\mathbf{x}^p, \\mathbf{x}^n) = \\max(0, \\|\\phi(\\mathbf{x}^a) - \\phi(\\mathbf{x}^p) \\|_2^2 - \\|\\phi(\\mathbf{x}^a) - \\phi(\\mathbf{x}^n) \\|_2^2 + M) $$ A slightly different form of the triplet loss, named n-pair loss is also commonly used for learning observation embedding in robotics tasks. See a later section for more related content.\nFig. 13. Overview of learning representation by tracking objects in videos. (a) Identify moving patches in short traces; (b) Feed two related patched and one random patch into a conv network with shared weights. (c) The loss function enforces the distance between related patches to be closer than the distance between random patches. (Image source: Wang \u0026 Gupta, 2015) Relevant patches are tracked and extracted through a two-step unsupervised optical flow approach:\n Obtain SURF interest points and use IDT to obtain motion of each SURF point. Given the trajectories of SURF interest points, classify these points as moving if the flow magnitude is more than 0.5 pixels. During training, given a pair of correlated patches $\\mathbf{x}$ and $\\mathbf{x}^+$, $K$ random patches $\\{\\mathbf{x}^-\\}$ are sampled in this same batch to form $K$ training triplets. After a couple of epochs, hard negative mining is applied to make the training harder and more efficient, that is, to search for random patches that maximize the loss and use them to do gradient updates.\nFrame Sequence Video frames are naturally positioned in chronological order. Researchers have proposed several self-supervised tasks, motivated by the expectation that good representation should learn the correct sequence of frames.\nOne idea is to validate frame order (Misra, et al 2016). The pretext task is to determine whether a sequence of frames from a video is placed in the correct temporal order (\u0026ldquo;temporal valid\u0026rdquo;). The model needs to track and reason about small motion of an object across frames to complete such a task.\nThe training frames are sampled from high-motion windows. Every time 5 frames are sampled $(f_a, f_b, f_c, f_d, f_e)$ and the timestamps are in order $a \u0026lt; b \u0026lt; c \u0026lt; d \u0026lt; e$. Out of 5 frames, one positive tuple $(f_b, f_c, f_d)$ and two negative tuples, $(f_b, f_a, f_d)$ and $(f_b, f_e, f_d)$ are created. The parameter $\\tau_\\max = \\vert b-d \\vert$ controls the difficulty of positive training instances (i.e. higher → harder) and the parameter $\\tau_\\min = \\min(\\vert a-b \\vert, \\vert d-e \\vert)$ controls the difficulty of negatives (i.e. lower → harder).\nThe pretext task of video frame order validation is shown to improve the performance on the downstream task of action recognition when used as a pretraining step.\nFig. 14. Overview of learning representation by validating the order of video frames. (a) the data sample process; (b) the model is a triplet siamese network, where all input frames have shared weights. (Image source: Misra, et al 2016) The task in O3N (Odd-One-Out Network; Fernando et al. 2017) is based on video frame sequence validation too. One step further from above, the task is to pick the incorrect sequence from multiple video clips.\nGiven $N+1$ input video clips, one of them has frames shuffled, thus in the wrong order, and the rest $N$ of them remain in the correct temporal order. O3N learns to predict the location of the odd video clip. In their experiments, there are 6 input clips and each contain 6 frames.\nThe arrow of time in a video contains very informative messages, on both low-level physics (e.g. gravity pulls objects down to the ground; smoke rises up; water flows downward.) and high-level event reasoning (e.g. fish swim forward; you can break an egg but cannot revert it.). Thus another idea is inspired by this to learn latent representation by predicting the arrow of time (AoT) \u0026mdash; whether video playing forwards or backwards (Wei et al., 2018).\nA classifier should capture both low-level physics and high-level semantics in order to predict the arrow of time. The proposed T-CAM (Temporal Class-Activation-Map) network accepts $T$ groups, each containing a number of frames of optical flow. The conv layer outputs from each group are concatenated and fed into binary logistic regression for predicting the arrow of time.\nFig. 15. Overview of learning representation by predicting the arrow of time. (a) Conv features of multiple groups of frame sequences are concatenated. (b) The top level contains 3 conv layers and average pooling. (Image source: Wei et al, 2018) Interestingly, there exist a couple of artificial cues in the dataset. If not handled properly, they could lead to a trivial classifier without relying on the actual video content:\n Due to the video compression, the black framing might not be completely black but instead may contain certain information on the chronological order. Hence black framing should be removed in the experiments. Large camera motion, like vertical translation or zoom-in/out, also provides strong signals for the arrow of time but independent of content. The processing stage should stabilize the camera motion. The AoT pretext task is shown to improve the performance on action classification downstream task when used as a pretraining step. Note that fine-tuning is still needed.\nVideo Colorization Vondrick et al. (2018) proposed video colorization as a self-supervised learning problem, resulting in a rich representation that can be used for video segmentation and unlabelled visual region tracking, without extra fine-tuning.\nUnlike the image-based colorization, here the task is to copy colors from a normal reference frame in color to another target frame in grayscale by leveraging the natural temporal coherency of colors across video frames (thus these two frames shouldn’t be too far apart in time). In order to copy colors consistently, the model is designed to learn to keep track of correlated pixels in different frames.\nFig. 16. Video colorization by copying colors from a reference frame to target frames in grayscale. (Image source: Vondrick et al. 2018) The idea is quite simple and smart. Let $c_i$ be the true color of the $i-th$ pixel in the reference frame and $c_j$ be the color of $j$-th pixel in the target frame. The predicted color of $j$-th color in the target $\\hat{c}_j$ is a weighted sum of colors of all the pixels in reference, where the weighting term measures the similarity:\n $$ \\hat{c}_j = \\sum_i A_{ij} c_i \\text{ where } A_{ij} = \\frac{\\exp(f_i f_j)}{\\sum_{i'} \\exp(f_{i'} f_j)} $$ where $f$ are learned embeddings for corresponding pixels; $i’$ indexes all the pixels in the reference frame. The weighting term implements an attention-based pointing mechanism, similar to matching network and pointer network. As the full similarity matrix could be really large, both frames are downsampled. The categorical cross-entropy loss between $c_j$ and $\\hat{c}_j$ is used with quantized colors, just like in Zhang et al. 2016.\nBased on how the reference frame are marked, the model can be used to complete several color-based downstream tasks such as tracking segmentation or human pose in time. No fine-tuning is needed. See Fig. 15.\nFig. 17. Use video colorization to track object segmentation and human pose in time. (Image source: Vondrick et al. (2018)) A couple common observations:\n Combining multiple pretext tasks improves performance; Deeper networks improve the quality of representation; Supervised learning baselines still beat all of them by far. Control-Based When running a RL policy in the real world, such as controlling a physical robot on visual inputs, it is non-trivial to properly track states, obtain reward signals or determine whether a goal is achieved for real. The visual data has a lot of noise that is irrelevant to the true state and thus the equivalence of states cannot be inferred from pixel-level comparison. Self-supervised representation learning has shown great potential in learning useful state embedding that can be used directly as input to a control policy.\nAll the cases discussed in this section are in robotic learning, mainly for state representation from multiple camera views and goal representation.\nMulti-View Metric Learning The concept of metric learning has been mentioned multiple times in the previous sections. A common setting is: Given a triple of samples, (anchor $s_a$, positive sample $s_p$, negative sample $s_n$), the learned representation embedding $\\phi(s)$ fulfills that $s_a$ stays close to $s_p$ but far away from $s_n$ in the latent space.\nGrasp2Vec (Jang \u0026amp; Devin et al., 2018) aims to learn an object-centric vision representation in the robot grasping task from free, unlabelled grasping activities. By object-centric, it means that, irrespective of how the environment or the robot looks like, if two images contain similar items, they should be mapped to similar representation; otherwise the embeddings should be far apart.\nFig. 18. A conceptual illustration of how grasp2vec learns an object-centric state embedding. (Image source: Jang \u0026 Devin et al., 2018) The grasping system can tell whether it moves an object but cannot tell which object it is. Cameras are set up to take images of the entire scene and the grasped object. During early training, the grasp robot is executed to grasp any object $o$ at random, producing a triple of images, $(s_\\text{pre}, s_\\text{post}, o)$:\n $o$ is an image of the grasped object held up to the camera; $s_\\text{pre}$ is an image of the scene before grasping, with the object $o$ in the tray; $s_\\text{post}$ is an image of the same scene after grasping, without the object $o$ in the tray. To learn object-centric representation, we expect the difference between embeddings of $s_\\text{pre}$ and $s_\\text{post}$ to capture the removed object $o$. The idea is quite interesting and similar to relationships that have been observed in word embedding, e.g. distance(\u0026ldquo;king\u0026rdquo;, \u0026ldquo;queen\u0026rdquo;) ≈ distance(\u0026ldquo;man\u0026rdquo;, \u0026ldquo;woman\u0026rdquo;).\nLet $\\phi_s$ and $\\phi_o$ be the embedding functions for the scene and the object respectively. The model learns the representation by minimizing the distance between $\\phi_s(s_\\text{pre}) - \\phi_s(s_\\text{post})$ and $\\phi_o(o)$ using n-pair loss:\n $$ \\begin{aligned} \\mathcal{L}_\\text{grasp2vec} \u0026= \\text{NPair}(\\phi_s(s_\\text{pre}) - \\phi_s(s_\\text{post}), \\phi_o(o)) + \\text{NPair}(\\phi_o(o), \\phi_s(s_\\text{pre}) - \\phi_s(s_\\text{post})) \\\\ \\text{where }\\text{NPair}(a, p) \u0026= \\sum_{iwhere $B$ refers to a batch of (anchor, positive) sample pairs.\nWhen framing representation learning as metric learning, n-pair loss is a common choice. Rather than processing explicit a triple of (anchor, positive, negative) samples, the n-pairs loss treats all other positive instances in one mini-batch across pairs as negatives.\nThe embedding function $\\phi_o$ works great for presenting a goal $g$ with an image. The reward function that quantifies how close the actually grasped object $o$ is close to the goal is defined as $r = \\phi_o(g) \\cdot \\phi_o(o)$. Note that computing rewards only relies on the learned latent space and doesn\u0026rsquo;t involve ground truth positions, so it can be used for training on real robots.\nFig. 19. Localization results of grasp2vec embedding. The heatmap of localizing a goal object in a pre-grasping scene is defined as $\\phi\\_o(o)^\\top \\phi\\_{s, \\text{spatial}} (s\\_\\text{pre})$, where $\\phi\\_{s, \\text{spatial}}$ is the output of the last resnet block after ReLU. The fourth column is a failure case and the last three columns take real images as goals. (Image source: Jang \u0026 Devin et al., 2018) Other than the embedding-similarity-based reward function, there are a few other tricks for training the RL policy in the grasp2vec framework:\n Posthoc labeling: Augment the dataset by labeling a randomly grasped object as a correct goal, like HER (Hindsight Experience Replay; Andrychowicz, et al., 2017). Auxiliary goal augmentation: Augment the replay buffer even further by relabeling transitions with unachieved goals; precisely, in each iteration, two goals are sampled $(g, g')$ and both are used to add new transitions into replay buffer. TCN (Time-Contrastive Networks; Sermanet, et al. 2018) learn from multi-camera view videos with the intuition that different viewpoints at the same timestep of the same scene should share the same embedding (like in FaceNet) while embedding should vary in time, even of the same camera viewpoint. Therefore embedding captures the semantic meaning of the underlying state rather than visual similarity. The TCN embedding is trained with triplet loss.\nThe training data is collected by taking videos of the same scene simultaneously but from different angles. All the videos are unlabelled.\nFig. 20. An illustration of time-contrastive approach for learning state embedding. The blue frames selected from two camera views at the same timestep are anchor and positive samples, while the red frame at a different timestep is the negative sample. TCN embedding extracts visual features that are invariant to camera configurations. It can be used to construct a reward function for imitation learning based on the euclidean distance between the demo video and the observations in the latent space.\nA further improvement over TCN is to learn embedding over multiple frames jointly rather than a single frame, resulting in mfTCN (Multi-frame Time-Contrastive Networks; Dwibedi et al., 2019). Given a set of videos from several synchronized camera viewpoints, $v_1, v_2, \\dots, v_k$, the frame at time $t$ and the previous $n-1$ frames selected with stride $s$ in each video are aggregated and mapped into one embedding vector, resulting in a lookback window of size $(n−1) \\times s + 1$. Each frame first goes through a CNN to extract low-level features and then we use 3D temporal convolutions to aggregate frames in time. The model is trained with n-pairs loss.\nFig. 21. The sampling process for training mfTCN. (Image source: Dwibedi et al., 2019) The training data is sampled as follows:\n First we construct two pairs of video clips. Each pair contains two clips from different camera views but with synchronized timesteps. These two sets of videos should be far apart in time. Sample a fixed number of frames from each video clip in the same pair simultaneously with the same stride. Frames with the same timesteps are trained as positive samples in the n-pair loss, while frames across pairs are negative samples. mfTCN embedding can capture the position and velocity of objects in the scene (e.g. in cartpole) and can also be used as inputs for policy.\nAutonomous Goal Generation RIG (Reinforcement learning with Imagined Goals; Nair et al., 2018) described a way to train a goal-conditioned policy with unsupervised representation learning. A policy learns from self-supervised practice by first imagining \u0026ldquo;fake\u0026rdquo; goals and then trying to achieve them.\nFig. 22. The workflow of RIG. (Image source: Nair et al., 2018) The task is to control a robot arm to push a small puck on a table to a desired position. The desired position, or the goal, is present in an image. During training, it learns latent embedding of both state $s$ and goal $g$ through $\\beta$-VAE encoder and the control policy operates entirely in the latent space.\nLet’s say a $\\beta$-VAE has an encoder $q_\\phi$ mapping input states to latent variable $z$ which is modeled by a Gaussian distribution and a decoder $p_\\psi$ mapping $z$ back to the states. The state encoder in RIG is set to be the mean of $\\beta$-VAE encoder.\n $$ \\begin{aligned} z \u0026\\sim q_\\phi(z \\vert s) = \\mathcal{N}(z; \\mu_\\phi(s), \\sigma^2_\\phi(s)) \\\\ \\mathcal{L}_{\\beta\\text{-VAE}} \u0026= - \\mathbb{E}_{z \\sim q_\\phi(z \\vert s)} [\\log p_\\psi (s \\vert z)] + \\beta D_\\text{KL}(q_\\phi(z \\vert s) \\| p_\\psi(s)) \\\\ e(s) \u0026\\triangleq \\mu_\\phi(s) \\end{aligned} $$ The reward is the Euclidean distance between state and goal embedding vectors: $r(s, g) = -|e(s) - e(g)|$. Similar to grasp2vec, RIG applies data augmentation as well by latent goal relabeling: precisely half of the goals are generated from the prior at random and the other half are selected using HER. Also same as grasp2vec, rewards do not depend on any ground truth states but only the learned state encoding, so it can be used for training on real robots.\nFig. 23. The algorithm of RIG. (Image source: Nair et al., 2018) The problem with RIG is a lack of object variations in the imagined goal pictures. If $\\beta$-VAE is only trained with a black puck, it would not be able to create a goal with other objects like blocks of different shapes and colors. A follow-up improvement replaces $\\beta$-VAE with a CC-VAE (Context-Conditioned VAE; Nair, et al., 2019), inspired by CVAE (Conditional VAE; Sohn, Lee \u0026amp; Yan, 2015), for goal generation.\nFig. 24. The workflow of context-conditioned RIG. (Image source: Nair, et al., 2019). A CVAE conditions on a context variable $c$. It trains an encoder $q_\\phi(z \\vert s, c)$ and a decoder $p_\\psi (s \\vert z, c)$ and note that both have access to $c$. The CVAE loss penalizes information passing from the input state $s$ through an information bottleneck but allows for unrestricted information flow from $c$ to both encoder and decoder.\n $$ \\mathcal{L}_\\text{CVAE} = - \\mathbb{E}_{z \\sim q_\\phi(z \\vert s,c)} [\\log p_\\psi (s \\vert z, c)] + \\beta D_\\text{KL}(q_\\phi(z \\vert s, c) \\| p_\\psi(s)) $$ To create plausible goals, CC-VAE conditions on a starting state $s_0$ so that the generated goal presents a consistent type of object as in $s_0$. This goal consistency is necessary; e.g. if the current scene contains a red puck but the goal has a blue block, it would confuse the policy.\nOther than the state encoder $e(s) \\triangleq \\mu_\\phi(s)$, CC-VAE trains a second convolutional encoder $e_0(.)$ to translate the starting state $s_0$ into a compact context representation $c = e_0(s_0)$. Two encoders, $e(.)$ and $e_0(.)$, are intentionally different without shared weights, as they are expected to encode different factors of image variation. In addition to the loss function of CVAE, CC-VAE adds an extra term to learn to reconstruct $c$ back to $s_0$, $\\hat{s}_0 = d_0(c)$.\n $$ \\mathcal{L}_\\text{CC-VAE} = \\mathcal{L}_\\text{CVAE} + \\log p(s_0\\vert c) $$ Fig. 25. Examples of imagined goals generated by CVAE that conditions on the context image (the first row), while VAE fails to capture the object consistency. (Image source: Nair, et al., 2019). Bisimulation Task-agnostic representation (e.g. a model that intends to represent all the dynamics in the system) may distract the RL algorithms as irrelevant information is also presented. For example, if we just train an auto-encoder to reconstruct the input image, there is no guarantee that the entire learned representation will be useful for RL. Therefore, we need to move away from reconstruction-based representation learning if we only want to learn information relevant to control, as irrelevant details are still important for reconstruction.\nRepresentation learning for control based on bisimulation does not depend on reconstruction, but aims to group states based on their behavioral similarity in MDP.\nBisimulation (Givan et al. 2003) refers to an equivalence relation between two states with similar long-term behavior. Bisimulation metrics quantify such relation so that we can aggregate states to compress a high-dimensional state space into a smaller one for more efficient computation. The bisimulation distance between two states corresponds to how behaviorally different these two states are.\nGiven a MDP $\\mathcal{M} = \\langle \\mathcal{S}, \\mathcal{A}, \\mathcal{P}, \\mathcal{R}, \\gamma \\rangle$ and a bisimulation relation $B$, two states that are equal under relation $B$ (i.e. $s_i B s_j$) should have the same immediate reward for all actions and the same transition probabilities over the next bisimilar states:\n $$ \\begin{aligned} \\mathcal{R}(s_i, a) \u0026= \\mathcal{R}(s_j, a) \\; \\forall a \\in \\mathcal{A} \\\\ \\mathcal{P}(G \\vert s_i, a) \u0026= \\mathcal{P}(G \\vert s_j, a) \\; \\forall a \\in \\mathcal{A} \\; \\forall G \\in \\mathcal{S}_B \\end{aligned} $$ where $\\mathcal{S}_B$ is a partition of the state space under the relation $B$.\nNote that $=$ is always a bisimulation relation. The most interesting one is the maximal bisimulation relation $\\sim$, which defines a partition $\\mathcal{S}_\\sim$ with fewest groups of states.\nFig. 26. DeepMDP learns a latent space model by minimizing two losses on a reward model and a dynamics model. (Image source: Gelada, et al. 2019) With a goal similar to bisimulation metric, DeepMDP (Gelada, et al. 2019) simplifies high-dimensional observations in RL tasks and learns a latent space model via minimizing two losses:\n prediction of rewards and prediction of the distribution over next latent states. $$ \\begin{aligned} \\mathcal{L}_{\\bar{\\mathcal{R}}}(s, a) = \\vert \\mathcal{R}(s, a) - \\bar{\\mathcal{R}}(\\phi(s), a) \\vert \\\\ \\mathcal{L}_{\\bar{\\mathcal{P}}}(s, a) = D(\\phi \\mathcal{P}(s, a), \\bar{\\mathcal{P}}(. \\vert \\phi(s), a)) \\end{aligned} $$ where $\\phi(s)$ is the embedding of state $s$; symbols with bar are functions (reward function $R$ and transition function $P$) in the same MDP but running in the latent low-dimensional observation space. Here the embedding representation $\\phi$ can be connected to bisimulation metrics, as the bisimulation distance is proved to be upper-bounded by the L2 distance in the latent space.\nThe function $D$ quantifies the distance between two probability distributions and should be chosen carefully. DeepMDP focuses on Wasserstein-1 metric (also known as “earth-mover distance”). The Wasserstein-1 distance between distributions $P$ and $Q$ on a metric space $(M, d)$ (i.e., $d: M \\times M \\to \\mathbb{R}$) is:\n $$ W_d (P, Q) = \\inf_{\\lambda \\in \\Pi(P, Q)} \\int_{M \\times M} d(x, y) \\lambda(x, y) \\; \\mathrm{d}x \\mathrm{d}y $$ where $\\Pi(P, Q)$ is the set of all couplings of $P$ and $Q$. $d(x, y)$ defines the cost of moving a particle from point $x$ to point $y$.\nThe Wasserstein metric has a dual form according to the Monge-Kantorovich duality:\n $$ W_d (P, Q) = \\sup_{f \\in \\mathcal{F}_d} \\vert \\mathbb{E}_{x \\sim P} f(x) - \\mathbb{E}_{y \\sim Q} f(y) \\vert $$ where $\\mathcal{F}_d$ is the set of 1-Lipschitz functions under the metric $d$ - $\\mathcal{F}_d = \\{ f: \\vert f(x) - f(y) \\vert \\leq d(x, y) \\}$.\nDeepMDP generalizes the model to the Norm Maximum Mean Discrepancy (Norm-MMD) metrics to improve the tightness of the bounds of its deep value function and, at the same time, to save computation (Wasserstein is expensive computationally). In their experiments, they found the model architecture of the transition prediction model can have a big impact on the performance. Adding these DeepMDP losses as auxiliary losses when training model-free RL agents leads to good improvement on most of the Atari games.\nDeep Bisimulatioin for Control (short for DBC; Zhang et al. 2020) learns the latent representation of observations that are good for control in RL tasks, without domain knowledge or pixel-level reconstruction.\nFig. 27. The Deep Bisimulation for Control algorithm learns a bisimulation metric representation via learning a reward model and a dynamics model. The model architecture is a siamese network. (Image source: Zhang et al. 2020) Similar to DeepMDP, DBC models the dynamics by learning a reward model and a transition model. Both models operate in the latent space, $\\phi(s)$. The optimization of embedding $\\phi$ depends on one important conclusion from Ferns, et al. 2004 (Theorem 4.5) and Ferns, et al 2011 (Theorem 2.6):\n Given $c \\in (0, 1)$ a discounting factor, $\\pi$ a policy that is being improved continuously, and $M$ the space of bounded pseudometric on the state space $\\mathcal{S}$, we can define $\\mathcal{F}: M \\mapsto M$:\n $$ \\mathcal{F}(d; \\pi)(s_i, s_j) = (1-c) \\vert \\mathcal{R}_{s_i}^\\pi - \\mathcal{R}_{s_j}^\\pi \\vert + c W_d (\\mathcal{P}_{s_i}^\\pi, \\mathcal{P}_{s_j}^\\pi) $$ Then, $\\mathcal{F}$ has a unique fixed point $\\tilde{d}$ which is a $\\pi^*$-bisimulation metric and $\\tilde{d}(s_i, s_j) = 0 \\iff s_i \\sim s_j$.\n [The proof is not trivial. I may or may not add it in the future _(:3」∠)_ \u0026hellip;]\nGiven batches of observations pairs, the training loss for $\\phi$, $J(\\phi)$, minimizes the mean square error between the on-policy bisimulation metric and Euclidean distance in the latent space:\n$$ J(\\phi) = \\Big( \\|\\phi(s_i) - \\phi(s_j)\\|_1 - \\vert \\hat{\\mathcal{R}}(\\bar{\\phi}(s_i)) - \\hat{\\mathcal{R}}(\\bar{\\phi}(s_j)) \\vert - \\gamma W_2(\\hat{\\mathcal{P}}(\\cdot \\vert \\bar{\\phi}(s_i), \\bar{\\pi}(\\bar{\\phi}(s_i))), \\hat{\\mathcal{P}}(\\cdot \\vert \\bar{\\phi}(s_j), \\bar{\\pi}(\\bar{\\phi}(s_j)))) \\Big)^2 $$ where $\\bar{\\phi}(s)$ denotes $\\phi(s)$ with stop gradient and $\\bar{\\pi}$ is the mean policy output. The learned reward model $\\hat{\\mathcal{R}}$ is deterministic and the learned forward dynamics model $\\hat{\\mathcal{P}}$ outputs a Gaussian distribution.\nDBC is based on SAC but operates on the latent space:\nFig. 28. The algorithm of Deep Bisimulation for Control. (Image source: Zhang et al. 2020) Cited as:\n@article{weng2019selfsup, title = \u0026quot;Self-Supervised Representation Learning\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2019\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2019-11-10-self-supervised/\u0026quot; } References [1] Alexey Dosovitskiy, et al. \u0026ldquo;Discriminative unsupervised feature learning with exemplar convolutional neural networks.\u0026quot; IEEE transactions on pattern analysis and machine intelligence 38.9 (2015): 1734-1747.\n[2] Spyros Gidaris, Praveer Singh \u0026amp; Nikos Komodakis. \u0026ldquo;Unsupervised Representation Learning by Predicting Image Rotations\u0026rdquo; ICLR 2018.\n[3] Carl Doersch, Abhinav Gupta, and Alexei A. Efros. \u0026ldquo;Unsupervised visual representation learning by context prediction.\u0026quot; ICCV. 2015.\n[4] Mehdi Noroozi \u0026amp; Paolo Favaro. \u0026ldquo;Unsupervised learning of visual representations by solving jigsaw puzzles.\u0026quot; ECCV, 2016.\n[5] Mehdi Noroozi, Hamed Pirsiavash, and Paolo Favaro. \u0026ldquo;Representation learning by learning to count.\u0026quot; ICCV. 2017.\n[6] Richard Zhang, Phillip Isola \u0026amp; Alexei A. Efros. \u0026ldquo;Colorful image colorization.\u0026quot; ECCV, 2016.\n[7] Pascal Vincent, et al. \u0026ldquo;Extracting and composing robust features with denoising autoencoders.\u0026quot; ICML, 2008.\n[8] Jeff Donahue, Philipp Krähenbühl, and Trevor Darrell. \u0026ldquo;Adversarial feature learning.\u0026quot; ICLR 2017.\n[9] Deepak Pathak, et al. \u0026ldquo;Context encoders: Feature learning by inpainting.\u0026quot; CVPR. 2016.\n[10] Richard Zhang, Phillip Isola, and Alexei A. Efros. \u0026ldquo;Split-brain autoencoders: Unsupervised learning by cross-channel prediction.\u0026quot; CVPR. 2017.\n[11] Xiaolong Wang \u0026amp; Abhinav Gupta. \u0026ldquo;Unsupervised Learning of Visual Representations using Videos.\u0026quot; ICCV. 2015.\n[12] Carl Vondrick, et al. \u0026ldquo;Tracking Emerges by Colorizing Videos\u0026rdquo; ECCV. 2018.\n[13] Ishan Misra, C. Lawrence Zitnick, and Martial Hebert. \u0026ldquo;Shuffle and learn: unsupervised learning using temporal order verification.\u0026quot; ECCV. 2016.\n[14] Basura Fernando, et al. \u0026ldquo;Self-Supervised Video Representation Learning With Odd-One-Out Networks\u0026rdquo; CVPR. 2017.\n[15] Donglai Wei, et al. \u0026ldquo;Learning and Using the Arrow of Time\u0026rdquo; CVPR. 2018.\n[16] Florian Schroff, Dmitry Kalenichenko and James Philbin. \u0026ldquo;FaceNet: A Unified Embedding for Face Recognition and Clustering\u0026rdquo; CVPR. 2015.\n[17] Pierre Sermanet, et al. \u0026ldquo;Time-Contrastive Networks: Self-Supervised Learning from Video\u0026rdquo; CVPR. 2018.\n[18] Debidatta Dwibedi, et al. \u0026ldquo;Learning actionable representations from visual observations.\u0026quot; IROS. 2018.\n[19] Eric Jang \u0026amp; Coline Devin, et al. \u0026ldquo;Grasp2Vec: Learning Object Representations from Self-Supervised Grasping\u0026rdquo; CoRL. 2018.\n[20] Ashvin Nair, et al. \u0026ldquo;Visual reinforcement learning with imagined goals\u0026rdquo; NeuriPS. 2018.\n[21] Ashvin Nair, et al. \u0026ldquo;Contextual imagined goals for self-supervised robotic learning\u0026rdquo; CoRL. 2019.\n[22] Aaron van den Oord, Yazhe Li \u0026amp; Oriol Vinyals. \u0026ldquo;Representation Learning with Contrastive Predictive Coding\u0026rdquo; arXiv preprint arXiv:1807.03748, 2018.\n[23] Olivier J. Henaff, et al. \u0026ldquo;Data-Efficient Image Recognition with Contrastive Predictive Coding\u0026rdquo; arXiv preprint arXiv:1905.09272, 2019.\n[24] Kaiming He, et al. \u0026ldquo;Momentum Contrast for Unsupervised Visual Representation Learning.\u0026quot; CVPR 2020.\n[25] Zhirong Wu, et al. \u0026ldquo;Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination.\u0026quot; CVPR 2018.\n[26] Ting Chen, et al. \u0026ldquo;A Simple Framework for Contrastive Learning of Visual Representations.\u0026quot; arXiv preprint arXiv:2002.05709, 2020.\n[27] Aravind Srinivas, Michael Laskin \u0026amp; Pieter Abbeel \u0026ldquo;CURL: Contrastive Unsupervised Representations for Reinforcement Learning.\u0026quot; arXiv preprint arXiv:2004.04136, 2020.\n[28] Carles Gelada, et al. “DeepMDP: Learning Continuous Latent Space Models for Representation Learning” ICML 2019.\n[29] Amy Zhang, et al. “Learning Invariant Representations for Reinforcement Learning without Reconstruction” arXiv preprint arXiv:2006.10742, 2020.\n[30] Xinlei Chen, et al. “Improved Baselines with Momentum Contrastive Learning” arXiv preprint arXiv:2003.04297, 2020.\n[31] Jean-Bastien Grill, et al. “Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning” arXiv preprint arXiv:2006.07733, 2020.\n[32] Abe Fetterman \u0026amp; Josh Albrecht. “Understanding self-supervised and contrastive learning with Bootstrap Your Own Latent (BYOL)” Untitled blog. Aug 24, 2020.\n","permalink":"https://lilianweng.github.io/posts/2019-11-10-self-supervised/","summary":"[Updated on 2020-01-09: add a new section on Contrastive Predictive Coding]. [Updated on 2020-04-13: add a \u0026ldquo;Momentum Contrast\u0026rdquo; section on MoCo, SimCLR and CURL.] [Updated on 2020-07-08: add a \u0026ldquo;Bisimulation\u0026rdquo; section on DeepMDP and DBC.] [Updated on 2020-09-12: add MoCo V2 and BYOL in the \u0026ldquo;Momentum Contrast\u0026rdquo; section.] [Updated on 2021-05-31: remove section on \u0026ldquo;Momentum Contrast\u0026rdquo; and add a pointer to a full post on \u0026ldquo;Contrastive Representation Learning\u0026rdquo;]","title":"Self-Supervised Representation Learning"},{"content":"Stochastic gradient descent is a universal choice for optimizing deep learning models. However, it is not the only option. With black-box optimization algorithms, you can evaluate a target function $f(x): \\mathbb{R}^n \\to \\mathbb{R}$, even when you don\u0026rsquo;t know the precise analytic form of $f(x)$ and thus cannot compute gradients or the Hessian matrix. Examples of black-box optimization methods include Simulated Annealing, Hill Climbing and Nelder-Mead method.\nEvolution Strategies (ES) is one type of black-box optimization algorithms, born in the family of Evolutionary Algorithms (EA). In this post, I would dive into a couple of classic ES methods and introduce a few applications of how ES can play a role in deep reinforcement learning.\nWhat are Evolution Strategies? Evolution strategies (ES) belong to the big family of evolutionary algorithms. The optimization targets of ES are vectors of real numbers, $x \\in \\mathbb{R}^n$.\nEvolutionary algorithms refer to a division of population-based optimization algorithms inspired by natural selection. Natural selection believes that individuals with traits beneficial to their survival can live through generations and pass down the good characteristics to the next generation. Evolution happens by the selection process gradually and the population grows better adapted to the environment.\nFig. 1. How natural selection works. (Image source: Khan Academy: Darwin, evolution, \u0026 natural selection) Evolutionary algorithms can be summarized in the following format as a general optimization solution:\nLet\u0026rsquo;s say we want to optimize a function $f(x)$ and we are not able to compute gradients directly. But we still can evaluate $f(x)$ given any $x$ and the result is deterministic. Our belief in the probability distribution over $x$ as a good solution to $f(x)$ optimization is $p_\\theta(x)$, parameterized by $\\theta$. The goal is to find an optimal configuration of $\\theta$.\n Here given a fixed format of distribution (i.e. Gaussian), the parameter $\\theta$ carries the knowledge about the best solutions and is being iteratively updated across generations.\n Starting with an initial value of $\\theta$, we can continuously update $\\theta$ by looping three steps as follows:\n Generate a population of samples $D = \\{(x_i, f(x_i)\\}$ where $x_i \\sim p_\\theta(x)$. Evaluate the \u0026ldquo;fitness\u0026rdquo; of samples in $D$. Select the best subset of individuals and use them to update $\\theta$, generally based on fitness or rank. In Genetic Algorithms (GA), another popular subcategory of EA, $x$ is a sequence of binary codes, $x \\in \\{0, 1\\}^n$. While in ES, $x$ is just a vector of real numbers, $x \\in \\mathbb{R}^n$.\nSimple Gaussian Evolution Strategies This is the most basic and canonical version of evolution strategies. It models $p_\\theta(x)$ as a $n$-dimensional isotropic Gaussian distribution, in which $\\theta$ only tracks the mean $\\mu$ and standard deviation $\\sigma$.\n $$ \\theta = (\\mu, \\sigma),\\;p_\\theta(x) \\sim \\mathcal{N}(\\mathbf{\\mu}, \\sigma^2 I) = \\mu + \\sigma \\mathcal{N}(0, I) $$ The process of Simple-Gaussian-ES, given $x \\in \\mathcal{R}^n$:\n Initialize $\\theta = \\theta^{(0)}$ and the generation counter $t=0$ Generate the offspring population of size $\\Lambda$ by sampling from the Gaussian distribution:$D^{(t+1)}=\\{ x^{(t+1)}_i \\mid x^{(t+1)}_i = \\mu^{(t)} + \\sigma^{(t)} y^{(t+1)}_i \\text{ where } y^{(t+1)}_i \\sim \\mathcal{N}(x \\vert 0, \\mathbf{I}),;i = 1, \\dots, \\Lambda\\}$. Select a top subset of $\\lambda$ samples with optimal $f(x_i)$ and this subset is called elite set. Without loss of generality, we may consider the first $k$ samples in $D^{(t+1)}$ to belong to the elite group \u0026mdash; Let\u0026rsquo;s label them as $$ D^{(t+1)}\\_\\text{elite} = \\\\{x^{(t+1)}\\_i \\mid x^{(t+1)}\\_i \\in D^{(t+1)}, i=1,\\dots, \\lambda, \\lambda\\leq \\Lambda\\\\} $$ Then we estimate the new mean and std for the next generation using the elite set: $$ \\begin{aligned} \\mu^{(t+1)} \u0026= \\text{avg}(D^{(t+1)}_\\text{elite}) = \\frac{1}{\\lambda}\\sum_{i=1}^\\lambda x_i^{(t+1)} \\\\ {\\sigma^{(t+1)}}^2 \u0026= \\text{var}(D^{(t+1)}_\\text{elite}) = \\frac{1}{\\lambda}\\sum_{i=1}^\\lambda (x_i^{(t+1)} -\\mu^{(t)})^2 \\end{aligned} $$ Repeat steps (2)-(4) until the result is good enough ✌️ Covariance Matrix Adaptation Evolution Strategies (CMA-ES) The standard deviation $\\sigma$ accounts for the level of exploration: the larger $\\sigma$ the bigger search space we can sample our offspring population. In vanilla ES, $\\sigma^{(t+1)}$ is highly correlated with $\\sigma^{(t)}$, so the algorithm is not able to rapidly adjust the exploration space when needed (i.e. when the confidence level changes).\nCMA-ES, short for \u0026ldquo;Covariance Matrix Adaptation Evolution Strategy\u0026rdquo;, fixes the problem by tracking pairwise dependencies between the samples in the distribution with a covariance matrix $C$. The new distribution parameter becomes:\n $$ \\theta = (\\mu, \\sigma, C),\\; p_\\theta(x) \\sim \\mathcal{N}(\\mu, \\sigma^2 C) \\sim \\mu + \\sigma \\mathcal{N}(0, C) $$ where $\\sigma$ controls for the overall scale of the distribution, often known as step size.\nBefore we dig into how the parameters are updated in CMA-ES, it is better to review how the covariance matrix works in the multivariate Gaussian distribution first. As a real symmetric matrix, the covariance matrix $C$ has the following nice features (See proof \u0026amp; proof):\n It is always diagonalizable. Always positive semi-definite. All of its eigenvalues are real non-negative numbers. All of its eigenvectors are orthogonal. There is an orthonormal basis of $\\mathbb{R}^n$ consisting of its eigenvectors. Let the matrix $C$ have an orthonormal basis of eigenvectors $B = [b_1, \\dots, b_n]$, with corresponding eigenvalues $\\lambda_1^2, \\dots, \\lambda_n^2$. Let $D=\\text{diag}(\\lambda_1, \\dots, \\lambda_n)$.\n $$ C = B^\\top D^2 B = \\begin{bmatrix} \\mid \u0026 \\mid \u0026 \u0026 \\mid \\\\ b_1 \u0026 b_2 \u0026 \\dots \u0026 b_n\\\\ \\mid \u0026 \\mid \u0026 \u0026 \\mid \\\\ \\end{bmatrix} \\begin{bmatrix} \\lambda_1^2 \u0026 0 \u0026 \\dots \u0026 0 \\\\ 0 \u0026 \\lambda_2^2 \u0026 \\dots \u0026 0 \\\\ \\vdots \u0026 \\dots \u0026 \\ddots \u0026 \\vdots \\\\ 0 \u0026 \\dots \u0026 0 \u0026 \\lambda_n^2 \\end{bmatrix} \\begin{bmatrix} - \u0026 b_1 \u0026 - \\\\ - \u0026 b_2 \u0026 - \\\\ \u0026 \\dots \u0026 \\\\ - \u0026 b_n \u0026 - \\\\ \\end{bmatrix} $$ The square root of $C$ is:\n $$ C^{\\frac{1}{2}} = B^\\top D B $$ Symbol Meaning $x_i^{(t)} \\in \\mathbb{R}^n$ the $i$-th samples at the generation (t) $y_i^{(t)} \\in \\mathbb{R}^n$ $x_i^{(t)} = \\mu^{(t-1)} + \\sigma^{(t-1)} y_i^{(t)} $ $\\mu^{(t)}$ mean of the generation (t) $\\sigma^{(t)}$ step size $C^{(t)}$ covariance matrix $B^{(t)}$ a matrix of $C$\u0026rsquo;s eigenvectors as row vectors $D^{(t)}$ a diagonal matrix with $C$\u0026rsquo;s eigenvalues on the diagnose. $p_\\sigma^{(t)}$ evaluation path for $\\sigma$ at the generation (t) $p_c^{(t)}$ evaluation path for $C$ at the generation (t) $\\alpha_\\mu$ learning rate for $\\mu$\u0026rsquo;s update $\\alpha_\\sigma$ learning rate for $p_\\sigma$ $d_\\sigma$ damping factor for $\\sigma$\u0026rsquo;s update $\\alpha_{cp}$ learning rate for $p_c$ $\\alpha_{c\\lambda}$ learning rate for $C$\u0026rsquo;s rank-min(λ, n) update $\\alpha_{c1}$ learning rate for $C$\u0026rsquo;s rank-1 update Updating the Mean $$ \\mu^{(t+1)} = \\mu^{(t)} + \\alpha_\\mu \\frac{1}{\\lambda}\\sum_{i=1}^\\lambda (x_i^{(t+1)} - \\mu^{(t)}) $$ CMA-ES has a learning rate $\\alpha_\\mu \\leq 1$ to control how fast the mean $\\mu$ should be updated. Usually it is set to 1 and thus the equation becomes the same as in vanilla ES, $\\mu^{(t+1)} = \\frac{1}{\\lambda}\\sum_{i=1}^\\lambda (x_i^{(t+1)}$.\nControlling the Step Size The sampling process can be decoupled from the mean and standard deviation:\n $$ x^{(t+1)}_i = \\mu^{(t)} + \\sigma^{(t)} y^{(t+1)}_i \\text{, where } y^{(t+1)}_i = \\frac{x_i^{(t+1)} - \\mu^{(t)}}{\\sigma^{(t)}} \\sim \\mathcal{N}(0, C) $$ The parameter $\\sigma$ controls the overall scale of the distribution. It is separated from the covariance matrix so that we can change steps faster than the full covariance. A larger step size leads to faster parameter update. In order to evaluate whether the current step size is proper, CMA-ES constructs an evolution path $p_\\sigma$ by summing up a consecutive sequence of moving steps, $\\frac{1}{\\lambda}\\sum_{i}^\\lambda y_i^{(j)}, j=1, \\dots, t$. By comparing this path length with its expected length under random selection (meaning single steps are uncorrelated), we are able to adjust $\\sigma$ accordingly (See Fig. 2).\nFig. 2. Three scenarios of how single steps are correlated in different ways and their impacts on step size update. (Image source: additional annotations on Fig 5 in CMA-ES tutorial paper) Each time the evolution path is updated with the average of moving step $y_i$ in the same generation.\n $$ \\begin{aligned} \u0026\\frac{1}{\\lambda}\\sum_{i=1}^\\lambda y_i^{(t+1)} = \\frac{1}{\\lambda} \\frac{\\sum_{i=1}^\\lambda x_i^{(t+1)} - \\lambda \\mu^{(t)}}{\\sigma^{(t)}} = \\frac{\\mu^{(t+1)} - \\mu^{(t)}}{\\sigma^{(t)}} \\\\ \u0026\\frac{1}{\\lambda}\\sum_{i=1}^\\lambda y_i^{(t+1)} \\sim \\frac{1}{\\lambda}\\mathcal{N}(0, \\lambda C^{(t)}) \\sim \\frac{1}{\\sqrt{\\lambda}}{C^{(t)}}^{\\frac{1}{2}}\\mathcal{N}(0, I) \\\\ \u0026\\text{Thus } \\sqrt{\\lambda}\\;{C^{(t)}}^{-\\frac{1}{2}} \\frac{\\mu^{(t+1)} - \\mu^{(t)}}{\\sigma^{(t)}} \\sim \\mathcal{N}(0, I) \\end{aligned} $$ By multiplying with $C^{-\\frac{1}{2}}$, the evolution path is transformed to be independent of its direction. The term ${C^{(t)}}^{-\\frac{1}{2}} = {B^{(t)}}^\\top {D^{(t)}}^{-\\frac{1}{2}} {B^{(t)}}$ transformation works as follows:\n ${B^{(t)}}$ contains row vectors of $C$\u0026rsquo;s eigenvectors. It projects the original space onto the perpendicular principal axes. Then ${D^{(t)}}^{-\\frac{1}{2}} = \\text{diag}(\\frac{1}{\\lambda_1}, \\dots, \\frac{1}{\\lambda_n})$ scales the length of principal axes to be equal. ${B^{(t)}}^\\top$ transforms the space back to the original coordinate system. In order to assign higher weights to recent generations, we use polyak averaging to update the evolution path with learning rate $\\alpha_\\sigma$. Meanwhile, the weights are balanced so that $p_\\sigma$ is conjugate, $\\sim \\mathcal{N}(0, I)$ both before and after one update.\n $$ \\begin{aligned} p_\\sigma^{(t+1)} \u0026 = (1 - \\alpha_\\sigma) p_\\sigma^{(t)} + \\sqrt{1 - (1 - \\alpha_\\sigma)^2}\\;\\sqrt{\\lambda}\\; {C^{(t)}}^{-\\frac{1}{2}} \\frac{\\mu^{(t+1)} - \\mu^{(t)}}{\\sigma^{(t)}} \\\\ \u0026 = (1 - \\alpha_\\sigma) p_\\sigma^{(t)} + \\sqrt{c_\\sigma (2 - \\alpha_\\sigma)\\lambda}\\;{C^{(t)}}^{-\\frac{1}{2}} \\frac{\\mu^{(t+1)} - \\mu^{(t)}}{\\sigma^{(t)}} \\end{aligned} $$ The expected length of $p_\\sigma$ under random selection is $\\mathbb{E}|\\mathcal{N}(0,I)|$, that is the expectation of the L2-norm of a $\\mathcal{N}(0,I)$ random variable. Following the idea in Fig. 2, we adjust the step size according to the ratio of $|p_\\sigma^{(t+1)}| / \\mathbb{E}|\\mathcal{N}(0,I)|$:\n $$ \\begin{aligned} \\ln\\sigma^{(t+1)} \u0026= \\ln\\sigma^{(t)} + \\frac{\\alpha_\\sigma}{d_\\sigma} \\Big(\\frac{\\|p_\\sigma^{(t+1)}\\|}{\\mathbb{E}\\|\\mathcal{N}(0,I)\\|} - 1\\Big) \\\\ \\sigma^{(t+1)} \u0026= \\sigma^{(t)} \\exp\\Big(\\frac{\\alpha_\\sigma}{d_\\sigma} \\Big(\\frac{\\|p_\\sigma^{(t+1)}\\|}{\\mathbb{E}\\|\\mathcal{N}(0,I)\\|} - 1\\Big)\\Big) \\end{aligned} $$ where $d_\\sigma \\approx 1$ is a damping parameter, scaling how fast $\\ln\\sigma$ should be changed.\nAdapting the Covariance Matrix For the covariance matrix, it can be estimated from scratch using $y_i$ of elite samples (recall that $y_i \\sim \\mathcal{N}(0, C)$):\n $$ C_\\lambda^{(t+1)} = \\frac{1}{\\lambda}\\sum_{i=1}^\\lambda y^{(t+1)}_i {y^{(t+1)}_i}^\\top = \\frac{1}{\\lambda {\\sigma^{(t)}}^2} \\sum_{i=1}^\\lambda (x_i^{(t+1)} - \\mu^{(t)})(x_i^{(t+1)} - \\mu^{(t)})^\\top $$ The above estimation is only reliable when the selected population is large enough. However, we do want to run fast iteration with a small population of samples in each generation. That\u0026rsquo;s why CMA-ES invented a more reliable but also more complicated way to update $C$. It involves two independent routes,\n Rank-min(λ, n) update: uses the history of $\\{C_\\lambda\\}$, each estimated from scratch in one generation. Rank-one update: estimates the moving steps $y_i$ and the sign information from the history. The first route considers the estimation of $C$ from the entire history of $\\{C_\\lambda\\}$. For example, if we have experienced a large number of generations, $C^{(t+1)} \\approx \\text{avg}(C_\\lambda^{(i)}; i=1,\\dots,t)$ would be a good estimator. Similar to $p_\\sigma$, we also use polyak averaging with a learning rate to incorporate the history:\n $$ C^{(t+1)} = (1 - \\alpha_{c\\lambda}) C^{(t)} + \\alpha_{c\\lambda} C_\\lambda^{(t+1)} = (1 - \\alpha_{c\\lambda}) C^{(t)} + \\alpha_{c\\lambda} \\frac{1}{\\lambda} \\sum_{i=1}^\\lambda y^{(t+1)}_i {y^{(t+1)}_i}^\\top $$ A common choice for the learning rate is $\\alpha_{c\\lambda} \\approx \\min(1, \\lambda/n^2)$.\nThe second route tries to solve the issue that $y_i{y_i}^\\top = (-y_i)(-y_i)^\\top$ loses the sign information. Similar to how we adjust the step size $\\sigma$, an evolution path $p_c$ is used to track the sign information and it is constructed in a way that $p_c$ is conjugate, $\\sim \\mathcal{N}(0, C)$ both before and after a new generation.\nWe may consider $p_c$ as another way to compute $\\text{avg}_i(y_i)$ (notice that both $\\sim \\mathcal{N}(0, C)$) while the entire history is used and the sign information is maintained. Note that we\u0026rsquo;ve known $\\sqrt{k}\\frac{\\mu^{(t+1)} - \\mu^{(t)}}{\\sigma^{(t)}} \\sim \\mathcal{N}(0, C)$ in the last section,\n $$ \\begin{aligned} p_c^{(t+1)} \u0026= (1-\\alpha_{cp}) p_c^{(t)} + \\sqrt{1 - (1-\\alpha_{cp})^2}\\;\\sqrt{\\lambda}\\;\\frac{\\mu^{(t+1)} - \\mu^{(t)}}{\\sigma^{(t)}} \\\\ \u0026= (1-\\alpha_{cp}) p_c^{(t)} + \\sqrt{\\alpha_{cp}(2 - \\alpha_{cp})\\lambda}\\;\\frac{\\mu^{(t+1)} - \\mu^{(t)}}{\\sigma^{(t)}} \\end{aligned} $$ Then the covariance matrix is updated according to $p_c$:\n $$ C^{(t+1)} = (1-\\alpha_{c1}) C^{(t)} + \\alpha_{c1}\\;p_c^{(t+1)} {p_c^{(t+1)}}^\\top $$ The rank-one update approach is claimed to generate a significant improvement over the rank-min(λ, n)-update when $k$ is small, because the signs of moving steps and correlations between consecutive steps are all utilized and passed down through generations.\nEventually we combine two approaches together,\n $$ C^{(t+1)} = (1 - \\alpha_{c\\lambda} - \\alpha_{c1}) C^{(t)} + \\alpha_{c1}\\;\\underbrace{p_c^{(t+1)} {p_c^{(t+1)}}^\\top}_\\textrm{rank-one update} + \\alpha_{c\\lambda} \\underbrace{\\frac{1}{\\lambda} \\sum_{i=1}^\\lambda y^{(t+1)}_i {y^{(t+1)}_i}^\\top}_\\textrm{rank-min(lambda, n) update} $$ In all my examples above, each elite sample is considered to contribute an equal amount of weights, $1/\\lambda$. The process can be easily extended to the case where selected samples are assigned with different weights, $w_1, \\dots, w_\\lambda$, according to their performances. See more detail in tutorial.\nFig. 3. Illustration of how CMA-ES works on a 2D optimization problem (the lighter color the better). Black dots are samples in one generation. The samples are more spread out initially but when the model has higher confidence in finding a good solution in the late stage, the samples become very concentrated over the global optimum. (Image source: Wikipedia CMA-ES) Natural Evolution Strategies Natural Evolution Strategies (NES; Wierstra, et al, 2008) optimizes in a search distribution of parameters and moves the distribution in the direction of high fitness indicated by the natural gradient.\nNatural Gradients Given an objective function $\\mathcal{J}(\\theta)$ parameterized by $\\theta$, let\u0026rsquo;s say our goal is to find the optimal $\\theta$ to maximize the objective function value. A plain gradient finds the steepest direction within a small Euclidean distance from the current $\\theta$; the distance restriction is applied on the parameter space. In other words, we compute the plain gradient with respect to a small change of the absolute value of $\\theta$. The optimal step is:\n $$ d^{*} = \\operatorname*{argmax}_{\\|d\\| = \\epsilon} \\mathcal{J}(\\theta + d)\\text{, where }\\epsilon \\to 0 $$ Differently, natural gradient works with a probability distribution space parameterized by $\\theta$, $p_\\theta(x)$ (referred to as \u0026ldquo;search distribution\u0026rdquo; in NES paper). It looks for the steepest direction within a small step in the distribution space where the distance is measured by KL divergence. With this constraint we ensure that each update is moving along the distributional manifold with constant speed, without being slowed down by its curvature.\n $$ d^{*}_\\text{N} = \\operatorname*{argmax}_{\\text{KL}[p_\\theta \\| p_{\\theta+d}] = \\epsilon} \\mathcal{J}(\\theta + d) $$ Estimation using Fisher Information Matrix But, how to compute $\\text{KL}[p_\\theta | p_{\\theta+\\Delta\\theta}]$ precisely? By running Taylor expansion of $\\log p_{\\theta + d}$ at $\\theta$, we get:\n $$ \\begin{aligned} \u0026 \\text{KL}[p_\\theta \\| p_{\\theta+d}] \\\\ \u0026= \\mathbb{E}_{x \\sim p_\\theta} [\\log p_\\theta(x) - \\log p_{\\theta+d}(x)] \u0026 \\\\ \u0026\\approx \\mathbb{E}_{x \\sim p_\\theta} [ \\log p_\\theta(x) -( \\log p_{\\theta}(x) + \\nabla_\\theta \\log p_{\\theta}(x) d + \\frac{1}{2}d^\\top \\nabla^2_\\theta \\log p_{\\theta}(x) d)] \u0026 \\scriptstyle{\\text{; Taylor expand }\\log p_{\\theta+d}} \\\\ \u0026\\approx - \\mathbb{E}_x [\\nabla_\\theta \\log p_{\\theta}(x)] d - \\frac{1}{2}d^\\top \\mathbb{E}_x [\\nabla^2_\\theta \\log p_{\\theta}(x)] d \u0026 \\end{aligned} $$ where\n $$ \\begin{aligned} \\mathbb{E}_x [\\nabla_\\theta \\log p_{\\theta}] d \u0026= \\int_{x\\sim p_\\theta} p_\\theta(x) \\nabla_\\theta \\log p_\\theta(x) \u0026 \\\\ \u0026= \\int_{x\\sim p_\\theta} p_\\theta(x) \\frac{1}{p_\\theta(x)} \\nabla_\\theta p_\\theta(x) \u0026 \\\\ \u0026= \\nabla_\\theta \\Big( \\int_{x} p_\\theta(x) \\Big) \u0026 \\scriptstyle{\\textrm{; note that }p_\\theta(x)\\textrm{ is probability distribution.}} \\\\ \u0026= \\nabla_\\theta (1) = 0 \\end{aligned} $$ Finally we have,\n $$ \\text{KL}[p_\\theta \\| p_{\\theta+d}] = - \\frac{1}{2}d^\\top \\mathbf{F}_\\theta d \\text{, where }\\mathbf{F}_\\theta = \\mathbb{E}_x [(\\nabla_\\theta \\log p_{\\theta}) (\\nabla_\\theta \\log p_{\\theta})^\\top] $$ where $\\mathbf{F}_\\theta$ is called the Fisher Information Matrix and it is the covariance matrix of $\\nabla_\\theta \\log p_\\theta$ since $\\mathbb{E}[\\nabla_\\theta \\log p_\\theta] = 0$.\nThe solution to the following optimization problem:\n $$ \\max \\mathcal{J}(\\theta + d) \\approx \\max \\big( \\mathcal{J}(\\theta) + {\\nabla_\\theta\\mathcal{J}(\\theta)}^\\top d \\big)\\;\\text{ s.t. }\\text{KL}[p_\\theta \\| p_{\\theta+d}] - \\epsilon = 0 $$ can be found using a Lagrangian multiplier,\n $$ \\begin{aligned} \\mathcal{L}(\\theta, d, \\beta) \u0026= \\mathcal{J}(\\theta) + \\nabla_\\theta\\mathcal{J}(\\theta)^\\top d - \\beta (\\frac{1}{2}d^\\top \\mathbf{F}_\\theta d + \\epsilon) = 0 \\text{ s.t. } \\beta 0 \\\\ \\nabla_d \\mathcal{L}(\\theta, d, \\beta) \u0026= \\nabla_\\theta\\mathcal{J}(\\theta) - \\beta\\mathbf{F}_\\theta d = 0 \\\\ \\text{Thus } d_\\text{N}^* \u0026= \\nabla_\\theta^\\text{N} \\mathcal{J}(\\theta) = \\mathbf{F}_\\theta^{-1} \\nabla_\\theta\\mathcal{J}(\\theta) \\end{aligned} $$ where $d_\\text{N}^*$ only extracts the direction of the optimal moving step on $\\theta$, ignoring the scalar $\\beta^{-1}$.\nFig. 4. The natural gradient samples (black solid arrows) in the right are the plain gradient samples (black solid arrows) in the left multiplied by the inverse of their covariance. In this way, a gradient direction with high uncertainty (indicated by high covariance with other samples) are penalized with a small weight. The aggregated natural gradient (red dash arrow) is therefore more trustworthy than the natural gradient (green solid arrow). (Image source: additional annotations on Fig 2 in NES paper) NES Algorithm The fitness associated with one sample is labeled as $f(x)$ and the search distribution over $x$ is parameterized by $\\theta$. NES is expected to optimize the parameter $\\theta$ to achieve maximum expected fitness:\n $$ \\mathcal{J}(\\theta) = \\mathbb{E}_{x\\sim p_\\theta(x)} [f(x)] = \\int_x f(x) p_\\theta(x) dx $$ Using the same log-likelihood trick in REINFORCE:\n $$ \\begin{aligned} \\nabla_\\theta\\mathcal{J}(\\theta) \u0026= \\nabla_\\theta \\int_x f(x) p_\\theta(x) dx \\\\ \u0026= \\int_x f(x) \\frac{p_\\theta(x)}{p_\\theta(x)}\\nabla_\\theta p_\\theta(x) dx \\\\ \u0026 = \\int_x f(x) p_\\theta(x) \\nabla_\\theta \\log p_\\theta(x) dx \\\\ \u0026 = \\mathbb{E}_{x \\sim p_\\theta} [f(x) \\nabla_\\theta \\log p_\\theta(x)] \\end{aligned} $$ Besides natural gradients, NES adopts a couple of important heuristics to make the algorithm performance more robust.\n NES applies rank-based fitness shaping, that is to use the rank under monotonically increasing fitness values instead of using $f(x)$ directly. Or it can be a function of the rank (“utility function”), which is considered as a free parameter of NES. NES adopts adaptation sampling to adjust hyperparameters at run time. When changing $\\theta \\to \\theta’$, samples drawn from $p_\\theta$ are compared with samples from $p_{\\theta’}$ using [Mann-Whitney U-test(https://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U_test)]; if there shows a positive or negative sign, the target hyperparameter decreases or increases by a multiplication constant. Note the score of a sample $x’_i \\sim p_{\\theta’}(x)$ has importance sampling weights applied $w_i’ = p_\\theta(x) / p_{\\theta’}(x)$. Applications: ES in Deep Reinforcement Learning OpenAI ES for RL The concept of using evolutionary algorithms in reinforcement learning can be traced back long ago, but only constrained to tabular RL due to computational limitations.\nInspired by NES, researchers at OpenAI (Salimans, et al. 2017) proposed to use NES as a gradient-free black-box optimizer to find optimal policy parameters $\\theta$ that maximizes the return function $F(\\theta)$. The key is to add Gaussian noise $\\epsilon$ on the model parameter $\\theta$ and then use the log-likelihood trick to write it as the gradient of the Gaussian pdf. Eventually only the noise term is left as a weighting scalar for measured performance.\nLet’s say the current parameter value is $\\hat{\\theta}$ (the added hat is to distinguish the value from the random variable $\\theta$). The search distribution of $\\theta$ is designed to be an isotropic multivariate Gaussian with a mean $\\hat{\\theta}$ and a fixed covariance matrix $\\sigma^2 I$,\n $$ \\theta \\sim \\mathcal{N}(\\hat{\\theta}, \\sigma^2 I) \\text{ equivalent to } \\theta = \\hat{\\theta} + \\sigma\\epsilon, \\epsilon \\sim \\mathcal{N}(0, I) $$ The gradient for $\\theta$ update is:\n $$ \\begin{aligned} \u0026 \\nabla_\\theta \\mathbb{E}_{\\theta\\sim\\mathcal{N}(\\hat{\\theta}, \\sigma^2 I)} F(\\theta) \\\\ \u0026= \\nabla_\\theta \\mathbb{E}_{\\epsilon\\sim\\mathcal{N}(0, I)} F(\\hat{\\theta} + \\sigma\\epsilon) \\\\ \u0026= \\nabla_\\theta \\int_{\\epsilon} p(\\epsilon) F(\\hat{\\theta} + \\sigma\\epsilon) d\\epsilon \u0026 \\scriptstyle{\\text{; Gaussian }p(\\epsilon)=(2\\pi)^{-\\frac{n}{2}} \\exp(-\\frac{1}{2}\\epsilon^\\top\\epsilon)} \\\\ \u0026= \\int_{\\epsilon} p(\\epsilon) \\nabla_\\epsilon \\log p(\\epsilon) \\nabla_\\theta \\epsilon\\;F(\\hat{\\theta} + \\sigma\\epsilon) d\\epsilon \u0026 \\scriptstyle{\\text{; log-likelihood trick}}\\\\ \u0026= \\mathbb{E}_{\\epsilon\\sim\\mathcal{N}(0, I)} [ \\nabla_\\epsilon \\big(-\\frac{1}{2}\\epsilon^\\top\\epsilon\\big) \\nabla_\\theta \\big(\\frac{\\theta - \\hat{\\theta}}{\\sigma}\\big) F(\\hat{\\theta} + \\sigma\\epsilon) ] \u0026 \\\\ \u0026= \\mathbb{E}_{\\epsilon\\sim\\mathcal{N}(0, I)} [ (-\\epsilon) (\\frac{1}{\\sigma}) F(\\hat{\\theta} + \\sigma\\epsilon) ] \u0026 \\\\ \u0026= \\frac{1}{\\sigma}\\mathbb{E}_{\\epsilon\\sim\\mathcal{N}(0, I)} [ \\epsilon F(\\hat{\\theta} + \\sigma\\epsilon) ] \u0026 \\scriptstyle{\\text{; negative sign can be absorbed.}} \\end{aligned} $$ In one generation, we can sample many $epsilon_i, i=1,\\dots,n$ and evaluate the fitness in parallel. One beautiful design is that no large model parameter needs to be shared. By only communicating the random seeds between workers, it is enough for the master node to do parameter update. This approach is later extended to adaptively learn a loss function; see my previous post on Evolved Policy Gradient.\nFig. 5. The algorithm for training a RL policy using evolution strategies. (Image source: ES-for-RL paper) To make the performance more robust, OpenAI ES adopts virtual batch normalization (BN with mini-batch used for calculating statistics fixed), mirror sampling (sampling a pair of $(-\\epsilon, \\epsilon)$ for evaluation), and fitness shaping.\nExploration with ES Exploration (vs exploitation) is an important topic in RL. The optimization direction in the ES algorithm above is only extracted from the cumulative return $F(\\theta)$. Without explicit exploration, the agent might get trapped in a local optimum.\nNovelty-Search ES (NS-ES; Conti et al, 2018) encourages exploration by updating the parameter in the direction to maximize the novelty score. The novelty score depends on a domain-specific behavior characterization function $b(\\pi_\\theta)$. The choice of $b(\\pi_\\theta)$ is specific to the task and seems to be a bit arbitrary; for example, in the Humanoid locomotion task in the paper, $b(\\pi_\\theta)$ is the final $(x,y)$ location of the agent.\n Every policy\u0026rsquo;s $b(\\pi_\\theta)$ is pushed to an archive set $\\mathcal{A}$. Novelty of a policy $\\pi_\\theta$ is measured as the k-nearest neighbor score between $b(\\pi_\\theta)$ and all other entries in $\\mathcal{A}$. (The use case of the archive set sounds quite similar to episodic memory.) $$ N(\\theta, \\mathcal{A}) = \\frac{1}{\\lambda} \\sum_{i=1}^\\lambda \\| b(\\pi_\\theta), b^\\text{knn}_i \\|_2 \\text{, where }b^\\text{knn}_i \\in \\text{kNN}(b(\\pi_\\theta), \\mathcal{A}) $$ The ES optimization step relies on the novelty score instead of fitness:\n $$ \\nabla_\\theta \\mathbb{E}_{\\theta\\sim\\mathcal{N}(\\hat{\\theta}, \\sigma^2 I)} N(\\theta, \\mathcal{A}) = \\frac{1}{\\sigma}\\mathbb{E}_{\\epsilon\\sim\\mathcal{N}(0, I)} [ \\epsilon N(\\hat{\\theta} + \\sigma\\epsilon, \\mathcal{A}) ] $$ NS-ES maintains a group of $M$ independently trained agents (\u0026ldquo;meta-population\u0026rdquo;), $\\mathcal{M} = \\{\\theta_1, \\dots, \\theta_M \\}$ and picks one to advance proportional to the novelty score. Eventually we select the best policy. This process is equivalent to ensembling; also see the same idea in SVPG.\n $$ \\begin{aligned} m \u0026\\leftarrow \\text{pick } i=1,\\dots,M\\text{ according to probability}\\frac{N(\\theta_i, \\mathcal{A})}{\\sum_{j=1}^M N(\\theta_j, \\mathcal{A})} \\\\ \\theta_m^{(t+1)} \u0026\\leftarrow \\theta_m^{(t)} + \\alpha \\frac{1}{\\sigma}\\sum_{i=1}^N \\epsilon_i N(\\theta^{(t)}_m + \\epsilon_i, \\mathcal{A}) \\text{ where }\\epsilon_i \\sim \\mathcal{N}(0, I) \\end{aligned} $$ where $N$ is the number of Gaussian perturbation noise vectors and $\\alpha$ is the learning rate.\nNS-ES completely discards the reward function and only optimizes for novelty to avoid deceptive local optima. To incorporate the fitness back into the formula, another two variations are proposed.\nNSR-ES:\n $$ \\theta_m^{(t+1)} \\leftarrow \\theta_m^{(t)} + \\alpha \\frac{1}{\\sigma}\\sum_{i=1}^N \\epsilon_i \\frac{N(\\theta^{(t)}_m + \\epsilon_i, \\mathcal{A}) + F(\\theta^{(t)}_m + \\epsilon_i)}{2} $$ NSRAdapt-ES (NSRA-ES): the adaptive weighting parameter $w = 1.0$ initially. We start decreasing $w$ if performance stays flat for a number of generations. Then when the performance starts to increase, we stop decreasing $w$ but increase it instead. In this way, fitness is preferred when the performance stops growing but novelty is preferred otherwise.\n $$ \\theta_m^{(t+1)} \\leftarrow \\theta_m^{(t)} + \\alpha \\frac{1}{\\sigma}\\sum_{i=1}^N \\epsilon_i \\big((1-w) N(\\theta^{(t)}_m + \\epsilon_i, \\mathcal{A}) + w F(\\theta^{(t)}_m + \\epsilon_i)\\big) $$ Fig. 6. (Left) The environment is Humanoid locomotion with a three-sided wall which plays a role as a deceptive trap to create local optimum. (Right) Experiments compare ES baseline and other variations that encourage exploration. (Image source: NS-ES paper) CEM-RL Fig. 7. Architectures of the (a) CEM-RL and (b) ERL algorithms (Image source: CEM-RL paper) The CEM-RL method (Pourchot \u0026amp; Sigaud, 2019) combines Cross Entropy Method (CEM) with either DDPG or TD3. CEM here works pretty much the same as the simple Gaussian ES described above and therefore the same function can be replaced using CMA-ES. CEM-RL is built on the framework of Evolutionary Reinforcement Learning (ERL; Khadka \u0026amp; Tumer, 2018) in which the standard EA algorithm selects and evolves a population of actors and the rollout experience generated in the process is then added into reply buffer for training both RL-actor and RL-critic networks.\nWorkflow:\n The mean actor of the CEM population is $\\pi_\\mu$ is initialized with a random actor network. The critic network $Q$ is initialized too, which will be updated by DDPG/TD3. Repeat until happy: a. Sample a population of actors $\\sim \\mathcal{N}(\\pi_\\mu, \\Sigma)$. b. Half of the population is evaluated. Their fitness scores are used as the cumulative reward $R$ and added into replay buffer. c. The other half are updated together with the critic. d. The new $\\pi_mu$ and $\\Sigma$ is computed using top performing elite samples. CMA-ES can be used for parameter update too. Extension: EA in Deep Learning (This section is not on evolution strategies, but still an interesting and relevant reading.)\nThe Evolutionary Algorithms have been applied on many deep learning problems. POET (Wang et al, 2019) is a framework based on EA and attempts to generate a variety of different tasks while the problems themselves are being solved. POET has been introduced in my last post on meta-RL. Evolutionary Reinforcement Learning (ERL) is another example; See Fig. 7 (b).\nBelow I would like to introduce two applications in more detail, Population-Based Training (PBT) and Weight-Agnostic Neural Networks (WANN).\nHyperparameter Tuning: PBT Fig. 8. Paradigms of comparing different ways of hyperparameter tuning. (Image source: PBT paper) Population-Based Training (Jaderberg, et al, 2017), short for PBT applies EA on the problem of hyperparameter tuning. It jointly trains a population of models and corresponding hyperparameters for optimal performance.\nPBT starts with a set of random candidates, each containing a pair of model weights initialization and hyperparameters, $\\{(\\theta_i, h_i)\\mid i=1, \\dots, N\\}$. Every sample is trained in parallel and asynchronously evaluates its own performance periodically. Whenever a member deems ready (i.e. after taking enough gradient update steps, or when the performance is good enough), it has a chance to be updated by comparing with the whole population:\n exploit(): When this model is under-performing, the weights could be replaced with a better performing model. explore(): If the model weights are overwritten, explore step perturbs the hyperparameters with random noise. In this process, only promising model and hyperparameter pairs can survive and keep on evolving, achieving better utilization of computational resources.\nFig. 9. The algorithm of population-based training. (Image source: PBT paper) Network Topology Optimization: WANN Weight Agnostic Neural Networks (short for WANN; Gaier \u0026amp; Ha 2019) experiments with searching for the smallest network topologies that can achieve the optimal performance without training the network weights. By not considering the best configuration of network weights, WANN puts much more emphasis on the architecture itself, making the focus different from NAS. WANN is heavily inspired by a classic genetic algorithm to evolve network topologies, called NEAT (\u0026ldquo;Neuroevolution of Augmenting Topologies\u0026rdquo;; Stanley \u0026amp; Miikkulainen 2002).\nThe workflow of WANN looks pretty much the same as standard GA:\n Initialize: Create a population of minimal networks. Evaluation: Test with a range of shared weight values. Rank and Selection: Rank by performance and complexity. Mutation: Create new population by varying best networks. Fig. 10. mutation operations for searching for new network topologies in WANN (Image source: WANN paper) At the \u0026ldquo;evaluation\u0026rdquo; stage, all the network weights are set to be the same. In this way, WANN is actually searching for network that can be described with a minimal description length. In the \u0026ldquo;selection\u0026rdquo; stage, both the network connection and the model performance are considered.\nFig. 11. Performance of WANN found network topologies on different RL tasks are compared with baseline FF networks commonly used in the literature. \"Tuned Shared Weight\" only requires adjusting one weight value. (Image source: WANN paper) As shown in Fig. 11, WANN results are evaluated with both random weights and shared weights (single weight). It is interesting that even when enforcing weight-sharing on all weights and tuning this single parameter, WANN can discover topologies that achieve non-trivial good performance.\n Cited as:\n@article{weng2019ES, title = \u0026quot;Evolution Strategies\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2019\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2019-09-05-evolution-strategies/\u0026quot; } References [1] Nikolaus Hansen. \u0026ldquo;The CMA Evolution Strategy: A Tutorial\u0026rdquo; arXiv preprint arXiv:1604.00772 (2016).\n[2] Marc Toussaint. Slides: \u0026ldquo;Introduction to Optimization\u0026rdquo;\n[3] David Ha. \u0026ldquo;A Visual Guide to Evolution Strategies\u0026rdquo; blog.otoro.net. Oct 2017.\n[4] Daan Wierstra, et al. \u0026ldquo;Natural evolution strategies.\u0026quot; IEEE World Congress on Computational Intelligence, 2008.\n[5] Agustinus Kristiadi. \u0026ldquo;Natural Gradient Descent\u0026rdquo; Mar 2018.\n[6] Razvan Pascanu \u0026amp; Yoshua Bengio. \u0026ldquo;Revisiting Natural Gradient for Deep Networks.\u0026quot; arXiv preprint arXiv:1301.3584 (2013).\n[7] Tim Salimans, et al. \u0026ldquo;Evolution strategies as a scalable alternative to reinforcement learning.\u0026quot; arXiv preprint arXiv:1703.03864 (2017).\n[8] Edoardo Conti, et al. \u0026ldquo;Improving exploration in evolution strategies for deep reinforcement learning via a population of novelty-seeking agents.\u0026quot; NIPS. 2018.\n[9] Aloïs Pourchot \u0026amp; Olivier Sigaud. \u0026ldquo;CEM-RL: Combining evolutionary and gradient-based methods for policy search.\u0026quot; ICLR 2019.\n[10] Shauharda Khadka \u0026amp; Kagan Tumer. \u0026ldquo;Evolution-guided policy gradient in reinforcement learning.\u0026quot; NIPS 2018.\n[11] Max Jaderberg, et al. \u0026ldquo;Population based training of neural networks.\u0026quot; arXiv preprint arXiv:1711.09846 (2017).\n[12] Adam Gaier \u0026amp; David Ha. \u0026ldquo;Weight Agnostic Neural Networks.\u0026quot; arXiv preprint arXiv:1906.04358 (2019).\n","permalink":"https://lilianweng.github.io/posts/2019-09-05-evolution-strategies/","summary":"Stochastic gradient descent is a universal choice for optimizing deep learning models. However, it is not the only option. With black-box optimization algorithms, you can evaluate a target function $f(x): \\mathbb{R}^n \\to \\mathbb{R}$, even when you don\u0026rsquo;t know the precise analytic form of $f(x)$ and thus cannot compute gradients or the Hessian matrix. Examples of black-box optimization methods include Simulated Annealing, Hill Climbing and Nelder-Mead method.\nEvolution Strategies (ES) is one type of black-box optimization algorithms, born in the family of Evolutionary Algorithms (EA).","title":"Evolution Strategies"},{"content":"In my earlier post on meta-learning, the problem is mainly defined in the context of few-shot classification. Here I would like to explore more into cases when we try to \u0026ldquo;meta-learn\u0026rdquo; Reinforcement Learning (RL) tasks by developing an agent that can solve unseen tasks fast and efficiently.\nTo recap, a good meta-learning model is expected to generalize to new tasks or new environments that have never been encountered during training. The adaptation process, essentially a mini learning session, happens at test with limited exposure to the new configurations. Even without any explicit fine-tuning (no gradient backpropagation on trainable variables), the meta-learning model autonomously adjusts internal hidden states to learn.\nTraining RL algorithms can be notoriously difficult sometimes. If the meta-learning agent could become so smart that the distribution of solvable unseen tasks grows extremely broad, we are on track towards general purpose methods \u0026mdash; essentially building a \u0026ldquo;brain\u0026rdquo; which would solve all kinds of RL problems without much human interference or manual feature engineering. Sounds amazing, right? 💖\nOn the Origin of Meta-RL Back in 2001 I encountered a paper written in 2001 by Hochreiter et al. when reading Wang et al., 2016. Although the idea was proposed for supervised learning, there are so many resemblances to the current approach to meta-RL.\nFig. 1. The meta-learning system consists of the supervisory and the subordinate systems. The subordinate system is a recurrent neural network that takes as input both the observation at the current time step, $x\\_t$ and the label at the last time step, $y\\_{t-1}$. (Image source: Hochreiter et al., 2001) Hochreiter\u0026rsquo;s meta-learning model is a recurrent network with LSTM cell. LSTM is a good choice because it can internalize a history of inputs and tune its own weights effectively through BPTT. The training data contains $K$ sequences and each sequence is consist of $N$ samples generated by a target function $f_k(.), k=1, \\dots, K$,\n $$ \\{\\text{input: }(\\mathbf{x}^k_i, \\mathbf{y}^k_{i-1}) \\to \\text{label: }\\mathbf{y}^k_i\\}_{i=1}^N \\text{ where }\\mathbf{y}^k_i = f_k(\\mathbf{x}^k_i) $$ Noted that the last label $\\mathbf{y}^k_{i-1}$ is also provided as an auxiliary input so that the function can learn the presented mapping.\nIn the experiment of decoding two-dimensional quadratic functions, $a x_1^2 + b x_2^2 + c x_1 x_2 + d x_1 + e x_2 + f$, with coefficients $a$-$f$ are randomly sampled from [-1, 1], this meta-learning system was able to approximate the function after seeing only ~35 examples.\nProposal in 2016 In the modern days of DL, Wang et al. (2016) and Duan et al. (2017) simultaneously proposed the very similar idea of Meta-RL (it is called RL^2 in the second paper). A meta-RL model is trained over a distribution of MDPs, and at test time, it is able to learn to solve a new task quickly. The goal of meta-RL is ambitious, taking one step further towards general algorithms.\nDefine Meta-RL Meta Reinforcement Learning, in short, is to do meta-learning in the field of reinforcement learning. Usually the train and test tasks are different but drawn from the same family of problems; i.e., experiments in the papers included multi-armed bandit with different reward probabilities, mazes with different layouts, same robots but with different physical parameters in simulator, and many others.\nFormulation Let\u0026rsquo;s say we have a distribution of tasks, each formularized as an MDP (Markov Decision Process), $M_i \\in \\mathcal{M}$. An MDP is determined by a 4-tuple, $M_i= \\langle \\mathcal{S}, \\mathcal{A}, P_i, R_i \\rangle$:\n Symbol Meaning $\\mathcal{S}$ A set of states. $\\mathcal{A}$ A set of actions. $P_i: \\mathcal{S} \\times \\mathcal{A} \\times \\mathcal{S} \\to \\mathbb{R}_{+}$ Transition probability function. $R_i: \\mathcal{S} \\times \\mathcal{A} \\to \\mathbb{R}$ Reward function. (RL^2 paper adds an extra parameter, horizon $T$, into the MDP tuple to emphasize that each MDP should have a finite horizon.)\nNote that common state $\\mathcal{S}$ and action space $\\mathcal{A}$ are used above, so that a (stochastic) policy: $\\pi_\\theta: \\mathcal{S} \\times \\mathcal{A} \\to \\mathbb{R}_{+}$ would get inputs compatible across different tasks. The test tasks are sampled from the same distribution $\\mathcal{M}$ or slightly modified version.\nFig. 2. Illustration of meta-RL, containing two optimization loops. The outer loop samples a new environment in every iteration and adjusts parameters that determine the agent's behavior. In the inner loop, the agent interacts with the environment and optimizes for the maximal reward. (Image source: Botvinick, et al. 2019) Main Differences from RL The overall configure of meta-RL is very similar to an ordinary RL algorithm, except that the last reward $r_{t-1}$ and the last action $a_{t-1}$ are also incorporated into the policy observation in addition to the current state $s_t$.\n In RL: $\\pi_\\theta(s_t) \\to$ a distribution over $\\mathcal{A}$ In meta-RL: $\\pi_\\theta(a_{t-1}, r_{t-1}, s_t) \\to$ a distribution over $\\mathcal{A}$ The intention of this design is to feed a history into the model so that the policy can internalize the dynamics between states, rewards, and actions in the current MDP and adjust its strategy accordingly. This is well aligned with the setup in Hochreiter\u0026rsquo;s system. Both meta-RL and RL^2 implemented an LSTM policy and the LSTM\u0026rsquo;s hidden states serve as a memory for tracking characteristics of the trajectories. Because the policy is recurrent, there is no need to feed the last state as inputs explicitly.\nThe training procedure works as follows:\n Sample a new MDP, $M_i \\sim \\mathcal{M}$; Reset the hidden state of the model; Collect multiple trajectories and update the model weights; Repeat from step 1. Fig. 3. In the meta-RL paper, different actor-critic architectures all use a recurrent model. Last reward and last action are additional inputs. The observation is fed into the LSTM either as a one-hot vector or as an embedding vector after passed through an encoder model. (Image source: Wang et al., 2016) Fig. 4. As described in the RL^2 paper, illustration of the procedure of the model interacting with a series of MDPs in training time . (Image source: Duan et al., 2017) Key Components There are three key components in Meta-RL:\n ⭐ A Model with Memory A recurrent neural network maintains a hidden state. Thus, it could acquire and memorize the knowledge about the current task by updating the hidden state during rollouts. Without memory, meta-RL would not work.\n ⭐ Meta-learning Algorithm A meta-learning algorithm refers to how we can update the model weights to optimize for the purpose of solving an unseen task fast at test time. In both Meta-RL and RL^2 papers, the meta-learning algorithm is the ordinary gradient descent update of LSTM with hidden state reset between a switch of MDPs.\n ⭐ A Distribution of MDPs While the agent is exposed to a variety of environments and tasks during training, it has to learn how to adapt to different MDPs.\n According to Botvinick et al. (2019), one source of slowness in RL training is weak inductive bias ( = \u0026ldquo;a set of assumptions that the learner uses to predict outputs given inputs that it has not encountered\u0026rdquo;). As a general ML rule, a learning algorithm with weak inductive bias will be able to master a wider range of variance, but usually, will be less sample-efficient. Therefore, to narrow down the hypotheses with stronger inductive biases help improve the learning speed.\nIn meta-RL, we impose certain types of inductive biases from the task distribution and store them in memory. Which inductive bias to adopt at test time depends on the algorithm. Together, these three key components depict a compelling view of meta-RL: Adjusting the weights of a recurrent network is slow but it allows the model to work out a new task fast with its own RL algorithm implemented in its internal activity dynamics.\nMeta-RL interestingly and not very surprisingly matches the ideas in the AI-GAs (\u0026ldquo;AI-Generating Algorithms\u0026rdquo;) paper by Jeff Clune (2019). He proposed that one efficient way towards building general AI is to make learning as automatic as possible. The AI-GAs approach involves three pillars: (1) meta-learning architectures, (2) meta-learning algorithms, and (3) automatically generated environments for effective learning.\n The topic of designing good recurrent network architectures is a bit too broad to be discussed here, so I will skip it. Next, let\u0026rsquo;s look further into another two components: meta-learning algorithms in the context of meta-RL and how to acquire a variety of training MDPs.\nMeta-Learning Algorithms for Meta-RL My previous post on meta-learning has covered several classic meta-learning algorithms. Here I\u0026rsquo;m gonna include more related to RL.\nOptimizing Model Weights for Meta-learning Both MAML (Finn, et al. 2017) and Reptile (Nichol et al., 2018) are methods on updating model parameters in order to achieve good generalization performance on new tasks. See an earlier post section on MAML and Reptile.\nMeta-learning Hyperparameters The return function in an RL problem, $G_t^{(n)}$ or $G_t^\\lambda$, involves a few hyperparameters that are often set heuristically, like the discount factor $\\gamma$ and the bootstrapping parameter $\\lambda$. Meta-gradient RL (Xu et al., 2018) considers them as meta-parameters, $\\eta=\\{\\gamma, \\lambda \\}$, that can be tuned and learned online while an agent is interacting with the environment. Therefore, the return becomes a function of $\\eta$ and dynamically adapts itself to a specific task over time.\n $$ \\begin{aligned} G_\\eta^{(n)}(\\tau_t) \u0026= R_{t+1} + \\gamma R_{t+2} + \\dots + \\gamma^{n-1}R_{t+n} + \\gamma^n v_\\theta(s_{t+n}) \u0026 \\scriptstyle{\\text{; n-step return}} \\\\ G_\\eta^{\\lambda}(\\tau_t) \u0026= (1-\\lambda) \\sum_{n=1}^\\infty \\lambda^{n-1} G_\\eta^{(n)} \u0026 \\scriptstyle{\\text{; λ-return, mixture of n-step returns}} \\end{aligned} $$ During training, we would like to update the policy parameters with gradients as a function of all the information in hand, $\\theta' = \\theta + f(\\tau, \\theta, \\eta)$, where $\\theta$ are the current model weights, $\\tau$ is a sequence of trajectories, and $\\eta$ are the meta-parameters.\nMeanwhile, let\u0026rsquo;s say we have a meta-objective function $J(\\tau, \\theta, \\eta)$ as a performance measure. The training process follows the principle of online cross-validation, using a sequence of consecutive experiences:\n Starting with parameter $\\theta$, the policy $\\pi_\\theta$ is updated on the first batch of samples $\\tau$, resulting in $\\theta'$. Then we continue running the policy $\\pi_{\\theta'}$ to collect a new set of experiences $\\tau'$, just following $\\tau$ consecutively in time. The performance is measured as $J(\\tau', \\theta', \\bar{\\eta})$ with a fixed meta-parameter $\\bar{\\eta}$. The gradient of meta-objective $J(\\tau', \\theta', \\bar{\\eta})$ w.r.t. $\\eta$ is used to update $\\eta$: $$ \\begin{aligned} \\Delta \\eta \u0026= -\\beta \\frac{\\partial J(\\tau', \\theta', \\bar{\\eta})}{\\partial \\eta} \\\\ \u0026= -\\beta \\frac{\\partial J(\\tau', \\theta', \\bar{\\eta})}{\\partial \\theta'} \\frac{d\\theta'}{d\\eta} \u0026 \\scriptstyle{\\text{ ; single variable chain rule.}} \\\\ \u0026= -\\beta \\frac{\\partial J(\\tau', \\theta', \\bar{\\eta})}{\\partial \\theta'} \\frac{\\partial (\\theta + f(\\tau, \\theta, \\eta))}{\\partial\\eta} \\\\ \u0026= -\\beta \\frac{\\partial J(\\tau', \\theta', \\bar{\\eta})}{\\partial \\theta'} \\Big(\\frac{d\\theta}{d\\eta} + \\frac{\\partial f(\\tau, \\theta, \\eta)}{\\partial\\theta}\\frac{d\\theta}{d\\eta} + \\frac{\\partial f(\\tau, \\theta, \\eta)}{\\partial\\eta}\\frac{d\\eta}{d\\eta} \\Big) \u0026 \\scriptstyle{\\text{; multivariable chain rule.}}\\\\ \u0026= -\\beta \\frac{\\partial J(\\tau', \\theta', \\bar{\\eta})}{\\partial \\theta'} \\Big( \\color{red}{\\big(\\mathbf{I} + \\frac{\\partial f(\\tau, \\theta, \\eta)}{\\partial\\theta}\\big)}\\frac{d\\theta}{d\\eta} + \\frac{\\partial f(\\tau, \\theta, \\eta)}{\\partial\\eta}\\Big) \u0026 \\scriptstyle{\\text{; secondary gradient term in red.}} \\end{aligned} $$ where $\\beta$ is the learning rate for $\\eta$.\nThe meta-gradient RL algorithm simplifies the computation by setting the secondary gradient term to zero, $\\mathbf{I} + \\partial g(\\tau, \\theta, \\eta)/\\partial\\theta = 0$ \u0026mdash; this choice prefers the immediate effect of the meta-parameters $\\eta$ on the parameters $\\theta$. Eventually we get:\n $$ \\Delta \\eta = -\\beta \\frac{\\partial J(\\tau', \\theta', \\bar{\\eta})}{\\partial \\theta'} \\frac{\\partial f(\\tau, \\theta, \\eta)}{\\partial\\eta} $$ Experiments in the paper adopted the meta-objective function same as $TD(\\lambda)$ algorithm, minimizing the error between the approximated value function $v_\\theta(s)$ and the $\\lambda$-return:\n $$ \\begin{aligned} J(\\tau, \\theta, \\eta) \u0026= (G^\\lambda_\\eta(\\tau) - v_\\theta(s))^2 \\\\ J(\\tau', \\theta', \\bar{\\eta}) \u0026= (G^\\lambda_{\\bar{\\eta}}(\\tau') - v_{\\theta'}(s'))^2 \\end{aligned} $$ Meta-learning the Loss Function In policy gradient algorithms, the expected total reward is maximized by updating the policy parameters $\\theta$ in the direction of estimated gradient (Schulman et al., 2016),\n $$ g = \\mathbb{E}[\\sum_{t=0}^\\infty \\Psi_t \\nabla_\\theta \\log \\pi_\\theta (a_t \\mid s_t)] $$ where the candidates for $\\Psi_t$ include the trajectory return $G_t$, the Q value $Q(s_t, a_t)$, or the advantage value $A(s_t, a_t)$. The corresponding surrogate loss function for the policy gradient can be reverse-engineered:\n $$ L_\\text{pg} = \\mathbb{E}[\\sum_{t=0}^\\infty \\Psi_t \\log \\pi_\\theta (a_t \\mid s_t)] $$ This loss function is a measure over a history of trajectories, $(s_0, a_0, r_0, \\dots, s_t, a_t, r_t, \\dots)$. Evolved Policy Gradient (EPG; Houthooft, et al, 2018) takes a step further by defining the policy gradient loss function as a temporal convolution (1-D convolution) over the agent\u0026rsquo;s past experience, $L_\\phi$. The parameters $\\phi$ of the loss function network are evolved in a way that an agent can achieve higher returns.\nSimilar to many meta-learning algorithms, EPG has two optimization loops:\n In the internal loop, an agent learns to improve its policy $\\pi_\\theta$. In the outer loop, the model updates the parameters $\\phi$ of the loss function $L_\\phi$. Because there is no explicit way to write down a differentiable equation between the return and the loss, EPG turned to Evolutionary Strategies (ES). A general idea is to train a population of $N$ agents, each of them is trained with the loss function $L_{\\phi + \\sigma \\epsilon_i}$ parameterized with $\\phi$ added with a small Gaussian noise $\\epsilon_i \\sim \\mathcal{N}(0, \\mathbf{I})$ of standard deviation $\\sigma$. During the inner loop\u0026rsquo;s training, EPG tracks a history of experience and updates the policy parameters according to the loss function $L_{\\phi + \\sigma\\epsilon_i}$ for each agent:\n $$ \\theta_i \\leftarrow \\theta - \\alpha_\\text{in} \\nabla_\\theta L_{\\phi + \\sigma \\epsilon_i} (\\pi_\\theta, \\tau_{t-K, \\dots, t}) $$ where $\\alpha_\\text{in}$ is the learning rate of the inner loop and $\\tau_{t-K, \\dots, t}$ is a sequence of $M$ transitions up to the current time step $t$.\nOnce the inner loop policy is mature enough, the policy is evaluated by the mean return $\\bar{G}_{\\phi+\\sigma\\epsilon_i}$ over multiple randomly sampled trajectories. Eventually, we are able to estimate the gradient of $\\phi$ according to NES numerically (Salimans et al, 2017). While repeating this process, both the policy parameters $\\theta$ and the loss function weights $\\phi$ are being updated simultaneously to achieve higher returns.\n $$ \\phi \\leftarrow \\phi + \\alpha_\\text{out} \\frac{1}{\\sigma N} \\sum_{i=1}^N \\epsilon_i G_{\\phi+\\sigma\\epsilon_i} $$ where $\\alpha_\\text{out}$ is the learning rate of the outer loop.\nIn practice, the loss $L_\\phi$ is bootstrapped with an ordinary policy gradient (such as REINFORCE or PPO) surrogate loss $L_\\text{pg}$, $\\hat{L} = (1-\\alpha) L_\\phi + \\alpha L_\\text{pg}$. The weight $\\alpha$ is annealing from 1 to 0 gradually during training. At test time, the loss function parameter $\\phi$ stays fixed and the loss value is computed over a history of experience to update the policy parameters $\\theta$.\nMeta-learning the Exploration Strategies The exploitation vs exploration dilemma is a critical problem in RL. Common ways to do exploration include $\\epsilon$-greedy, random noise on actions, or stochastic policy with built-in randomness on the action space.\nMAESN (Gupta et al, 2018) is an algorithm to learn structured action noise from prior experience for better and more effective exploration. Simply adding random noise on actions cannot capture task-dependent or time-correlated exploration strategies. MAESN changes the policy to condition on a per-task random variable $z_i \\sim \\mathcal{N}(\\mu_i, \\sigma_i)$, for $i$-th task $M_i$, so we would have a policy $a \\sim \\pi_\\theta(a\\mid s, z_i)$. The latent variable $z_i$ is sampled once and fixed during one episode. Intuitively, the latent variable determines one type of behavior (or skills) that should be explored more at the beginning of a rollout and the agent would adjust its actions accordingly. Both the policy parameters and latent space are optimized to maximize the total task rewards. In the meantime, the policy learns to make use of the latent variables for exploration.\nIn addition, the loss function includes a KL divergence between the learned latent variable and a unit Gaussian prior, $D_\\text{KL}(\\mathcal{N}(\\mu_i, \\sigma_i)|\\mathcal{N}(0, \\mathbf{I}))$. On one hand, it restricts the learned latent space not too far from a common prior. On the other hand, it creates the variational evidence lower bound (ELBO) for the reward function. Interestingly the paper found that $(\\mu_i, \\sigma_i)$ for each task are usually close to the prior at convergence.\nFig. 5. The policy is conditioned on a latent variable variable $z\\_i \\sim \\mathcal{N}(\\mu, \\sigma)$ that is sampled once every episode. Each task has different hyperparameters for the latent variable distribution, $(\\mu\\_i, \\sigma\\_i)$ and they are optimized in the outer loop. (Image source: Gupta et al, 2018) Episodic Control A major criticism of RL is on its sample inefficiency. A large number of samples and small learning steps are required for incremental parameter adjustment in RL in order to maximize generalization and avoid catastrophic forgetting of earlier learning (Botvinick et al., 2019).\nEpisodic control (Lengyel \u0026amp; Dayan, 2008) is proposed as a solution to avoid forgetting and improve generalization while training at a faster speed. It is partially inspired by hypotheses on instance-based hippocampal learning.\nAn episodic memory keeps explicit records of past events and uses these records directly as point of reference for making new decisions (i.e. just like metric-based meta-learning). In MFEC (Model-Free Episodic Control; Blundell et al., 2016), the memory is modeled as a big table, storing the state-action pair $(s, a)$ as key and the corresponding Q-value $Q_\\text{EC}(s, a)$ as value. When receiving a new observation $s$, the Q value is estimated in an non-parametric way as the average Q-value of top $k$ most similar samples:\n $$ \\hat{Q}_\\text{EC}(s, a) = \\begin{cases} Q_\\text{EC}(s, a) \u0026 \\text{if } (s,a) \\in Q_\\text{EC}, \\\\ \\frac{1}{k} \\sum_{i=1}^k Q(s^{(i)}, a) \u0026 \\text{otherwise} \\end{cases} $$ where $s^{(i)}, i=1, \\dots, k$ are top $k$ states with smallest distances to the state $s$. Then the action that yields the highest estimated Q value is selected. Then the memory table is updated according to the return received at $s_t$:\n $$ Q_\\text{EC}(s, a) \\leftarrow \\begin{cases} \\max\\{Q_\\text{EC}(s_t, a_t), G_t\\} \u0026 \\text{if } (s,a) \\in Q_\\text{EC}, \\\\ G_t \u0026 \\text{otherwise} \\end{cases} $$ As a tabular RL method, MFEC suffers from large memory consumption and a lack of ways to generalize among similar states. The first one can be fixed with an LRU cache. Inspired by metric-based meta-learning, especially Matching Networks (Vinyals et al., 2016), the generalization problem is improved in a follow-up algorithm, NEC (Neural Episodic Control; Pritzel et al., 2016).\nThe episodic memory in NEC is a Differentiable Neural Dictionary (DND), where the key is a convolutional embedding vector of input image pixels and the value stores estimated Q value. Given an inquiry key, the output is a weighted sum of values of top similar keys, where the weight is a normalized kernel measure between the query key and the selected key in the dictionary. This sounds like a hard attention machanism.\nFig. 6 Illustrations of episodic memory module in NEC and two operations on a differentiable neural dictionary. (Image source: Pritzel et al., 2016) Further, Episodic LSTM (Ritter et al., 2018) enhances the basic LSTM architecture with a DND episodic memory, which stores task context embeddings as keys and the LSTM cell states as values. The stored hidden states are retrieved and added directly to the current cell state through the same gating mechanism within LSTM:\nFig. 7. Illustration of the episodic LSTM architecture. The additional structure of episodic memory is in bold. (Image source: Ritter et al., 2018) $$ \\begin{aligned} \\mathbf{c}_t \u0026= \\mathbf{i}_t \\circ \\mathbf{c}_\\text{in} + \\mathbf{f}_t \\circ \\mathbf{c}_{t-1} + \\color{green}{\\mathbf{r}_t \\circ \\mathbf{c}_\\text{ep}} \u0026\\\\ \\mathbf{i}_t \u0026= \\sigma(\\mathbf{W}_{i} \\cdot [\\mathbf{h}_{t-1}, \\mathbf{x}_t] + \\mathbf{b}_i) \u0026 \\scriptstyle{\\text{; input gate}} \\\\ \\mathbf{f}_t \u0026= \\sigma(\\mathbf{W}_{f} \\cdot [\\mathbf{h}_{t-1}, \\mathbf{x}_t] + \\mathbf{b}_f) \u0026 \\scriptstyle{\\text{; forget gate}} \\\\ \\color{green}{\\mathbf{r}_t} \u0026 \\color{green}{=} \\color{green}{\\sigma(\\mathbf{W}_{r} \\cdot [\\mathbf{h}_{t-1}, \\mathbf{x}_t] + \\mathbf{b}_r)} \u0026 \\scriptstyle{\\text{; reinstatement gate}} \\end{aligned} $$ where $\\mathbf{c}_t$ and $\\mathbf{h}_t$ are hidden and cell state at time $t$; $\\mathbf{i}_t$, $\\mathbf{f}_t$ and $\\mathbf{r}_t$ are input, forget and reinstatement gates, respectively; $\\mathbf{c}_\\text{ep}$ is the retrieved cell state from episodic memory. The newly added episodic memory components are marked in green.\nThis architecture provides a shortcut to the prior experience through context-based retrieval. Meanwhile, explicitly saving the task-dependent experience in an external memory avoids forgetting. In the paper, all the experiments have manually designed context vectors. How to construct an effective and efficient format of task context embeddings for more free-formed tasks would be an interesting topic.\nOverall the capacity of episodic control is limited by the complexity of the environment. It is very rare for an agent to repeatedly visit exactly the same states in a real-world task, so properly encoding the states is critical. The learned embedding space compresses the observation data into a lower dimension space and, in the meantime, two states being close in this space are expected to demand similar strategies.\nTraining Task Acquisition Among three key components, how to design a proper distribution of tasks is the less studied and probably the most specific one to meta-RL itself. As described above, each task is a MDP: $M_i = \\langle \\mathcal{S}, \\mathcal{A}, P_i, R_i \\rangle \\in \\mathcal{M}$. We can build a distribution of MDPs by modifying:\n The reward configuration: Among different tasks, same behavior might get rewarded differently according to $R_i$. Or, the environment: The transition function $P_i$ can be reshaped by initializing the environment with varying shifts between states. Task Generation by Domain Randomization Randomizing parameters in a simulator is an easy way to obtain tasks with modified transition functions. If interested in learning further, check my last post on domain randomization.\nEvolutionary Algorithm on Environment Generation Evolutionary algorithm is a gradient-free heuristic-based optimization method, inspired by natural selection. A population of solutions follows a loop of evaluation, selection, reproduction, and mutation. Eventually, good solutions survive and thus get selected.\nPOET (Wang et al, 2019), a framework based on the evolutionary algorithm, attempts to generate tasks while the problems themselves are being solved. The implementation of POET is only specifically designed for a simple 2D bipedal walker environment but points out an interesting direction. It is noteworthy that the evolutionary algorithm has had some compelling applications in Deep Learning like EPG and PBT (Population-Based Training; Jaderberg et al, 2017).\nFig. 8. An example bipedal walking environment (top) and an overview of POET (bottom). (Image source: POET blog post) The 2D bipedal walking environment is evolving: from a simple flat surface to a much more difficult trail with potential gaps, stumps, and rough terrains. POET pairs the generation of environmental challenges and the optimization of agents together so as to (a) select agents that can resolve current challenges and (b) evolve environments to be solvable. The algorithm maintains a list of environment-agent pairs and repeats the following:\n Mutation: Generate new environments from currently active environments. Note that here types of mutation operations are created just for bipedal walker and a new environment would demand a new set of configurations. Optimization: Train paired agents within their respective environments. Selection: Periodically attempt to transfer current agents from one environment to another. Copy and update the best performing agent for every environment. The intuition is that skills learned in one environment might be helpful for a different environment. The procedure above is quite similar to PBT, but PBT mutates and evolves hyperparameters instead. To some extent, POET is doing domain randomization, as all the gaps, stumps and terrain roughness are controlled by some randomization probability parameters. Different from DR, the agents are not exposed to a fully randomized difficult environment all at once, but instead they are learning gradually with a curriculum configured by the evolutionary algorithm.\nLearning with Random Rewards An MDP without a reward function $R$ is known as a Controlled Markov process (CMP). Given a predefined CMP, $\\langle \\mathcal{S}, \\mathcal{A}, P\\rangle$, we can acquire a variety of tasks by generating a collection of reward functions $\\mathcal{R}$ that encourage the training of an effective meta-learning policy.\nGupta et al. (2018) proposed two unsupervised approaches for growing the task distribution in the context of CMP. Assuming there is an underlying latent variable $z \\sim p(z)$ associated with every task, it parameterizes/determines a reward function: $r_z(s) = \\log D(z|s)$, where a \u0026ldquo;discriminator\u0026rdquo; function $D(.)$ is used to extract the latent variable from the state. The paper described two ways to construct a discriminator function:\n Sample random weights $\\phi_\\text{rand}$ of the discriminator, $D_{\\phi_\\text{rand}}(z \\mid s)$. Learn a discriminator function to encourage diversity-driven exploration. This method is introduced in more details in another sister paper \u0026ldquo;DIAYN\u0026rdquo; (Eysenbach et al., 2018). DIAYN, short for \u0026ldquo;Diversity is all you need\u0026rdquo;, is a framework to encourage a policy to learn useful skills without a reward function. It explicitly models the latent variable $z$ as a skill embedding and makes the policy conditioned on $z$ in addition to state $s$, $\\pi_\\theta(a \\mid s, z)$. (Ok, this part is same as MAESN unsurprisingly, as the papers are from the same group.) The design of DIAYN is motivated by a few hypotheses:\n Skills should be diverse and lead to visitations of different states. → maximize the mutual information between states and skills, $I(S; Z)$ Skills should be distinguishable by states, not actions. → minimize the mutual information between actions and skills, conditioned on states $I(A; Z \\mid S)$ The objective function to maximize is as follows, where the policy entropy is also added to encourage diversity:\n $$ \\begin{aligned} \\mathcal{F}(\\theta) \u0026= I(S; Z) + H[A \\mid S] - I(A; Z \\mid S) \u0026 \\\\ \u0026= (H(Z) - H(Z \\mid S)) + H[A \\mid S] - (H[A\\mid S] - H[A\\mid S, Z]) \u0026 \\\\ \u0026= H[A\\mid S, Z] \\color{green}{- H(Z \\mid S) + H(Z)} \u0026 \\\\ \u0026= H[A\\mid S, Z] + \\mathbb{E}_{z\\sim p(z), s\\sim\\rho(s)}[\\log p(z \\mid s)] - \\mathbb{E}_{z\\sim p(z)}[\\log p(z)] \u0026 \\scriptstyle{\\text{; can infer skills from states \u0026 p(z) is diverse.}} \\\\ \u0026\\ge H[A\\mid S, Z] + \\mathbb{E}_{z\\sim p(z), s\\sim\\rho(s)}[\\color{red}{\\log D_\\phi(z \\mid s) - \\log p(z)}] \u0026 \\scriptstyle{\\text{; according to Jensen's inequality; \"pseudo-reward\" in red.}} \\end{aligned} $$ where $I(.)$ is mutual information and $H[.]$ is entropy measure. We cannot integrate all states to compute $p(z \\mid s)$, so approximate it with $D_\\phi(z \\mid s)$ \u0026mdash; that is the diversity-driven discriminator function.\nFig. 9. DIAYN Algorithm. (Image source: Eysenbach et al., 2019) Once the discriminator function is learned, sampling a new MDP for training is strainght-forward: First, sample a latent variable, $z \\sim p(z)$ and construct a reward function $r_z(s) = \\log(D(z \\vert s))$. Pairing the reward function with a predefined CMP creates a new MDP.\n Cited as:\n@article{weng2019metaRL, title = \u0026quot;Meta Reinforcement Learning\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2019\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2019-06-23-meta-rl/\u0026quot; } References [1] Richard S. Sutton. \u0026ldquo;The Bitter Lesson.\u0026quot; March 13, 2019.\n[2] Sepp Hochreiter, A. Steven Younger, and Peter R. Conwell. \u0026ldquo;Learning to learn using gradient descent.\u0026quot; Intl. Conf. on Artificial Neural Networks. 2001.\n[3] Jane X Wang, et al. \u0026ldquo;Learning to reinforcement learn.\u0026quot; arXiv preprint arXiv:1611.05763 (2016).\n[4] Yan Duan, et al. \u0026ldquo;RL $^ 2$: Fast Reinforcement Learning via Slow Reinforcement Learning.\u0026quot; ICLR 2017.\n[5] Matthew Botvinick, et al. \u0026ldquo;Reinforcement Learning, Fast and Slow\u0026rdquo; Cell Review, Volume 23, Issue 5, P408-422, May 01, 2019.\n[6] Jeff Clune. \u0026ldquo;AI-GAs: AI-generating algorithms, an alternate paradigm for producing general artificial intelligence\u0026rdquo; arXiv preprint arXiv:1905.10985 (2019).\n[7] Zhongwen Xu, et al. \u0026ldquo;Meta-Gradient Reinforcement Learning\u0026rdquo; NIPS 2018.\n[8] Rein Houthooft, et al. \u0026ldquo;Evolved Policy Gradients.\u0026quot; NIPS 2018.\n[9] Tim Salimans, et al. \u0026ldquo;Evolution strategies as a scalable alternative to reinforcement learning.\u0026quot; arXiv preprint arXiv:1703.03864 (2017).\n[10] Abhishek Gupta, et al. \u0026ldquo;Meta-Reinforcement Learning of Structured Exploration Strategies.\u0026quot; NIPS 2018.\n[11] Alexander Pritzel, et al. \u0026ldquo;Neural episodic control.\u0026quot; Proc. Intl. Conf. on Machine Learning, Volume 70, 2017.\n[12] Charles Blundell, et al. \u0026ldquo;Model-free episodic control.\u0026quot; arXiv preprint arXiv:1606.04460 (2016).\n[13] Samuel Ritter, et al. \u0026ldquo;Been there, done that: Meta-learning with episodic recall.\u0026quot; ICML, 2018.\n[14] Rui Wang et al. \u0026ldquo;Paired Open-Ended Trailblazer (POET): Endlessly Generating Increasingly Complex and Diverse Learning Environments and Their Solutions\u0026rdquo; arXiv preprint arXiv:1901.01753 (2019).\n[15] Uber Engineering Blog: \u0026ldquo;POET: Endlessly Generating Increasingly Complex and Diverse Learning Environments and their Solutions through the Paired Open-Ended Trailblazer.\u0026quot; Jan 8, 2019.\n[16] Abhishek Gupta, et al.\u0026ldquo;Unsupervised meta-learning for Reinforcement Learning\u0026rdquo; arXiv preprint arXiv:1806.04640 (2018).\n[17] Eysenbach, Benjamin, et al. \u0026ldquo;Diversity is all you need: Learning skills without a reward function.\u0026quot; ICLR 2019.\n[18] Max Jaderberg, et al. \u0026ldquo;Population Based Training of Neural Networks.\u0026quot; arXiv preprint arXiv:1711.09846 (2017).\n","permalink":"https://lilianweng.github.io/posts/2019-06-23-meta-rl/","summary":"In my earlier post on meta-learning, the problem is mainly defined in the context of few-shot classification. Here I would like to explore more into cases when we try to \u0026ldquo;meta-learn\u0026rdquo; Reinforcement Learning (RL) tasks by developing an agent that can solve unseen tasks fast and efficiently.\nTo recap, a good meta-learning model is expected to generalize to new tasks or new environments that have never been encountered during training. The adaptation process, essentially a mini learning session, happens at test with limited exposure to the new configurations.","title":"Meta Reinforcement Learning"},{"content":"In Robotics, one of the hardest problems is how to make your model transfer to the real world. Due to the sample inefficiency of deep RL algorithms and the cost of data collection on real robots, we often need to train models in a simulator which theoretically provides an infinite amount of data. However, the reality gap between the simulator and the physical world often leads to failure when working with physical robots. The gap is triggered by an inconsistency between physical parameters (i.e. friction, kp, damping, mass, density) and, more fatally, the incorrect physical modeling (i.e. collision between soft surfaces).\nTo close the sim2real gap, we need to improve the simulator and make it closer to reality. A couple of approaches:\n System identification System identification is to build a mathematical model for a physical system; in the context of RL, the mathematical model is the simulator. To make the simulator more realistic, careful calibration is necessary. Unfortunately, calibration is expensive. Furthermore, many physical parameters of the same machine might vary significantly due to temperature, humidity, positioning or its wear-and-tear in time. Domain adaptation Domain adaptation (DA) refers to a set of transfer learning techniques developed to update the data distribution in sim to match the real one through a mapping or regularization enforced by the task model. Many DA models, especially for image classification or end-to-end image-based RL task, are built on adversarial loss or GAN. Domain randomization With domain randomization (DR), we are able to create a variety of simulated environments with randomized properties and train a model that works across all of them. Likely this model can adapt to the real-world environment, as the real system is expected to be one sample in that rich distribution of training variations. Both DA and DR are unsupervised. Compared to DA which requires a decent amount of real data samples to capture the distribution, DR may need only a little or no real data. DR is the focus of this post.\nFig. 1. Conceptual illustrations of three approaches for sim2real transfer. What is Domain Randomization? To make the definition more general, let us call the environment that we have full access to (i.e. simulator) source domain and the environment that we would like to transfer the model to target domain (i.e. physical world). Training happens in the source domain. We can control a set of $N$ randomization parameters in the source domain $e_\\xi$ with a configuration $\\xi$, sampled from a randomization space, $\\xi \\in \\Xi \\subset \\mathbb{R}^N$.\nDuring policy training, episodes are collected from source domain with randomization applied. Thus the policy is exposed to a variety of environments and learns to generalize. The policy parameter $\\theta$ is trained to maximize the expected reward $R(.)$ average across a distribution of configurations:\n $$ \\theta^* = \\arg\\max_\\theta \\mathbb{E}_{\\xi \\sim \\Xi} [\\mathbb{E}_{\\pi_\\theta, \\tau \\sim e_\\xi} [R(\\tau)]] $$ where $\\tau_\\xi$ is a trajectory collected in source domain randomized with $\\xi$. In a way, \u0026ldquo;discrepancies between the source and target domains are modeled as variability in the source domain.\u0026quot; (quote from Peng et al. 2018).\nUniform Domain Randomization In the original form of DR (Tobin et al, 2017; Sadeghi et al. 2016), each randomization parameter $\\xi_i$ is bounded by an interval, $\\xi_i \\in [\\xi_i^\\text{low}, \\xi_i^\\text{high}], i=1,\\dots,N$ and each parameter is uniformly sampled within the range.\nThe randomization parameters can control appearances of the scene, including but not limited to the followings (see Fig. 2). A model trained on simulated and randomized images is able to transfer to real non-randomized images.\n Position, shape, and color of objects, Material texture, Lighting condition, Random noise added to images, Position, orientation, and field of view of the camera in the simulator. Fig. 2. Images captured in the training environment are randomized. (Image source: Tobin et al, 2017) Physical dynamics in the simulator can also be randomized (Peng et al. 2018). Studies have showed that a recurrent policy can adapt to different physical dynamics including the partially observable reality. A set of physical dynamics features include but are not limited to:\n Mass and dimensions of objects, Mass and dimensions of robot bodies, Damping, kp, friction of the joints, Gains for the PID controller (P term), Joint limit, Action delay, Observation noise. With visual and dynamics DR, at OpenAI Robotics, we were able to learn a policy that works on real dexterous robot hand (OpenAI, 2018). Our manipulation task is to teach the robot hand to rotate an object continously to achieve 50 successive random target orientations. The sim2real gap in this task is very large, due to (a) a high number of simultaneous contacts between the robot and the object and (b) imperfect simulation of object collision and other motions. At first, the policy could barely survive for more than 5 seconds without dropping the object. But with the help of DR, the policy evolved to work surprisingly well in reality eventually.\n Why does Domain Randomization Work? Now you may ask, why does domain randomization work so well? The idea sounds really simple. Here are two non-exclusive explanations I found most convincing.\nDR as Optimization One idea (Vuong, et al, 2019) is to view learning randomization parameters in DR as a bilevel optimization. Assuming we have access to the real environment $e_\\text{real}$ and the randomization config is sampled from a distribution parameterized by $\\phi$, $\\xi \\sim P_\\phi(\\xi)$, we would like to learn a distribution on which a policy $\\pi_\\theta$ is trained on can achieve maximal performance in $e_\\text{real}$:\n $$ \\begin{aligned} \u0026\\phi^* = \\arg\\min_{\\phi} \\mathcal{L}(\\pi_{\\theta^*(\\phi)}; e_\\text{real}) \\\\ \\text{where } \u0026\\theta^*(\\phi) = \\arg\\min_\\theta \\mathbb{E}_{\\xi \\sim P_\\phi(\\xi)}[\\mathcal{L}(\\pi_\\theta; e_\\xi)] \\end{aligned} $$ where $\\mathcal{L}(\\pi; e)$ is the loss function of policy $\\pi$ evaluated in the environment $e$.\nAlthough randomization ranges are hand-picked in uniform DR, it often involves domain knowledge and a couple rounds of trial-and-error adjustment based on the transfer performance. Essentially this is a manual optimization process on tuning $\\phi$ for the optimal $\\mathcal{L}(\\pi_{\\theta^*(\\phi)}; e_\\text{real})$.\nGuided domain randomization in the next section is largely inspired by this view, aiming to do bilevel optimization and learn the best parameter distribution automatically.\nDR as Meta-Learning In our learning dexterity project (OpenAI, 2018), we trained an LSTM policy to generalize across different environmental dynamics. We observed that once a robot achieved the first rotation, the time it needed for the following successes was much shorter. Also, a FF policy without memory was found not able to transfer to a physical robot. Both are evidence of the policy dynamically learning and adapting to a new environment.\nIn some ways, domain randomization composes a collection of different tasks. Memory in the recurrent network empowers the policy to achieve meta-learning across tasks and further work on a real-world setting.\nGuided Domain Randomization The vanilla DR assumes no access to the real data, and thus the randomization config is sampled as broadly and uniformly as possible in sim, hoping that the real environment could be covered under this broad distribution. It is reasonable to think of a more sophisticated strategy \u0026mdash; replacing uniform sampling with guidance from task performance, real data, or simulator.\nOne motivation for guided DR is to save computation resources by avoiding training models in unrealistic environments. Another is to avoid infeasible solutions that might arise from overly wide randomization distributions and thus might hinder successful policy learning.\nOptimization for Task Performance Say we train a family of policies with different randomization parameters $\\xi \\sim P_\\phi(\\xi)$, where $P_\\xi$ is the distribution for $\\xi$ parameterized by $\\phi$. Later we decide to try every one of them on the downstream task in the target domain (i.e. control a robot in reality or evaluate on a validation set) to collect feedback. This feedback tells us how good a configuration $\\xi$ is and provides signals for optimizing $\\phi$.\nInspired by NAS, AutoAugment (Cubuk, et al. 2018) frames the problem of learning best data augmentation operations (i.e. shearing, rotation, invert, etc.) for image classification as an RL problem. Note that AutoAugment is not proposed for sim2real transfer, but falls in the bucket of DR guided by task performance. Individual augmentation configuration is tested on the evaluation set and the performance improvement is used as a reward to train a PPO policy. This policy outputs different augmentation strategies for different datasets; for example, for CIFAR-10 AutoAugment mostly picks color-based transformations, while ImageNet prefers geometric based.\nRuiz (2019) considered the task feedback as reward in RL problem and proposed a RL-based method, named \u0026ldquo;learning to simulate\u0026rdquo;, for adjusting $\\xi$. A policy is trained to predict $\\xi$ using performance metrics on the validation data of the main task as rewards, which is modeled as a multivariate Gaussian. Overall the idea is similar to AutoAugment, applying NAS on data generation. According to their experiments, even if the main task model is not converged, it still can provide a reasonable signal to the data generation policy.\nFig. 3. An overview of the \"learning to simulate\" approach. (Image source: Ruiz (2019)) Evolutionary algorithm is another way to go, where the feedback is treated as fitness for guiding evolution (Yu et al, 2019). In this study, they used CMA-ES (covariance matrix adaptation evolution strategy) while fitness is the performance of a $\\xi$-conditional policy in target environment. In the appendix, they compared CMA-ES with other ways of modeling the dynamics of $\\xi$, including Bayesian optimization or a neural network. The main claim was those methods are not as stable or sample efficient as CMA-ES. Interestly, when modeling $P(\\xi)$ as a neural network, LSTM is found to notably outperform FF.\nSome believe that sim2real gap is a combination of appearance gap and content gap; i.e. most GAN-inspired DA models focus on appearance gap. Meta-Sim (Kar, et al. 2019) aims to close the content gap by generating task-specific synthetic datasets. Meta-Sim uses self-driving car training as an example and thus the scene could be very complicated. In this case, the synthetic scenes are parameterized by a hierarchy of objects with properties (i.e., location, color) as well as relationships between objects. The hierarchy is specified by a probabilistic scene grammar akin to structure domain randomization (SDR; Prakash et al., 2018) and it is assumed to be known beforehand. A model $G$ is trained to augment the distribution of scene properties $s$ by following:\n Learn the prior first: pre-train $G$ to learn the identity function $G(s) = s$. Minimize MMD loss between the real and sim data distributions. This involves backpropagation through non-differentiable renderer. The paper computes it numerically by perturbing the attributes of $G(s)$. Minimize REINFORCE task loss when trained on synthetic data but evaluated on real data. Again, very similar to AutoAugment. Unfortunately, this family of methods are not suitable for sim2real case. Either an RL policy or an EA model requires a large number of real samples. And it is really expensive to include real-time feedback collection on a physical robot into the training loop. Whether you want to trade less computation resource for real data collection would depend on your task.\nMatch Real Data Distribution Using real data to guide domain randomization feels a lot like doing system identification or DA. The core idea behind DA is to improve the synthetic data to match the real data distribution. In the case of real-data-guided DR, we would like to learn the randomization parameters $\\xi$ that bring the state distribution in simulator close to the state distribution in the real world.\nThe SimOpt model (Chebotar et al, 2019) is trained under an initial randomization distribution $P_\\phi(\\xi)$ first, getting a policy $\\pi_{\\theta, P_\\phi}$. Then this policy is deployed on both simulator and physical robot to collect trajectories $\\tau_\\xi$ and $\\tau_\\text{real}$ respectively. The optimization objective is to minimize the discrepancy between sim and real trajectories:\n $$ \\phi^* = \\arg\\min_{\\phi}\\mathbb{E}_{\\xi \\sim P_\\phi(\\xi)} [\\mathbb{E}_{\\pi_{\\theta, P_\\phi}} [D(\\tau_\\text{sim}, \\tau_\\text{real})]] $$ where $D(.)$ is a trajectory-based discrepancy measure. Like the \u0026ldquo;Learning to simulate\u0026rdquo; paper, SimOpt also has to solve the tricky problem of how to propagate gradient through non-differentiable simulator. It used a method called relative entropy policy search, see paper for more details.\nFig. 4. An overview of the SimOpt framework. (Image source: Chebotar et al, 2019) RCAN (James et al., 2019), short for \u0026ldquo;Randomized-to-Canonical Adaptation Networks\u0026rdquo;, is a nice combination of DA and DR for end-to-end RL tasks. An image-conditional GAN (cGAN) is trained in sim to translate a domain-randomized image into a non-randomized version (aka \u0026ldquo;canonical version\u0026rdquo;). Later the same model is used to translate real images into corresponding simulated version so that the agent would consume consistent observation as what it has encountered in training. Still, the underlying assumption is that the distribution of domain-randomized sim images is broad enough to cover real-world samples.\nFig. 5. RCAN is an image-conditional generator that can convert a domain-randomized or real image into its corresponding non-randomized simulator version. (Image source: James et al., 2019) The RL model is trained end-to-end in a simulator to do vision-based robot arm grasping. Randomization is applied at each timestep, including the position of tray divider, objects to grasp, random textures, as well as the position, direction, and color of the lighting. The canonical version is the default simulator look. RCAN is trying to learn a generator\n$G$: randomized image $\\to$ {canonical image, segmentation, depth}\nwhere segmentation masks and depth images are used as auxiliary tasks. RCAN had a better zero-shot transfer compared to uniform DR, although both were shown to be worse than the model trained on only real images. Conceptually, RCAN operates in a reverse direction of GraspGAN which translates synthetic images into real ones by domain adaptation.\nGuided by Data in Simulator Network-driven domain randomization (Zakharov et al., 2019), also known as DeceptionNet, is motivated by learning which randomizations are actually useful to bridge the domain gap for image classification tasks.\nRandomization is applied through a set of deception modules with encoder-decoder architecture. The deception modules are specifically designed to transform images; such as change backgrounds, add distortion, change lightings, etc. The other recognition network handles the main task by running classification on transformed images.\nThe training involves two steps:\n With the recognition network fixed, maximize the difference between the prediction and the labels by applying reversed gradients during backpropagation. So that the deception module can learn the most confusing tricks. With the deception modules fixed, train the recognition network with input images altered. Fig. 6. How DeceptionNet works. (Image source: Zakharov et al., 2019) The feedback for training deception modules is provided by the downstream classifier. But rather than trying to maximize the task performance like the section above, the randomization modules aim to create harder cases. One big disadvantage is you need to manually design different deception modules for different datasets or tasks, making it not easily scalable. Given the fact that it is zero-shot, the results are still worse than SOTA DA methods on MNIST and LineMOD.\nSimilarly, Active domain randomization (ADR; Mehta et al., 2019) also relies on sim data to create harder training samples. ADR searches for the most informative environment variations within the given randomization ranges, where the informativeness is measured as the discrepancies of policy rollouts in randomized and reference (original, non-randomized) environment instances. Sounds a bit like SimOpt? Well, noted that SimOpt measures the discrepancy between sim and real rollouts, while ADR measures between randomized and non-randomized sim, avoiding the expensive real data collection part.\nFig. 7. How active domain randomization (ADR) works. (Image source: Mehta et al., 2019) Precisely the training happens as follows:\n Given a policy, run it on both reference and randomized envs and collect two sets of trajectories respectively. Train a discriminator model to tell whether a rollout trajectory is randomized apart from reference run. The predicted $\\log p$ (probability of being randomized) is used as reward. The more different randomized and reference rollouts, the easier the prediction, the higher the reward. The intuition is that if an environment is easy, the same policy agent can produce similar trajectories as in the reference one. Then the model should reward and explore hard environments by encouraging different behaviors. The reward by discriminator is fed into Stein Variational Policy Gradient (SVPG) particles, outputting a diverse set of randomization configurations. The idea of ADR is very appealing with two small concerns. The similarity between trajectories might not be a good way to measure the env difficulty when running a stochastic policy. The sim2real results look unfortunately not as exciting, but the paper pointed out the win being ADR explores a smaller range of randomization parameters.\n Cited as:\n@article{weng2019DR, title = \u0026quot;Domain Randomization for Sim2Real Transfer\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2019\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2019-05-05-domain-randomization/\u0026quot; } Overall, after reading this post, I hope you like domain randomization as much as I do :).\nReferences [1] Josh Tobin, et al. \u0026ldquo;Domain randomization for transferring deep neural networks from simulation to the real world.\u0026quot; IROS, 2017.\n[2] Fereshteh Sadeghi and Sergey Levine. \u0026ldquo;CAD2RL: Real single-image flight without a single real image.\u0026quot; arXiv:1611.04201 (2016).\n[3] Xue Bin Peng, et al. \u0026ldquo;Sim-to-real transfer of robotic control with dynamics randomization.\u0026quot; ICRA, 2018.\n[4] Nataniel Ruiz, et al. \u0026ldquo;Learning to Simulate.\u0026quot; ICLR 2019\n[5] OpenAI. \u0026ldquo;Learning Dexterous In-Hand Manipulation.\u0026quot; arXiv:1808.00177 (2018).\n[6] OpenAI Blog. \u0026ldquo;Learning dexterity\u0026rdquo; July 30, 2018.\n[7] Quan Vuong, et al. \u0026ldquo;How to pick the domain randomization parameters for sim-to-real transfer of reinforcement learning policies?.\u0026quot; arXiv:1903.11774 (2019).\n[8] Ekin D. Cubuk, et al. \u0026ldquo;AutoAugment: Learning augmentation policies from data.\u0026quot; arXiv:1805.09501 (2018).\n[9] Wenhao Yu et al. \u0026ldquo;Policy Transfer with Strategy Optimization.\u0026quot; ICLR 2019\n[10] Yevgen Chebotar et al. \u0026ldquo;Closing the Sim-to-Real Loop: Adapting Simulation Randomization with Real World Experience.\u0026quot; Arxiv: 1810.05687 (2019).\n[11] Stephen James et al. \u0026ldquo;Sim-to-real via sim-to-sim: Data-efficient robotic grasping via randomized-to-canonical adaptation networks\u0026rdquo; CVPR 2019.\n[12] Bhairav Mehta et al. \u0026ldquo;Active Domain Randomization\u0026rdquo; arXiv:1904.04762\n[13] Sergey Zakharov,et al. \u0026ldquo;DeceptionNet: Network-Driven Domain Randomization.\u0026quot; arXiv:1904.02750 (2019).\n[14] Amlan Kar, et al. \u0026ldquo;Meta-Sim: Learning to Generate Synthetic Datasets.\u0026quot; arXiv:1904.11621 (2019).\n[15] Aayush Prakash, et al. \u0026ldquo;Structured Domain Randomization: Bridging the Reality Gap by Context-Aware Synthetic Data.\u0026quot; arXiv:1810.10093 (2018).\n","permalink":"https://lilianweng.github.io/posts/2019-05-05-domain-randomization/","summary":"In Robotics, one of the hardest problems is how to make your model transfer to the real world. Due to the sample inefficiency of deep RL algorithms and the cost of data collection on real robots, we often need to train models in a simulator which theoretically provides an infinite amount of data. However, the reality gap between the simulator and the physical world often leads to failure when working with physical robots.","title":"Domain Randomization for Sim2Real Transfer"},{"content":"[Updated on 2019-05-27: add the section on Lottery Ticket Hypothesis.]\nIf you are like me, entering into the field of deep learning with experience in traditional machine learning, you may often ponder over this question: Since a typical deep neural network has so many parameters and training error can easily be perfect, it should surely suffer from substantial overfitting. How could it be ever generalized to out-of-sample data points?\nThe effort in understanding why deep neural networks can generalize somehow reminds me of this interesting paper on System Biology \u0026mdash; \u0026ldquo;Can a biologist fix a radio?\u0026quot; (Lazebnik, 2002). If a biologist intends to fix a radio machine like how she works on a biological system, life could be hard. Because the full mechanism of the radio system is not revealed, poking small local functionalities might give some hints but it can hardly present all the interactions within the system, let alone the entire working flow. No matter whether you think it is relevant to DL, it is a very fun read.\nI would like to discuss a couple of papers on generalizability and complexity measurement of deep learning models in the post. Hopefully, it could shed light on your thinking path towards the understanding of why DNN can generalize.\nClassic Theorems on Compression and Model Selection Let\u0026rsquo;s say we have a classification problem and a dataset, we can develop many models to solve it, from fitting a simple linear regression to memorizing the full dataset in disk space. Which one is better? If we only care about the accuracy over training data (especially given that testing data is likely unknown), the memorization approach seems to be the best \u0026mdash; well, it doesn\u0026rsquo;t sound right.\nThere are many classic theorems to guide us when deciding what types of properties a good model should possess in such scenarios.\nOccam\u0026rsquo;s Razor Occam\u0026rsquo;s Razor is an informal principle for problem-solving, proposed by William of Ockham in the 14th century:\n \u0026ldquo;Simpler solutions are more likely to be correct than complex ones.\u0026rdquo;\n The statement is extremely powerful when we are facing multiple candidates of underlying theories to explain the world and have to pick one. Too many unnecessary assumptions might seem to be plausible for one problem, but harder to be generalized to other complications or to eventually lead to the basic principles of the universe.\nThink of this, it took people hundreds of years to figure out that the sky is blue in the daytime but reddish at sunset are because of the same reason (Rayleigh scattering), although two phenomena look very different. People must have proposed many other explanations for them separately but the unified and simple version won eventually.\nMinimum Description Length principle The principle of Occam\u0026rsquo;s Razor can be similarly applied to machine learning models. A formalized version of such concept is called the Minimum Description Length (MDL) principle, used for comparing competing models / explanations given data observed.\n \u0026ldquo;Comprehension is compression.\u0026rdquo;\n The fundamental idea in MDL is to view learning as data compression. By compressing the data, we need to discover regularity or patterns in the data with the high potentiality to generalize to unseen samples. Information bottleneck theory believes that a deep neural network is trained first to represent the data by minimizing the generalization error and then learn to compress this representation by trimming noise.\nMeanwhile, MDL considers the model description as part of the compression delivery, so the model cannot be arbitrarily large.\nA two-part version of MDL principle states that: Let $\\mathcal{H}^{(1)}, \\mathcal{H}^{(2)}, \\dots$ be a list of models that can explain the dataset $\\mathcal{D}$. The best hypothesis among them should be the one that minimizes the sum:\n $$ \\mathcal{H}^\\text{best} = \\arg\\min_\\mathcal{H} [L(\\mathcal{H}) + L(\\mathcal{D}\\vert\\mathcal{H})] $$ $L(\\mathcal{H})$ is the length of the description of model $\\mathcal{H}$ in bits. $L(\\mathcal{D}\\vert\\mathcal{H})$ is the length of the description of the data $\\mathcal{D}$ in bits when encoded with $\\mathcal{H}$. In simple words, the best model is the smallest model containing the encoded data and the model itself. Following this criterion, the memorization approach I proposed at the beginning of the section sounds horrible no matter how good accuracy it can achieve on the training data.\nPeople might argue Occam\u0026rsquo;s Razor is wrong, as given the real world can be arbitrarily complicated, why do we have to find simple models? One interesting view by MDL is to consider models as \u0026ldquo;languages\u0026rdquo; instead of fundamental generative theorems. We would like to find good compression strategies to describe regularity in a small set of samples, and they do not have to be the \u0026ldquo;real\u0026rdquo; generative model for explaining the phenomenon. Models can be wrong but still useful (i.e., think of any Bayesian prior).\nKolmogorov Complexity Kolmogorov Complexity relies on the concept of modern computers to define the algorithmic (descriptive) complexity of an object: It is the length of the shortest binary computer program that describes the object. Following MDL, a computer is essentially the most general form of data decompressor.\nThe formal definition of Kolmogorov Complexity states that: Given a universal computer $\\mathcal{U}$ and a program $p$, let\u0026rsquo;s denote $\\mathcal{U}(p)$ as the output of the computer processing the program and $L(p)$ as the descriptive length of the program. Then Kolmogorov Complexity $K_\\mathcal{U}$ of a string $s$ with respect to a universal computer $\\mathcal{U}$ is:\n $$ K_\\mathcal{U}(s) = \\min_{p: \\mathcal{U}(p)=s} L(p) $$ Note that a universal computer is one that can mimic the actions of any other computers. All modern computers are universal as they can all be reduced to Turing machines. The definition is universal no matter which computers we are using, because another universal computer can always be programmed to clone the behavior of $\\mathcal{U}$, while encoding this clone program is just a constant.\nThere are a lot of connections between Kolmogorov Complexity and Shannon Information Theory, as both are tied to universal coding. It is an amazing fact that the expected Kolmogorov Complexity of a random variable is approximately equal to its Shannon entropy (see Sec 2.3 of the report). More on this topic is out of the scope here, but there are many interesting readings online. Help yourself :)\nSolomonoff\u0026rsquo;s Inference Theory Another mathematical formalization of Occam\u0026rsquo;s Razor is Solomonoff\u0026rsquo;s theory of universal inductive inference (Solomonoff, 1964). The principle is to favor models that correspond to the \u0026ldquo;shortest program\u0026rdquo; to produce the training data, based on its Kolmogorov complexity\nExpressive Power of DL Models Deep neural networks have an extremely large number of parameters compared to the traditional statistical models. If we use MDL to measure the complexity of a deep neural network and consider the number of parameters as the model description length, it would look awful. The model description $L(\\mathcal{H})$ can easily grow out of control.\nHowever, having numerous parameters is necessary for a neural network to obtain high expressivity power. Because of its great capability to capture any flexible data representation, deep neural networks have achieved great success in many applications.\nUniversal Approximation Theorem The Universal Approximation Theorem states that a feedforward network with: 1) a linear output layer, 2) at least one hidden layer containing a finite number of neurons and 3) some activation function can approximate any continuous functions on a compact subset of $\\mathbb{R}^n$ to arbitrary accuracy. The theorem was first proved for sigmoid activation function (Cybenko, 1989). Later it was shown that the universal approximation property is not specific to the choice of activation (Hornik, 1991) but the multilayer feedforward architecture.\nAlthough a feedforward network with a single layer is sufficient to represent any function, the width has to be exponentially large. The universal approximation theorem does not guarantee whether the model can be learned or generalized properly. Often, adding more layers helps to reduce the number of hidden neurons needed in a shallow network.\nTo take advantage of the universal approximation theorem, we can always find a neural network to represent the target function with error under any desired threshold, but we need to pay the price \u0026mdash; the network might grow super large.\nProof: Finite Sample Expressivity of Two-layer NN The Universal Approximation Theorem we have discussed so far does not consider a finite sample set. Zhang, et al. (2017) provided a neat proof on the finite-sample expressivity of two-layer neural networks.\nA neural network $C$ can represent any function given a sample size $n$ in $d$ dimensions if: For every finite sample set $S \\subseteq \\mathbb{R}^d$ with $\\vert S \\vert = n$ and every function defined on this sample set: $f: S \\mapsto \\mathbb{R}$, we can find a set of weight configuration for $C$ so that $C(\\boldsymbol{x}) = f(\\boldsymbol{x}), \\forall \\boldsymbol{x} \\in S$.\nThe paper proposed a theorem:\n There exists a two-layer neural network with ReLU activations and $2n + d$ weights that can represent any function on a sample of size $n$ in $d$ dimensions.\n Proof. First we would like to construct a two-layer neural network $C: \\mathbb{R}^d \\mapsto \\mathbb{R}$. The input is a $d$-dimensional vector, $\\boldsymbol{x} \\in \\mathbb{R}^d$. The hidden layer has $h$ hidden units, associated with a weight matrix $\\mathbf{W} \\in \\mathbb{R}^{d\\times h}$, a bias vector $-\\mathbf{b} \\in \\mathbb{R}^h$ and ReLU activation function. The second layer outputs a scalar value with weight vector $\\boldsymbol{v} \\in \\mathbb{R}^h$ and zero biases.\nThe output of network $C$ for a input vector $\\boldsymbol{x}$ can be represented as follows:\n $$ C(\\boldsymbol{x}) = \\boldsymbol{v} \\max\\{ \\boldsymbol{x}\\mathbf{W} - \\boldsymbol{b}, 0\\}^\\top = \\sum_{i=1}^h v_i \\max\\{\\boldsymbol{x}\\boldsymbol{W}_{(:,i)} - b_i, 0\\} $$ where $\\boldsymbol{W}_{(:,i)}$ is the $i$-th column in the $d \\times h$ matrix.\nGiven a sample set $S = \\{\\boldsymbol{x}_1, \\dots, \\boldsymbol{x}_n\\}$ and target values $\\boldsymbol{y} = \\{y_1, \\dots, y_n \\}$, we would like to find proper weights $\\mathbf{W} \\in \\mathbb{R}^{d\\times h}$, $\\boldsymbol{b}, \\boldsymbol{v} \\in \\mathbb{R}^h$ so that $C(\\boldsymbol{x}_i) = y_i, \\forall i=1,\\dots,n$.\nLet\u0026rsquo;s combine all sample points into one batch as one input matrix $\\mathbf{X} \\in \\mathbb{R}^{n \\times d}$. If set $h=n$, $\\mathbf{X}\\mathbf{W} - \\boldsymbol{b}$ would be a square matrix of size $n \\times n$.\n $$ \\mathbf{M}_\\text{ReLU} = \\max\\{\\mathbf{X}\\mathbf{W} - \\boldsymbol{b}, 0 \\} = \\begin{bmatrix} \\boldsymbol{x}_1\\mathbf{W} - \\boldsymbol{b} \\\\ \\dots \\\\ \\boldsymbol{x}_n\\mathbf{W} - \\boldsymbol{b} \\\\ \\end{bmatrix} = [\\boldsymbol{x}_i\\boldsymbol{W}_{(:,j)} - b_j]_{i \\times j} $$ We can simplify $\\mathbf{W}$ to have the same column vectors across all the columns:\n $$ \\mathbf{W}_{(:,j)} = \\boldsymbol{w} \\in \\mathbb{R}^{d}, \\forall j = 1, \\dots, n $$ Let $a_i = \\boldsymbol{x}_i \\boldsymbol{w}$, we would like to find a suitable $\\boldsymbol{w}$ and $\\boldsymbol{b}$ such that $b_1 \u0026lt; a_1 \u0026lt; b_2 \u0026lt; a_2 \u0026lt; \\dots \u0026lt; b_n \u0026lt; a_n$. This is always achievable because we try to solve $n+d$ unknown variables with $n$ constraints and $\\boldsymbol{x}_i$ are independent (i.e. pick a random $\\boldsymbol{w}$, sort $\\boldsymbol{x}_i \\boldsymbol{w}$ and then set $b_j$\u0026rsquo;s as values in between). Then $\\mathbf{M}_\\text{ReLU}$ becomes a lower triangular matrix:\n $$ \\mathbf{M}_\\text{ReLU} = [a_i - b_j]_{i \\times j} = \\begin{bmatrix} a_1 - b_1 \u0026 0 \u0026 0 \u0026 \\dots \u0026 0 \\\\ \\vdots \u0026 \\ddots \u0026 \u0026 \u0026 \\vdots \\\\ a_i - b_1 \u0026 \\dots \u0026 a_i - b_i \u0026 \\dots \u0026 0\\\\ \\vdots \u0026 \u0026 \u0026 \\ddots \u0026 \\vdots \\\\ a_n - b_1 \u0026 a_n - b_2 \u0026 \\dots \u0026 \\dots \u0026 a_n - b_n \\\\ \\end{bmatrix} $$ It is a nonsingular square matrix as $\\det(\\mathbf{M}_\\text{ReLU}) \\neq 0$, so we can always find suitable $\\boldsymbol{v}$ to solve $\\boldsymbol{v}\\mathbf{M}_\\text{ReLU}=\\boldsymbol{y}$ (In other words, the column space of $\\mathbf{M}_\\text{ReLU}$ is all of $\\mathbb{R}^n$ and we can find a linear combination of column vectors to obtain any $\\boldsymbol{y}$).\nDeep NN can Learn Random Noise As we know two-layer neural networks are universal approximators, it is less surprising to see that they are able to learn unstructured random noise perfectly, as shown in Zhang, et al. (2017). If labels of image classification dataset are randomly shuffled, the high expressivity power of deep neural networks can still empower them to achieve near-zero training loss. These results do not change with regularization terms added.\nFig. 1. Fit models on CIFAR10 with random labels or random pixels: (a) learning curves; (b-c) label corruption ratio is the percentage of randomly shuffled labels. (Image source: Zhang et al. 2017) Are Deep Learning Models Dramatically Overfitted? Deep learning models are heavily over-parameterized and can often get to perfect results on training data. In the traditional view, like bias-variance trade-offs, this could be a disaster that nothing may generalize to the unseen test data. However, as is often the case, such \u0026ldquo;overfitted\u0026rdquo; (training error = 0) deep learning models still present a decent performance on out-of-sample test data. Hmm … interesting and why?\nModern Risk Curve for Deep Learning The traditional machine learning uses the following U-shape risk curve to measure the bias-variance trade-offs and quantify how generalizable a model is. If I get asked how to tell whether a model is overfitted, this would be the first thing popping into my mind.\nAs the model turns larger (more parameters added), the training error decreases to close to zero, but the test error (generalization error) starts to increase once the model complexity grows to pass the threshold between \u0026ldquo;underfitting\u0026rdquo; and \u0026ldquo;overfitting\u0026rdquo;. In a way, this is well aligned with Occam\u0026rsquo;s Razor.\nFig. 2. U-shaped bias-variance risk curve. (Image source: (left) paper (right) fig. 6 of this post) Unfortunately this does not apply to deep learning models. Belkin et al. (2018) reconciled the traditional bias-variance trade-offs and proposed a new double-U-shaped risk curve for deep neural networks. Once the number of network parameters is high enough, the risk curve enters another regime.\nFig. 3. A new double-U-shaped bias-variance risk curve for deep neural networks. (Image source: original paper) The paper claimed that it is likely due to two reasons:\n The number of parameters is not a good measure of inductive bias, defined as the set of assumptions of a learning algorithm used to predict for unknown samples. See more discussion on DL model complexity in later sections. Equipped with a larger model, we might be able to discover larger function classes and further find interpolating functions that have smaller norm and are thus \u0026ldquo;simpler\u0026rdquo;. The double-U-shaped risk curve was observed empirically, as shown in the paper. However I was struggling quite a bit to reproduce the results. There are some signs of life, but in order to generate a pretty smooth curve similar to the theorem, many details in the experiment have to be taken care of.\nFig. 4. Training and evaluation errors of a one hidden layer fc network of different numbers of hidden units, trained on 4000 data points sampled from MNIST. (Image source: original paper) Regularization is not the Key to Generalization Regularization is a common way to control overfitting and improve model generalization performance. Interestingly some research (Zhang, et al. 2017) has shown that explicit regularization (i.e. data augmentation, weight decay and dropout) is neither necessary or sufficient for reducing generalization error.\nTaking the Inception model trained on CIFAR10 as an example (see Fig. 5), regularization techniques help with out-of-sample generalization but not much. No single regularization seems to be critical independent of other terms. Thus, it is unlikely that regularizers are the fundamental reason for generalization.\nFig. 5. The accuracy of Inception model trained on CIFAR10 with different combinations of taking on or off data augmentation and weight decay. (Image source: Table 1 in the original paper) Intrinsic Dimension The number of parameters is not correlated with model overfitting in the field of deep learning, suggesting that parameter counting cannot indicate the true complexity of deep neural networks.\nApart from parameter counting, researchers have proposed many ways to quantify the complexity of these models, such as the number of degrees of freedom of models (Gao \u0026amp; Jojic, 2016), or prequential code (Blier \u0026amp; Ollivier, 2018).\nI would like to discuss a recent method on this matter, named intrinsic dimension (Li et al, 2018). Intrinsic dimension is intuitive, easy to measure, while still revealing many interesting properties of models of different sizes.\nConsidering a neural network with a great number of parameters, forming a high-dimensional parameter space, the learning happens on this high-dimensional objective landscape. The shape of the parameter space manifold is critical. For example, a smoother manifold is beneficial for optimization by providing more predictive gradients and allowing for larger learning rates\u0026mdash;this was claimed to be the reason why batch normalization has succeeded in stabilizing training (Santurkar, et al, 2019).\nEven though the parameter space is huge, fortunately we don\u0026rsquo;t have to worry too much about the optimization process getting stuck in local optima, as it has been shown that local optimal points in the objective landscape almost always lay in saddle-points rather than valleys. In other words, there is always a subset of dimensions containing paths to leave local optima and keep on exploring.\nFig. 6. Illustrations of various types of critical points on the parameter optimization landscape. (Image source: here) One intuition behind the measurement of intrinsic dimension is that, since the parameter space has such high dimensionality, it is probably not necessary to exploit all the dimensions to learn efficiently. If we only travel through a slice of objective landscape and still can learn a good solution, the complexity of the resulting model is likely lower than what it appears to be by parameter-counting. This is essentially what intrinsic dimension tries to assess.\nSay a model has $D$ dimensions and its parameters are denoted as $\\theta^{(D)}$. For learning, a smaller $d$-dimensional subspace is randomly sampled, $\\theta^{(d)}$, where $d \u0026lt; D$. During one optimization update, rather than taking a gradient step according to all $D$ dimensions, only the smaller subspace $\\theta^{(d)}$ is used and remapped to update model parameters.\nFig. 7. Illustration of parameter vectors for direct optimization when $D=3$. (Image source: original paper) The gradient update formula looks like the follows:\n $$ \\theta^{(D)} = \\theta_0^{(D)} + \\mathbf{P} \\theta^{(d)} $$ where $\\theta_0^{(D)}$ are the initialization values and $\\mathbf{P}$ is a $D \\times d$ projection matrix that is randomly sampled before training. Both $\\theta_0^{(D)}$ and $\\mathbf{P}$ are not trainable and fixed during training. $\\theta^{(d)}$ is initialized as all zeros.\nBy searching through the value of $d = 1, 2, \\dots, D$, the corresponding $d$ when the solution emerges is defined as the intrinsic dimension.\nIt turns out many problems have much smaller intrinsic dimensions than the number of parameters. For example, on CIFAR10 image classification, a fully-connected network with 650k+ parameters has only 9k intrinsic dimension and a convolutional network containing 62k parameters has an even lower intrinsic dimension of 2.9k.\nFig. 8. The measured intrinsic dimensions $d$ for various models achieving 90% of the best performance. (Image source: original paper) The measurement of intrinsic dimensions suggests that deep learning models are significantly simpler than what they might appear to be.\nHeterogeneous Layer Robustness Zhang et al. (2019) investigated the role of parameters in different layers. The fundamental question raised by the paper is: \u0026ldquo;are all layers created equal?\u0026quot; The short answer is: No. The model is more sensitive to changes in some layers but not others.\nThe paper proposed two types of operations that can be applied to parameters of the $\\ell$-th layer, $\\ell = 1, \\dots, L$, at time $t$, $\\theta^{(\\ell)}_t$ to test their impacts on model robustness:\n Re-initialization: Reset the parameters to the initial values, $\\theta^{(\\ell)}_t \\leftarrow \\theta^{(\\ell)}_0$. The performance of a network in which layer $\\ell$ was re-initialized is referred to as the re-initialization robustness of layer $\\ell$.\n Re-randomization: Re-sampling the layer\u0026rsquo;s parameters at random, $\\theta^{(\\ell)}_t \\leftarrow \\tilde{\\theta}^{(\\ell)} \\sim \\mathcal{P}^{(\\ell)}$. The corresponding network performance is called the re-randomization robustness of layer $\\ell$.\n Layers can be categorized into two categories with the help of these two operations:\n Robust Layers: The network has no or only negligible performance degradation after re-initializing or re-randomizing the layer. Critical Layers: Otherwise. Similar patterns are observed on fully-connected and convolutional networks. Re-randomizing any of the layers completely destroys the model performance, as the prediction drops to random guessing immediately. More interestingly and surprisingly, when applying re-initialization, only the first or the first few layers (those closest to the input layer) are critical, while re-initializing higher levels causes only negligible decrease in performance.\nFig. 9. (a) A fc network trained on MNIST. Each row corresponds to one layer in the network. The first column is re-randomization robustness of each layer and the rest of the columns indicate re-initialization robustness at different training time. (b) VGG11 model (conv net) trained on CIFAR 10. Similar representation as in (a) but rows and columns are transposed. (Image source: original paper) ResNet is able to use shortcuts between non-adjacent layers to re-distribute the sensitive layers across the networks rather than just at the bottom. With the help of residual block architecture, the network can evenly be robust to re-randomization. Only the first layer of each residual block is still sensitive to both re-initialization and re-randomization. If we consider each residual block as a local sub-network, the robustness pattern resembles the fc and conv nets above.\nFig. 10. Re-randomization (first row) and re-initialization (the reset rows) robustness of layers in ResNet-50 model trained on CIFAR10. (Image source: original paper) Based on the fact that many top layers in deep neural networks are not critical to the model performance after re-initialization, the paper loosely concluded that:\n \u0026ldquo;Over-capacitated deep networks trained with stochastic gradient have low-complexity due to self-restricting the number of critical layers.\u0026rdquo;\n We can consider re-initialization as a way to reduce the effective number of parameters, and thus the observation is aligned with what intrinsic dimension has demonstrated.\nThe Lottery Ticket Hypothesis The lottery ticket hypothesis (Frankle \u0026amp; Carbin, 2019) is another intriguing and inspiring discovery, supporting that only a subset of network parameters have impact on the model performance and thus the network is not overfitted. The lottery ticket hypothesis states that a randomly initialized, dense, feed-forward network contains a pool of subnetworks and among them only a subset are \u0026ldquo;winning tickets\u0026rdquo; which can achieve the optimal performance when trained in isolation.\nThe idea is motivated by network pruning techniques \u0026mdash; removing unnecessary weights (i.e. tiny weights that are almost negligible) without harming the model performance. Although the final network size can be reduced dramatically, it is hard to train such a pruned network architecture successfully from scratch. It feels like in order to successfully train a neural network, we need a large number of parameters, but we don\u0026rsquo;t need that many parameters to keep the accuracy high once the model is trained. Why is that?\nThe lottery ticket hypothesis did the following experiments:\n Randomly initialize a dense feed-forward network with initialization values $\\theta_0$; Train the network for multiple iterations to achieve a good performance with parameter config $\\theta$; Run pruning on $\\theta$ and creating a mask $m$. The \u0026ldquo;winning ticket\u0026rdquo; initialization config is $m \\odot \\theta_0$. Only training the small \u0026ldquo;winning ticket\u0026rdquo; subset of parameters with the initial values as found in step 1, the model is able to achieve the same level of accuracy as in step 2. It turns out a large parameter space is not needed in the final solution representation, but needed for training as it provides a big pool of initialization configs of many much smaller subnetworks.\nThe lottery ticket hypothesis opens a new perspective about interpreting and dissecting deep neural network results. Many interesting following-up works are on the way.\nExperiments After seeing all the interesting findings above, it should be pretty fun to reproduce them. Some results are easily to reproduce than others. Details are described below. My code is available on github lilianweng/generalization-experiment.\nNew Risk Curve for DL Models\nThis is the trickiest one to reproduce. The authors did give me a lot of good advice and I appreciate it a lot. Here are a couple of noticeable settings in their experiments:\n There are no regularization terms like weight decay, dropout. In Fig 3, the training set contains 4k samples. It is only sampled once and fixed for all the models. The evaluation uses the full MNIST test set. Each network is trained for a long time to achieve near-zero training risk. The learning rate is adjusted differently for models of different sizes. To make the model less sensitive to the initialization in the under-parameterization region, their experiments adopted a \u0026ldquo;weight reuse\u0026rdquo; scheme: the parameters obtained from training a smaller neural network are used as initialization for training larger networks. I did not train or tune each model long enough to get perfect training performance, but evaluation error indeed shows a special twist around the interpolation threshold, different from training error. For example, for MNIST, the threshold is the number of training samples times the number of classes (10), that is 40000.\nThe x-axis is the number of model parameters: (28 * 28 + 1) * num. units + num. units * 10, in logarithm.\nLayers are not Created Equal\nThis one is fairly easy to reproduce. See my implementation here.\nIn the first experiment, I used a three-layer fc networks with 256 units in each layer. Layer 0 is the input layer while layer 3 is the output. The network is trained on MNIST for 100 epochs.\nIn the second experiment, I used a four-layer fc networks with 128 units in each layer. Other settings are the same as experiment 1.\nIntrinsic Dimension Measurement\nTo correctly map the $d$-dimensional subspace to the full parameter space, the projection matrix $\\mathbf{P}$ should have orthogonal columns. Because the production $\\mathbf{P}\\theta^{(d)}$ is the sum of columns of $\\mathbf{P}$ scaled by corresponding scalar values in the $d$-dim vector, $\\sum_{i=1}^d \\theta^{(d)}_i \\mathbf{P}^\\top_{(:,i)}$, it is better to fully utilize the subspace with orthogonal columns in $\\mathbf{P}$.\nMy implementation follows a naive approach by sampling a large matrix with independent entries from a standard normal distribution. The columns are expected to be independent in a high dimension space and thus to be orthogonal. This works when the dimension is not too large. When exploring with a large $d$, there are methods for creating sparse projection matrices, which is what the intrinsic dimension paper suggested.\nHere are experiment runs on two networks: (left) a two-layer fc network with 64 units in each layer and (right) a one-layer fc network with 128 hidden units, trained on 10% of MNIST. For every $d$, the model is trained for 100 epochs. See the code here.\n Cited as:\n@article{weng2019overfit, title = \u0026quot;Are Deep Neural Networks Dramatically Overfitted?\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2019\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2019-03-14-overfit/\u0026quot; } References [1] Wikipedia page on Occam\u0026rsquo;s Razor.\n[2] Occam\u0026rsquo;s Razor on Principia Cybernetica Web.\n[3] Peter Grunwald. \u0026ldquo;A Tutorial Introduction to the Minimum Description Length Principle\u0026rdquo;. 2004.\n[4] Ian Goodfellow, et al. Deep Learning. 2016. Sec 6.4.1.\n[5] Zhang, Chiyuan, et al. \u0026ldquo;Understanding deep learning requires rethinking generalization.\u0026quot; ICLR 2017.\n[6] Shibani Santurkar, et al. \u0026ldquo;How does batch normalization help optimization?.\u0026quot; NIPS 2018.\n[7] Mikhail Belkin, et al. \u0026ldquo;Reconciling modern machine learning and the bias-variance trade-off.\u0026quot; arXiv:1812.11118, 2018.\n[8] Chiyuan Zhang, et al. \u0026ldquo;Are All Layers Created Equal?\u0026quot; arXiv:1902.01996, 2019.\n[9] Chunyuan Li, et al. \u0026ldquo;Measuring the intrinsic dimension of objective landscapes.\u0026quot; ICLR 2018.\n[10] Jonathan Frankle and Michael Carbin. \u0026ldquo;The lottery ticket hypothesis: Finding sparse, trainable neural networks.\u0026quot; ICLR 2019.\n","permalink":"https://lilianweng.github.io/posts/2019-03-14-overfit/","summary":"[Updated on 2019-05-27: add the section on Lottery Ticket Hypothesis.]\nIf you are like me, entering into the field of deep learning with experience in traditional machine learning, you may often ponder over this question: Since a typical deep neural network has so many parameters and training error can easily be perfect, it should surely suffer from substantial overfitting. How could it be ever generalized to out-of-sample data points?\nThe effort in understanding why deep neural networks can generalize somehow reminds me of this interesting paper on System Biology \u0026mdash; \u0026ldquo;Can a biologist fix a radio?","title":"Are Deep Neural Networks Dramatically Overfitted?"},{"content":"[Updated on 2019-02-14: add ULMFiT and GPT-2.] [Updated on 2020-02-29: add ALBERT.] [Updated on 2020-10-25: add RoBERTa.] [Updated on 2020-12-13: add T5.] [Updated on 2020-12-30: add GPT-3.] [Updated on 2021-11-13: add XLNet, BART and ELECTRA; Also updated the Summary section.]\nFig. 0. I guess they are Elmo \u0026 Bert? (Image source: here) We have seen amazing progress in NLP in 2018. Large-scale pre-trained language modes like OpenAI GPT and BERT have achieved great performance on a variety of language tasks using generic model architectures. The idea is similar to how ImageNet classification pre-training helps many vision tasks (*). Even better than vision classification pre-training, this simple and powerful approach in NLP does not require labeled data for pre-training, allowing us to experiment with increased training scale, up to our very limit.\n(*) He et al. (2018) found that pre-training might not be necessary for image segmentation task.\nIn my previous NLP post on word embedding, the introduced embeddings are not context-specific \u0026mdash; they are learned based on word concurrency but not sequential context. So in two sentences, \u0026ldquo;I am eating an apple\u0026rdquo; and \u0026ldquo;I have an Apple phone\u0026rdquo;, two \u0026ldquo;apple\u0026rdquo; words refer to very different things but they would still share the same word embedding vector.\nDespite this, early adoption of word embeddings in problem-solving is to use them as additional features for an existing task-specific model and in a way the improvement is bounded.\nIn this post, we will discuss how various approaches were proposed to make embeddings dependent on context, and to make them easier and cheaper to be applied to downstream tasks in general form.\nCoVe CoVe (McCann et al. 2017), short for Contextual Word Vectors, is a type of word embeddings learned by an encoder in an attentional seq-to-seq machine translation model. Different from traditional word embeddings introduced here, CoVe word representations are functions of the entire input sentence.\nNMT Recap Here the Neural Machine Translation (NMT) model is composed of a standard, two-layer, bidirectional LSTM encoder and an attentional two-layer unidirectional LSTM decoder. It is pre-trained on the English-German translation task. The encoder learns and optimizes the embedding vectors of English words in order to translate them to German. With the intuition that the encoder should capture high-level semantic and syntactic meanings before transforming words into another language, the encoder output is used to provide contextualized word embeddings for various downstream language tasks.\nFig. 1. The NMT base model used in CoVe. A sequence of $n$ words in source language (English): $x = [x_1, \\dots, x_n]$. A sequence of $m$ words in target language (German): $y = [y_1, \\dots, y_m]$. The GloVe vectors of source words: $\\text{GloVe}(x)$. Randomly initialized embedding vectors of target words: $z = [z_1, \\dots, z_m]$. The biLSTM encoder outputs a sequence of hidden states: $h = [h_1, \\dots, h_n] = \\text{biLSTM}(\\text{GloVe}(x))$ and $h_t = [\\overrightarrow{h}_t; \\overleftarrow{h}_t]$ where the forward LSTM computes $\\overrightarrow{h}_t = \\text{LSTM}(x_t, \\overrightarrow{h}_{t-1})$ and the backward computation gives us $\\overleftarrow{h}_t = \\text{LSTM}(x_t, \\overleftarrow{h}_{t-1})$. The attentional decoder outputs a distribution over words: $p(y_t \\mid H, y_1, \\dots, y_{t-1})$ where $H$ is a stack of hidden states $\\{h\\}$ along the time dimension: $$ \\begin{aligned} \\text{decoder hidden state: } s_t \u0026= \\text{LSTM}([z_{t-1}; \\tilde{h}_{t-1}], s_{t-1}) \\\\ \\text{attention weights: } \\alpha_t \u0026= \\text{softmax}(H(W_1 s_t + b_1)) \\\\ \\text{context-adjusted hidden state: } \\tilde{h}_t \u0026= \\tanh(W_2[H^\\top\\alpha_t;s_t] + b_2) \\\\ \\text{decoder output: } p(y_t\\mid H, y_1, \\dots, y_{t-1}) \u0026= \\text{softmax}(W_\\text{out} \\tilde{h}_t + b_\\text{out}) \\end{aligned} $$ Use CoVe in Downstream Tasks The hidden states of NMT encoder are defined as context vectors for other language tasks:\n $$ \\text{CoVe}(x) = \\text{biLSTM}(\\text{GloVe}(x)) $$ The paper proposed to use the concatenation of GloVe and CoVe for question-answering and classification tasks. GloVe learns from the ratios of global word co-occurrences, so it has no sentence context, while CoVe is generated by processing text sequences is able to capture the contextual information.\n $$ v = [\\text{GloVe}(x); \\text{CoVe}(x)] $$ Given a downstream task, we first generate the concatenation of GloVe + CoVe vectors of input words and then feed them into the task-specific models as additional features.\nFig. 2. The CoVe embeddings are generated by an encoder trained for machine translation task. The encoder can be plugged into any downstream task-specific model. (Image source: original paper) Summary: The limitation of CoVe is obvious: (1) pre-training is bounded by available datasets on the supervised translation task; (2) the contribution of CoVe to the final performance is constrained by the task-specific model architecture.\nIn the following sections, we will see that ELMo overcomes issue (1) by unsupervised pre-training and OpenAI GPT \u0026amp; BERT further overcome both problems by unsupervised pre-training + using generative model architecture for different downstream tasks.\nELMo ELMo, short for Embeddings from Language Model (Peters, et al, 2018) learns contextualized word representation by pre-training a language model in an unsupervised way.\nBidirectional Language Model The bidirectional Language Model (biLM) is the foundation for ELMo. While the input is a sequence of $n$ tokens, $(x_1, \\dots, x_n)$, the language model learns to predict the probability of next token given the history.\nIn the forward pass, the history contains words before the target token,\n $$ p(x_1, \\dots, x_n) = \\prod_{i=1}^n p(x_i \\mid x_1, \\dots, x_{i-1}) $$ In the backward pass, the history contains words after the target token,\n $$ p(x_1, \\dots, x_n) = \\prod_{i=1}^n p(x_i \\mid x_{i+1}, \\dots, x_n) $$ The predictions in both directions are modeled by multi-layer LSTMs with hidden states $\\overrightarrow{\\mathbf{h}}_{i,\\ell}$ and $\\overleftarrow{\\mathbf{h}}_{i,\\ell}$ for input token $x_i$ at the layer level $\\ell=1,\\dots,L$. The final layer’s hidden state $\\mathbf{h}_{i,L} = [\\overrightarrow{\\mathbf{h}}_{i,L}; \\overleftarrow{\\mathbf{h}}_{i,L}]$ is used to output the probabilities over tokens after softmax normalization. They share the embedding layer and the softmax layer, parameterized by $\\Theta_e$ and $\\Theta_s$ respectively.\nFig. 3. The biLSTM base model of ELMo. (Image source: recreated based on the figure in [\"Neural Networks, Types, and Functional Programming\"](http://colah.github.io/posts/2015-09-NN-Types-FP/) by Christopher Olah.) The model is trained to minimize the negative log likelihood (= maximize the log likelihood for true words) in both directions:\n $$ \\begin{aligned} \\mathcal{L} = - \\sum_{i=1}^n \\Big( \\log p(x_i \\mid x_1, \\dots, x_{i-1}; \\Theta_e, \\overrightarrow{\\Theta}_\\text{LSTM}, \\Theta_s) + \\\\ \\log p(x_i \\mid x_{i+1}, \\dots, x_n; \\Theta_e, \\overleftarrow{\\Theta}_\\text{LSTM}, \\Theta_s) \\Big) \\end{aligned} $$ ELMo Representations On top of a $L$-layer biLM, ELMo stacks all the hidden states across layers together by learning a task-specific linear combination. The hidden state representation for the token $x_i$ contains $2L+1$ vectors:\n $$ R_i = \\{ \\mathbf{h}_{i,\\ell} \\mid \\ell = 0, \\dots, L \\} $$ where $\\mathbf{h}_{0, \\ell}$ is the embedding layer output and $\\mathbf{h}_{i, \\ell} = [\\overrightarrow{\\mathbf{h}}_{i,\\ell}; \\overleftarrow{\\mathbf{h}}_{i,\\ell}]$.\nThe weights, $\\mathbf{s}^\\text{task}$, in the linear combination are learned for each end task and normalized by softmax. The scaling factor $\\gamma^\\text{task}$ is used to correct the misalignment between the distribution of biLM hidden states and the distribution of task specific representations.\n $$ v_i = f(R_i; \\Theta^\\text{task}) = \\gamma^\\text{task} \\sum_{\\ell=0}^L s^\\text{task}_i \\mathbf{h}_{i,\\ell} $$ To evaluate what kind of information is captured by hidden states across different layers, ELMo is applied on semantic-intensive and syntax-intensive tasks respectively using representations in different layers of biLM:\n Semantic task: The word sense disambiguation (WSD) task emphasizes the meaning of a word given a context. The biLM top layer is better at this task than the first layer. Syntax task: The part-of-speech (POS) tagging task aims to infer the grammatical role of a word in one sentence. A higher accuracy can be achieved by using the biLM first layer than the top layer. The comparison study indicates that syntactic information is better represented at lower layers while semantic information is captured by higher layers. Because different layers tend to carry different type of information, stacking them together helps.\nUse ELMo in Downstream Tasks Similar to how CoVe can help different downstream tasks, ELMo embedding vectors are included in the input or lower levels of task-specific models. Moreover, for some tasks (i.e., SNLI and SQuAD, but not SRL), adding them into the output level helps too.\nThe improvements brought up by ELMo are largest for tasks with a small supervised dataset. With ELMo, we can also achieve similar performance with much less labeled data.\nSummary: The language model pre-training is unsupervised and theoretically the pre-training can be scaled up as much as possible since the unlabeled text corpora are abundant. However, it still has the dependency on task-customized models and thus the improvement is only incremental, while searching for a good model architecture for every task remains non-trivial.\nCross-View Training In ELMo the unsupervised pre-training and task-specific learning happen for two independent models in two separate training stages. Cross-View Training (abbr. CVT; Clark et al., 2018) combines them into one unified semi-supervised learning procedure where the representation of a biLSTM encoder is improved by both supervised learning with labeled data and unsupervised learning with unlabeled data on auxiliary tasks.\nModel Architecture The model consists of a two-layer bidirectional LSTM encoder and a primary prediction module. During training, the model is fed with labeled and unlabeled data batches alternatively.\n On labeled examples, all the model parameters are updated by standard supervised learning. The loss is the standard cross entropy. On unlabeled examples, the primary prediction module still can produce a \u0026ldquo;soft\u0026rdquo; target, even though we cannot know exactly how accurate they are. In a couple of auxiliary tasks, the predictor only sees and processes a restricted view of the input, such as only using encoder hidden state representation in one direction. The auxiliary task outputs are expected to match the primary prediction target for a full view of input. In this way, the encoder is forced to distill the knowledge of the full context into partial representation. At this stage, the biLSTM encoder is backpropagated but the primary prediction module is fixed. The loss is to minimize the distance between auxiliary and primary predictions. Fig. 4. The overview of semi-supervised language model cross-view training. (Image source: original paper) Multi-Task Learning When training for multiple tasks simultaneously, CVT adds several extra primary prediction models for additional tasks. They all share the same sentence representation encoder. During supervised training, once one task is randomly selected, parameters in its corresponding predictor and the representation encoder are updated. With unlabeled data samples, the encoder is optimized jointly across all the tasks by minimizing the differences between auxiliary outputs and primary prediction for every task.\nThe multi-task learning encourages better generality of representation and in the meantime produces a nice side-product: all-tasks-labeled examples from unlabeled data. They are precious data labels considering that cross-task labels are useful but fairly rare.\nUse CVT in Downstream Tasks Theoretically the primary prediction module can take any form, generic or task-specific design. The examples presented in the CVT paper include both cases.\nIn sequential tagging tasks (classification for every token) like NER or POS tagging, the predictor module contains two fully connected layers and a softmax layer on the output to produce a probability distribution over class labels. For each token $\\mathbf{x}_i$, we take the corresponding hidden states in two layers, $\\mathbf{h}_1^{(i)}$ and $\\mathbf{h}_2^{(i)}$:\n $$ \\begin{aligned} p_\\theta(y_i \\mid \\mathbf{x}_i) \u0026= \\text{NN}(\\mathbf{h}^{(i)}) \\\\ \u0026= \\text{NN}([\\mathbf{h}_1^{(i)}; \\mathbf{h}_2^{(i)}]) \\\\ \u0026= \\text{softmax} \\big( \\mathbf{W}\\cdot\\text{ReLU}(\\mathbf{W'}\\cdot[\\mathbf{h}_1^{(i)}; \\mathbf{h}_2^{(i)}]) + \\mathbf{b} \\big) \\end{aligned} $$ The auxiliary tasks are only fed with forward or backward LSTM state in the first layer. Because they only observe partial context, either on the left or right, they have to learn like a language model, trying to predict the next token given the context. The fwd and bwd auxiliary tasks only take one direction. The future and past tasks take one step further in forward and backward direction, respectively.\n $$ \\begin{aligned} p_\\theta^\\text{fwd}(y_i \\mid \\mathbf{x}_i) \u0026= \\text{NN}^\\text{fwd}(\\overrightarrow{\\mathbf{h}}^{(i)}) \\\\ p_\\theta^\\text{bwd}(y_i \\mid \\mathbf{x}_i) \u0026= \\text{NN}^\\text{bwd}(\\overleftarrow{\\mathbf{h}}^{(i)}) \\\\ p_\\theta^\\text{future}(y_i \\mid \\mathbf{x}_i) \u0026= \\text{NN}^\\text{future}(\\overrightarrow{\\mathbf{h}}^{(i-1)}) \\\\ p_\\theta^\\text{past}(y_i \\mid \\mathbf{x}_i) \u0026= \\text{NN}^\\text{past}(\\overleftarrow{\\mathbf{h}}^{(i+1)}) \\end{aligned} $$ Fig. 5. The sequential tagging task depends on four auxiliary prediction models, their inputs only involving hidden states in one direction: forward, backward, future and past. (Image source: original paper) Note that if the primary prediction module has dropout, the dropout layer works as usual when training with labeled data, but it is not applied when generating \u0026ldquo;soft\u0026rdquo; target for auxiliary tasks during training with unlabeled data.\nIn the machine translation task, the primary prediction module is replaced with a standard unidirectional LSTM decoder with attention. There are two auxiliary tasks: (1) apply dropout on the attention weight vector by randomly zeroing out some values; (2) predict the future word in the target sequence. The primary prediction for auxiliary tasks to match is the best predicted target sequence produced by running the fixed primary decoder on the input sequence with beam search.\nULMFiT The idea of using generative pretrained LM + task-specific fine-tuning was first explored in ULMFiT (Howard \u0026amp; Ruder, 2018), directly motivated by the success of using ImageNet pre-training for computer vision tasks. The base model is AWD-LSTM.\nULMFiT follows three steps to achieve good transfer learning results on downstream language classification tasks:\n General LM pre-training: on Wikipedia text.\n Target task LM fine-tuning: ULMFiT proposed two training techniques for stabilizing the fine-tuning process. See below.\n Discriminative fine-tuning is motivated by the fact that different layers of LM capture different types of information (see discussion above). ULMFiT proposed to tune each layer with different learning rates, $\\{\\eta^1, \\dots, \\eta^\\ell, \\dots, \\eta^L\\}$, where $\\eta$ is the base learning rate for the first layer, $\\eta^\\ell$ is for the $\\ell$-th layer and there are $L$ layers in total.\n Slanted triangular learning rates (STLR) refer to a special learning rate scheduling that first linearly increases the learning rate and then linearly decays it. The increase stage is short so that the model can converge to a parameter space suitable for the task fast, while the decay period is long allowing for better fine-tuning.\n Target task classifier fine-tuning: The pretrained LM is augmented with two standard feed-forward layers and a softmax normalization at the end to predict a target label distribution. Concat pooling extracts max-polling and mean-pooling over the history of hidden states and concatenates them with the final hidden state.\n Gradual unfreezing helps to avoid catastrophic forgetting by gradually unfreezing the model layers starting from the last one. First the last layer is unfrozen and fine-tuned for one epoch. Then the next lower layer is unfrozen. This process is repeated until all the layers are tuned.\n Fig. 6. Three training stages of ULMFiT. (Image source: original paper) GPT Following the similar idea of ELMo, OpenAI GPT, short for Generative Pre-training Transformer (Radford et al., 2018), expands the unsupervised language model to a much larger scale by training on a giant collection of free text corpora. Despite of the similarity, GPT has two major differences from ELMo.\n The model architectures are different: ELMo uses a shallow concatenation of independently trained left-to-right and right-to-left multi-layer LSTMs, while GPT is a multi-layer transformer decoder. The use of contextualized embeddings in downstream tasks are different: ELMo feeds embeddings into models customized for specific tasks as additional features, while GPT fine-tunes the same base model for all end tasks. Transformer Decoder as Language Model Compared to the original transformer architecture, the transformer decoder model discards the encoder part, so there is only one single input sentence rather than two separate source and target sequences.\nThis model applies multiple transformer blocks over the embeddings of input sequences. Each block contains a masked multi-headed self-attention layer and a pointwise feed-forward layer. The final output produces a distribution over target tokens after softmax normalization.\nFig. 7. The transformer decoder model architecture in OpenAI GPT. The loss is the negative log-likelihood, same as ELMo, but without backward computation. Let’s say, the context window of the size $k$ is located before the target word and the loss would look like:\n $$ \\mathcal{L}_\\text{LM} = -\\sum_{i} \\log p(x_i\\mid x_{i-k}, \\dots, x_{i-1}) $$ Byte Pair Encoding Byte Pair Encoding (BPE) is used to encode the input sequences. BPE was originally proposed as a data compression algorithm in 1990s and then was adopted to solve the open-vocabulary issue in machine translation, as we can easily run into rare and unknown words when translating into a new language. Motivated by the intuition that rare and unknown words can often be decomposed into multiple subwords, BPE finds the best word segmentation by iteratively and greedily merging frequent pairs of characters.\nSupervised Fine-Tuning The most substantial upgrade that OpenAI GPT proposed is to get rid of the task-specific model and use the pre-trained language model directly!\nLet’s take classification as an example. Say, in the labeled dataset, each input has $n$ tokens, $\\mathbf{x} = (x_1, \\dots, x_n)$, and one label $y$. GPT first processes the input sequence $\\mathbf{x}$ through the pre-trained transformer decoder and the last layer output for the last token $x_n$ is $\\mathbf{h}_L^{(n)}$. Then with only one new trainable weight matrix $\\mathbf{W}_y$, it can predict a distribution over class labels.\n $$ P(y\\mid x_1, \\dots, x_n) = \\text{softmax}(\\mathbf{h}_L^{(n)}\\mathbf{W}_y) $$ The loss is to minimize the negative log-likelihood for true labels. In addition, adding the LM loss as an auxiliary loss is found to be beneficial, because:\n (1) it helps accelerate convergence during training and (2) it is expected to improve the generalization of the supervised model. $$ \\begin{aligned} \\mathcal{L}_\\text{cls} \u0026= \\sum_{(\\mathbf{x}, y) \\in \\mathcal{D}} \\log P(y\\mid x_1, \\dots, x_n) = \\sum_{(\\mathbf{x}, y) \\in \\mathcal{D}} \\log \\text{softmax}(\\mathbf{h}_L^{(n)}(\\mathbf{x})\\mathbf{W}_y) \\\\ \\mathcal{L}_\\text{LM} \u0026= -\\sum_{i} \\log p(x_i\\mid x_{i-k}, \\dots, x_{i-1}) \\\\ \\mathcal{L} \u0026= \\mathcal{L}_\\text{cls} + \\lambda \\mathcal{L}_\\text{LM} \\end{aligned} $$ With similar designs, no customized model structure is needed for other end tasks (see Fig. 7). If the task input contains multiple sentences, a special delimiter token ($) is added between each pair of sentences. The embedding for this delimiter token is a new parameter we need to learn, but it should be pretty minimal.\nFor the sentence similarity task, because the ordering does not matter, both orderings are included. For the multiple choice task, the context is paired with every answer candidate.\nFig. 8. Training objects in slightly modified GPT transformer models for downstream tasks. (Image source: original paper) Summary: It is super neat and encouraging to see that such a general framework is capable to beat SOTA on most language tasks at that time (June 2018). At the first stage, generative pre-training of a language model can absorb as much free text as possible. Then at the second stage, the model is fine-tuned on specific tasks with a small labeled dataset and a minimal set of new parameters to learn.\nOne limitation of GPT is its uni-directional nature \u0026mdash; the model is only trained to predict the future left-to-right context.\nBERT BERT, short for Bidirectional Encoder Representations from Transformers (Devlin, et al., 2019) is a direct descendant to GPT: train a large language model on free text and then fine-tune on specific tasks without customized network architectures.\nCompared to GPT, the largest difference and improvement of BERT is to make training bi-directional. The model learns to predict both context on the left and right. The paper according to the ablation study claimed that:\n \u0026ldquo;bidirectional nature of our model is the single most important new contribution\u0026rdquo;\n Pre-training Tasks The model architecture of BERT is a multi-layer bidirectional Transformer encoder.\nFig. 9. Recap of Transformer Encoder model architecture. (Image source: Transformer paper) To encourage the bi-directional prediction and sentence-level understanding, BERT is trained with two tasks instead of the basic language task (that is, to predict the next token given context).\n*Task 1: Mask language model (MLM)\n From Wikipedia: \u0026ldquo;A cloze test (also cloze deletion test) is an exercise, test, or assessment consisting of a portion of language with certain items, words, or signs removed (cloze text), where the participant is asked to replace the missing language item. … The exercise was first described by W.L. Taylor in 1953.\u0026rdquo;\n It is unsurprising to believe that a representation that learns the context around a word rather than just after the word is able to better capture its meaning, both syntactically and semantically. BERT encourages the model to do so by training on the \u0026ldquo;mask language model\u0026rdquo; task:\n Randomly mask 15% of tokens in each sequence. Because if we only replace masked tokens with a special placeholder [MASK], the special token would never be encountered during fine-tuning. Hence, BERT employed several heuristic tricks: (a) with 80% probability, replace the chosen words with [MASK]; (b) with 10% probability, replace with a random word; (c) with 10% probability, keep it the same. The model only predicts the missing words, but it has no information on which words have been replaced or which words should be predicted. The output size is only 15% of the input size. Task 2: Next sentence prediction\nMotivated by the fact that many downstream tasks involve the understanding of relationships between sentences (i.e., QA, NLI), BERT added another auxiliary task on training a binary classifier for telling whether one sentence is the next sentence of the other:\n Sample sentence pairs (A, B) so that: (a) 50% of the time, B follows A; (b) 50% of the time, B does not follow A. The model processes both sentences and output a binary label indicating whether B is the next sentence of A. The training data for both auxiliary tasks above can be trivially generated from any monolingual corpus. Hence the scale of training is unbounded. The training loss is the sum of the mean masked LM likelihood and mean next sentence prediction likelihood.\nFig. 10. Comparison of BERT, OpenAI GPT and ELMo model architectures. (Image source: original paper) Input Embedding The input embedding is the sum of three parts:\n WordPiece tokenization embeddings: The WordPiece model was originally proposed for Japanese or Korean segmentation problem. Instead of using naturally split English word, they can be further divided into smaller sub-word units so that it is more effective to handle rare or unknown words. Please read linked papers for the optimal way to split words if interested. Segment embeddings: If the input contains two sentences, they have sentence A embeddings and sentence B embeddings respectively and they are separated by a special character [SEP]; Only sentence A embeddings are used if the input only contains one sentence. Position embeddings: Positional embeddings are learned rather than hard-coded. Fig. 11. BERT input representation. (Image source: original paper) Note that the first token is always forced to be [CLS] \u0026mdash; a placeholder that will be used later for prediction in downstream tasks.\nUse BERT in Downstream Tasks BERT fine-tuning requires only a few new parameters added, just like OpenAI GPT.\nFor classification tasks, we get the prediction by taking the final hidden state of the special first token [CLS], $\\mathbf{h}^\\text{[CLS]}_L$, and multiplying it with a small weight matrix, $\\text{softmax}(\\mathbf{h}^\\text{[CLS]}_L \\mathbf{W}_\\text{cls})$.\nFor QA tasks like SQuAD, we need to predict the text span in the given paragraph for an given question. BERT predicts two probability distributions of every token, being the start and the end of the text span. Only two new small matrices, $\\mathbf{W}_\\text{s}$ and $\\mathbf{W}_\\text{e}$, are newly learned during fine-tuning and $\\text{softmax}(\\mathbf{h}^\\text{(i)}_L \\mathbf{W}_\\text{s})$ and $\\text{softmax}(\\mathbf{h}^\\text{(i)}_L \\mathbf{W}_\\text{e})$ define two probability distributions.\nOverall the add-on part for end task fine-tuning is very minimal \u0026mdash; one or two weight matrices to convert the Transform hidden states to an interpretable format. Check the paper for implementation details for other cases.\nFig. 12. Training objects in slightly modified BERT models for downstream tasks. (Image source: original paper) A summary table compares differences between fine-tuning of OpenAI GPT and BERT.\n| | OpenAI GPT | BERT | | Special char | [SEP] and [CLS] are only introduced at fine-tuning stage. | [SEP] and [CLS] and sentence A/B embeddings are learned at the pre-training stage. | | Training process | 1M steps, batch size 32k words. | 1M steps, batch size 128k words. | | Fine-tuning | lr = 5e-5 for all fine-tuning tasks. | Use task-specific lr for fine-tuning. |\nALBERT ALBERT (Lan, et al. 2019), short for A Lite BERT, is a light-weighted version of BERT model. An ALBERT model can be trained 1.7x faster with 18x fewer parameters, compared to a BERT model of similar configuration. ALBERT incorporates three changes as follows: the first two help reduce parameters and memory consumption and hence speed up the training speed, while the third one proposes a more chanllenging training task to replace the next sentence prediction (NSP) objective.\nFactorized Embedding Parameterization In BERT, the WordPiece tokenization embedding size $E$ is configured to be the same as the hidden state size $H$. That is saying, if we want to increase the model size (larger $H$), we need to learn a larger tokenization embedding too, which is expensive because it depends on the vocabulary size ($V$).\nConceptually, because the tokenization embedding is expected to learn context-independent representation and the hidden states are context-dependent, it makes sense to separate the size of the hidden layers from the size of vocabulary embedding. Using factorized embedding parameterization, the large vocabulary embedding matrix of size $V \\times H$ is decomposed into two small matrices of size $V \\times E$ and $E \\times H$. Given $H \\gt E$ or even $H \\gg E$, factorization can result in significant parameter reduction.\nCross-layer Parameter Sharing Parameter sharing across layers can happen in many ways: (a) only share feed-forward part; (b) only share attention parameters; or (c) share all the parameters. This technique reduces the number of parameters by a ton and does not damage the performance too much.\nSentence-Order Prediction (SOP) Interestingly, the next sentence prediction (NSP) task of BERT turned out to be too easy. ALBERT instead adopted a sentence-order prediction (SOP) self-supervised loss,\n Positive sample: two consecutive segments from the same document. Negative sample: same as above, but the segment order is switched. For the NSP task, the model can make reasonable predictions if it is able to detect topics when A and B are from different contexts. In comparison, SOP is harder as it requires the model to fully understand the coherence and ordering between segments.\nGPT-2 The OpenAI GPT-2 language model is a direct successor to GPT. GPT-2 has 1.5B parameters, 10x more than the original GPT, and it achieves SOTA results on 7 out of 8 tested language modeling datasets in a zero-shot transfer setting without any task-specific fine-tuning. The pre-training dataset contains 8 million Web pages collected by crawling qualified outbound links from Reddit. Large improvements by OpenAI GPT-2 are specially noticeable on small datasets and datasets used for measuring long-term dependency.\nZero-Shot Transfer The pre-training task for GPT-2 is solely language modeling. All the downstream language tasks are framed as predicting conditional probabilities and there is no task-specific fine-tuning.\n Text generation is straightforward using LM. Machine translation task, for example, English to Chinese, is induced by conditioning LM on pairs of \u0026ldquo;English sentence = Chinese sentence\u0026rdquo; and \u0026ldquo;the target English sentence =\u0026rdquo; at the end. For example, the conditional probability to predict might look like: P(? | I like green apples. = 我喜欢绿苹果。 A cat meows at him. = 一只猫对他喵。It is raining cats and dogs. =\u0026quot;) QA task is formatted similar to translation with pairs of questions and answers in the context. Summarization task is induced by adding TL;DR: after the articles in the context. BPE on Byte Sequences Same as the original GPT, GPT-2 uses BPE but on UTF-8 byte sequences. Each byte can represent 256 different values in 8 bits, while UTF-8 can use up to 4 bytes for one character, supporting up to $2^{31}$ characters in total. Therefore, with byte sequence representation we only need a vocabulary of size 256 and do not need to worry about pre-processing, tokenization, etc. Despite of the benefit, current byte-level LMs still have non-negligible performance gap with the SOTA word-level LMs.\nBPE merges frequently co-occurred byte pairs in a greedy manner. To prevent it from generating multiple versions of common words (i.e. dog., dog! and dog? for the word dog), GPT-2 prevents BPE from merging characters across categories (thus dog would not be merged with punctuations like ., ! and ?). This tricks help increase the quality of the final byte segmentation.\nUsing the byte sequence representation, GPT-2 is able to assign a probability to any Unicode string, regardless of any pre-processing steps.\nModel Modifications Compared to GPT, other than having many more transformer layers and parameters, GPT-2 incorporates only a few architecture modifications:\n Layer normalization was moved to the input of each sub-block, similar to a residual unit of type \u0026ldquo;building block\u0026rdquo; (differently from the original type \u0026ldquo;bottleneck\u0026rdquo;, it has batch normalization applied before weight layers). An additional layer normalization was added after the final self-attention block. A modified initialization was constructed as a function of the model depth. The weights of residual layers were initially scaled by a factor of $1/ \\sqrt{N}$ where N is the number of residual layers. Use larger vocabulary size and context size. RoBERTa RoBERTa (short for Robustly optimized BERT approach; Liu, et al. 2019) refers to a new receipt for training BERT to achieve better results, as they found that the original BERT model is significantly undertrained. The receipt contains the following learnings:\n Train for longer with bigger batch size. Remove the next sentence prediction (NSP) task. Use longer sequences in training data format. The paper found that using individual sentences as inputs hurts downstream performance. Instead we should use multiple sentences sampled contiguously to form longer segments. Change the masking pattern dynamically. The original BERT applies masking once during the data preprocessing stage, resulting in a static mask across training epochs. RoBERTa applies masks in 10 different ways across 40 epochs. RoBERTa also added a new dataset CommonCrawl News and further confirmed that pretraining with more data helps improve the performance on downstream tasks. It was trained with the BPE on byte sequences, same as in GPT-2. They also found that choices of hyperparameters have a big impact on the model performance.\nT5 The language model T5 is short for \u0026ldquo;Text-to-Text Transfer Transformer\u0026rdquo; (Raffel et al., 2020). The encoder-decoder implementation follows the original Transformer architecture: tokens → embedding → encoder → decoder → output. T5 adopts the framework “Natural Language Decathlon” (McCann et al., 2018), where many common NLP tasks are translated into question-answering over a context. Instead of an explicit QA format, T5 uses short task prefixes to distinguish task intentions and separately fine-tunes the model on every individual task. The text-to-text framework enables easier transfer learning evaluation with the same model on a diverse set of tasks.\nFig. 13. A diagram of T5 task evaluation. The text-to-text framework casts every task into a generic form: feeding input text to predict some target text. (Image source: Raffel et al., 2020) The model is trained on Web corpus extracted from Apr 2019 with various filters applied. The model is fine-tuned for each downstream task separately via \u0026ldquo;adapter layers\u0026rdquo; (add an extra layer for training) or \u0026ldquo;gradual unfreezing\u0026rdquo; (see ULMFiT). Both fine-tuning approaches only update partial parameters while keeping the majority of the model parameters unchanged. T5-11B achieved SOTA results on many NLP tasks.\nAs the authors mentioned in the paper \u0026ldquo;\u0026hellip;our goal is not to propose new methods but instead to provide a comprehensive perspective on where the field stands\u0026rdquo;, the T5 long paper described a lot of training setup and evaluation processes in detail, a good read for people who are interested in training a LM from scratch.\nGPT-3 GPT-3 (Brown et al., 2020) has the same architecture as GPT-2 but contains 175B parameters, 10x larger than GPT-2 (1.5B). In addition, GPT-3 uses alternating dense and locally banded sparse attention patterns, same as in sparse transformer. In order to fit such a huge model across multiple GPUs, GPT-3 is trained with partitions along both width and depth dimension. The training data is a filtered version of Common Crawl mixed with a few other high-quality curated datasets. To avoid the contamination that downstream tasks might appear in the training data, the authors attempted to remove all the overlaps with all the studied benchmark dataset from the training dataset. Unfortunately the filtering process is not perfect due to a bug.\nFig. 14. Training datasets for GPT-3. Note that the occurrence of each dataset during training is not proportional to the dataset size. (Table source: Brown et al., 2020) For all the downstream evaluation, GPT-3 is tested in the few-shot setting without any gradient-based fine-tuning. Here the few-shot examples are provided as part of the prompt. GPT-3 achieves strong performance on many NLP datasets, comparable with fine-tuned BERT models.\nFig. 15. The evaluation performance increases with the model size and the number of examples. (Image source: Brown et al., 2020) XLNet The Autoregressive (AR) model such as GPT and autoencoder (AE) model such as BERT are two most common ways for language modeling. However, each has their own disadvantages: AR does not learn the bidirectional context, which is needed by downstream tasks like reading comprehension and AE assumes masked positions are independent given all other unmasked tokens which oversimplifies the long context dependency.\nXLNet (Yang et al. 2019) generalizes the AE method to incorporate the benefits of AR. XLNet proposed the permutation language modeling objective. For a text sequence, it samples a factorization order $\\mathbf{z}$ and decomposes the likelihood $p_\\theta(\\mathbf{x})$ according to this factorization order,\n $$ \\begin{aligned} \\mathcal{L}_\\text{XLNet} \u0026= - \\mathbb{E}_{\\mathbf{z} \\sim \\mathcal{Z}_T} \\Big[ \\sum_{t=1}^T \\log p_\\theta (X_{z_t} = x \\mid \\mathbf{x}_{\\mathbf{z}_{where $\\mathcal{Z}_T$ is a set of all possible permutation of length $T$; $z_t$ and $\\mathbf{z}_{\u0026lt;t}$ denote the $t$-th element and the first $t-1$ elements of a permutation $\\mathbf{z} \\in \\mathcal{Z}_T$.\nNote that the naive representation of the hidden state of the context, $h_\\theta (\\mathbf{x}_{\\mathbf{z}_{\u0026lt;t}})$ in red, does not depend on which position the model tries to predict, as the permutation breaks the default ordering. Therefore, XLNet re-parameterized it to a function of the target position too, $g_\\theta (\\mathbf{x}_{\\mathbf{z}_{\u0026lt;t}}, z_t)$ in blue.\nHowever, two different requirements on $g_\\theta (\\mathbf{x}_{\\mathbf{z}_{\u0026lt;t}}, z_t)$ lead to a two-stream self-attention design to accommodate:\n When predicting $x_{z_t}$, it should only encode the position $z_t$ but not the content $x_{z_t}$; otherwise it is trivial. This is wrapped into the \u0026ldquo;query representation\u0026rdquo; $g_{z_t} = g_\\theta (\\mathbf{x}_{\\mathbf{z}_{\u0026lt;t}}, z_t)$ does not encode $x_{z_t}$. When predicting $x_j$ where $j \u0026gt; t$, it should encode the content $x_{z_t}$ as well to provide the full context. This is the \u0026ldquo;content representation\u0026rdquo; $h_{z_t} = h_\\theta(\\mathbf{x}_{\\leq t})$. Fig. 16. The illustration of two-stream self-attention mechanism in XLNet. (Image source: Yang et al. 2019) Conceptually, the two streams of representations are updated as follows,\n $$ \\begin{aligned} g_{z_t}^{(m)} \u0026\\gets \\text{Attention}(Q = g^{(m-1)}_{z_t}, KV=\\mathbf{h}^{(m-1)}_{\\color{red}{\\mathbf{z}_{Given the difficulty of optimization in permutation language modeling, XLNet is set to only predict the last chunk of tokens in a factorization order.\nThe name in XLNet actually comes from Transformer-XL. It incorporates the design of Transformer-XL to extend the attention span by reusing hidden states from previous segments.\nFig. 17. Comparison of model performance of XLNet with a couple other language models on GLUE, all single-task, no ensembles. (Image source: Yang et al. 2019) BART BART (Lewis et al., 2019) is a denoising autoencoder to recover the original text from a randomly corrupted version. It combines Bidirectional and AutoRegressive Transformer: precisely, jointly training BERT-like bidirectional encoder and GPT-like autoregressive decoder together. The loss is simply just to minimize the negative log-likelihood.\nFig. 18. A schematic comparison of BART with BERT and GPT. (Image source: Lewis et al., 2019) They experimented with a variety of noising transformations, including token masking, token deletion, text infilling (i.e. A randomly sampled text span, which may contain multiple tokens, is replaced with a [MASK] token), sentence permutation, documentation rotation (i.e. A document is rotated to begin with a random token.). The best noising approach they discovered is text infilling and sentence shuffling.\nFig. 19. Comparison of different language modeling pre-training objectives. (Image source: Lewis et al., 2019) Learnings from their experiments:\n The performance of pre-training methods varies significantly across downstream tasks. Token masking is crucial, as the performance is poor when only sentence permutation or documentation rotation is applied. Left-to-right pre-training improves generation. Bidirectional encoders are crucial for SQuAD. The pre-training objective is not the only important factor. Architectural improvements such as relative-position embeddings or segment-level recurrence matter too. Autoregressive language models perform best on ELI5. BART achieves the most consistently strong performance. ELECTRA Most current pre-training large language models demand a lot of computation resources, raising concerns about their cost and accessibility. ELECTRA (\u0026ldquo;Efficiently Learning an Encoder that Classifies Token Replacements Accurately\u0026rdquo;; Clark et al. 2020) aims to improve the pre-training efficiency, which frames the language modeling as a discrimination task instead of generation task.\nFig. 20. Illustration of ELECTRA model architecture. (Image source: Clark et al. 2020) ELECTRA proposes a new pretraining task, called \u0026ldquo;Replaced Token Detection\u0026rdquo; (RTD). Let\u0026rsquo;s randomly sample $k$ positions to be masked. Each selected token in the original text is replaced by a plausible alternative predicted by a small language model, known as the generator $G$. The discriminator $D$ predicts whether each token is original or replaced.\n $$ \\begin{aligned} \\boldsymbol{m} \u0026= [m_1, \\dots, m_k] \\text{ where } m_i \\sim \\text{unif}\\{1, n\\}\\text{ for } i=1, \\dots, k \\\\ \\boldsymbol{x}^\\text{masked} \u0026= \\text{REPLACE}(\\boldsymbol{x}, \\boldsymbol{m}, \\texttt{[MASK]}) \\\\ \\boldsymbol{x}^\\text{corrupt} \u0026= \\text{REPLACE}(\\boldsymbol{x}, \\boldsymbol{m}, \\tilde{\\boldsymbol{x}}) \\text{ where } \\tilde{x}_t \\sim p_G(x_i \\mid \\boldsymbol{x}^\\text{masked}) \\text{ for } i \\in \\boldsymbol{m} \\\\ \\end{aligned} $$ The loss for the generator is the negative log-likelihood just as in other language models. The loss for the discriminator is the cross-entropy. Note that the generator is not adversarially trained to fool the discriminator but simply to optimize the NLL, since their experiments show negative results.\n $$ \\begin{aligned} \\mathcal{L}_\\text{MLM}(\\mathbf{x}, \\theta_G) \u0026= \\mathbb{E}\\Big(\\sum_{i \\in \\boldsymbol{m}} -\\log p_G (x_i \\mid \\boldsymbol{x}^\\text{masked} )\\Big) \\\\ \\mathcal{L}_\\text{Disc}(\\mathbf{x}, \\theta_D) \u0026= \\mathbb{E}\\Big( - \\mathbb{1}[x^\\text{corrupt}_t = x_t] \\log D(\\boldsymbol{x}^\\text{corrupt}, t) - \\mathbb{1}[x^\\text{corrupt}_t \\neq x_t] \\log (1 - \\log D(\\boldsymbol{x}^\\text{corrupt}, t)) \\Big) \\end{aligned} $$ They found it more beneficial to only share the embeddings between generator \u0026amp; discriminator while using a small generator (1/4 to 1/2 the discriminator size), rather than sharing all the weights (i.e. two models have to be the same size then). In addition, joint training of the generator and discriminator works better than two-stage training of each alternatively.\nAfter pretraining the generator is discarded and only the ELECTRA discriminator is fine-tuned further for downstream tasks. The following table shows ELECTRA\u0026rsquo;s performance on the GLUE dev set.\nFig. 21. Comparison of ELECTRA with other language models on the GLUE dev set. (Image source: Clark et al. 2020) Summary Base model Pretraining Tasks CoVe seq2seq NMT model supervised learning using translation dataset. ELMo two-layer biLSTM next token prediction CVT two-layer biLSTM semi-supervised learning using both labeled and unlabeled datasets ULMFiT AWD-LSTM autoregressive pretraining on Wikitext-103 GPT Transformer decoder next token prediction BERT Transformer encoder mask language model + next sentence prediction ALBERT same as BERT but light-weighted mask language model + sentence order prediction GPT-2 Transformer decoder next token prediction RoBERTa same as BERT mask language model (dynamic masking) T5 Transformer encoder + decoder pre-trained on a multi-task mixture of unsupervised and supervised tasks and for which each task is converted into a text-to-text format. GPT-3 Transformer decoder next token prediction XLNet same as BERT permutation language modeling BART BERT encoder + GPT decoder reconstruct text from a noised version ELECTRA same as BERT replace token detection Metric: Perplexity Perplexity is often used as an intrinsic evaluation metric for gauging how well a language model can capture the real word distribution conditioned on the context.\nA perplexity of a discrete proability distribution $p$ is defined as the exponentiation of the entropy:\n $$ 2^{H(p)} = 2^{-\\sum_x p(x) \\log_2 p(x)} $$ Given a sentence with $N$ words, $s = (w_1, \\dots, w_N)$, the entropy looks as follows, simply assuming that each word has the same frequency, $\\frac{1}{N}$:\n $$ H(s) = -\\sum_{i=1}^N P(w_i) \\log_2 p(w_i) = -\\sum_{i=1}^N \\frac{1}{N} \\log_2 p(w_i) $$ The perplexity for the sentence becomes:\n $$ \\begin{aligned} 2^{H(s)} \u0026= 2^{-\\frac{1}{N} \\sum_{i=1}^N \\log_2 p(w_i)} = (2^{\\sum_{i=1}^N \\log_2 p(w_i)})^{-\\frac{1}{N}} = (p(w_1) \\dots p(w_N))^{-\\frac{1}{N}} \\end{aligned} $$ A good language model should predict high word probabilities. Therefore, the smaller perplexity the better.\nCommon Tasks and Datasets Question-Answering\n SQuAD (Stanford Question Answering Dataset): A reading comprehension dataset, consisting of questions posed on a set of Wikipedia articles, where the answer to every question is a span of text. RACE (ReAding Comprehension from Examinations): A large-scale reading comprehension dataset with more than 28,000 passages and nearly 100,000 questions. The dataset is collected from English examinations in China, which are designed for middle school and high school students. See more QA datasets in a later post. Commonsense Reasoning\n Story Cloze Test: A commonsense reasoning framework for evaluating story understanding and generation. The test requires a system to choose the correct ending to multi-sentence stories from two options. SWAG (Situations With Adversarial Generations): multiple choices; contains 113k sentence-pair completion examples that evaluate grounded common-sense inference Natural Language Inference (NLI): also known as Text Entailment, an exercise to discern in logic whether one sentence can be inferred from another.\n RTE (Recognizing Textual Entailment): A set of datasets initiated by text entailment challenges. SNLI (Stanford Natural Language Inference): A collection of 570k human-written English sentence pairs manually labeled for balanced classification with the labels entailment, contradiction, and neutral. MNLI (Multi-Genre NLI): Similar to SNLI, but with a more diverse variety of text styles and topics, collected from transcribed speech, popular fiction, and government reports. QNLI (Question NLI): Converted from SQuAD dataset to be a binary classification task over pairs of (question, sentence). SciTail: An entailment dataset created from multiple-choice science exams and web sentences. Named Entity Recognition (NER): labels sequences of words in a text which are the names of things, such as person and company names, or gene and protein names\n CoNLL 2003 NER task: consists of newswire from the Reuters, concentrating on four types of named entities: persons, locations, organizations and names of miscellaneous entities. OntoNotes 5.0: This corpus contains text in English, Arabic and Chinese, tagged with four different entity types (PER, LOC, ORG, MISC). Reuters Corpus: A large collection of Reuters News stories. Fine-Grained NER (FGN) Sentiment Analysis\n SST (Stanford Sentiment Treebank) IMDb: A large dataset of movie reviews with binary sentiment classification labels. Semantic Role Labeling (SRL): models the predicate-argument structure of a sentence, and is often described as answering \u0026ldquo;Who did what to whom\u0026rdquo;.\n CoNLL-2004 \u0026amp; CoNLL-2005 Sentence similarity: also known as paraphrase detection\n MRPC (MicRosoft Paraphrase Corpus): It contains pairs of sentences extracted from news sources on the web, with annotations indicating whether each pair is semantically equivalent. QQP (Quora Question Pairs) STS Benchmark: Semantic Textual Similarity Sentence Acceptability: a task to annotate sentences for grammatical acceptability.\n CoLA (Corpus of Linguistic Acceptability): a binary single-sentence classification task. Text Chunking: To divide a text in syntactically correlated parts of words.\n CoNLL-2000 Part-of-Speech (POS) Tagging: tag parts of speech to each token, such as noun, verb, adjective, etc. the Wall Street Journal portion of the Penn Treebank (Marcus et al., 1993).\nMachine Translation: See Standard NLP page.\n WMT 2015 English-Czech data (Large) WMT 2014 English-German data (Medium) IWSLT 2015 English-Vietnamese data (Small) Coreference Resolution: cluster mentions in text that refer to the same underlying real world entities.\n CoNLL-2012 Long-range Dependency\n LAMBADA (LAnguage Modeling Broadened to Account for Discourse Aspects): A collection of narrative passages extracted from the BookCorpus and the task is to predict the last word, which require at least 50 tokens of context for a human to successfully predict. Children’s Book Test: is built from books that are freely available in Project Gutenberg. The task is to predict the missing word among 10 candidates. Multi-task benchmark\n GLUE multi-task benchmark: https://gluebenchmark.com decaNLP benmark: https://decanlp.com Unsupervised pretraining dataset\n Books corpus: The corpus contains \u0026ldquo;over 7,000 unique unpublished books from a variety of genres including Adventure, Fantasy, and Romance.\u0026rdquo; 1B Word Language Model Benchmark English Wikipedia: ~2500M words Cited as:\n@article{weng2019LM, title = \u0026quot;Generalized Language Models\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2019\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2019-01-31-lm/\u0026quot; } Reference [1] Bryan McCann, et al. \u0026ldquo;Learned in translation: Contextualized word vectors.\u0026quot; NIPS. 2017.\n[2] Kevin Clark et al. \u0026ldquo;Semi-Supervised Sequence Modeling with Cross-View Training.\u0026quot; EMNLP 2018.\n[3] Matthew E. Peters, et al. \u0026ldquo;Deep contextualized word representations.\u0026quot; NAACL-HLT 2017.\n[4] OpenAI Blog \u0026ldquo;Improving Language Understanding with Unsupervised Learning\u0026rdquo;, June 11, 2018.\n[5] OpenAI Blog \u0026ldquo;Better Language Models and Their Implications.\u0026quot; Feb 14, 2019.\n[6] Jeremy Howard and Sebastian Ruder. \u0026ldquo;Universal language model fine-tuning for text classification.\u0026quot; ACL 2018.\n[7] Alec Radford et al. \u0026ldquo;Improving Language Understanding by Generative Pre-Training\u0026rdquo;. OpenAI Blog, June 11, 2018.\n[8] Jacob Devlin, et al. \u0026ldquo;BERT: Pre-training of deep bidirectional transformers for language understanding.\u0026quot; arXiv:1810.04805 (2018).\n[9] Mike Schuster, and Kaisuke Nakajima. \u0026ldquo;Japanese and Korean voice search.\u0026quot; ICASSP. 2012.\n[10] Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation\n[11] Ashish Vaswani, et al. \u0026ldquo;Attention is all you need.\u0026quot; NIPS 2017.\n[12] Peter J. Liu, et al. \u0026ldquo;Generating wikipedia by summarizing long sequences.\u0026quot; ICLR 2018.\n[13] Sebastian Ruder. \u0026ldquo;10 Exciting Ideas of 2018 in NLP\u0026rdquo; Dec 2018.\n[14] Alec Radford, et al. \u0026ldquo;Language Models are Unsupervised Multitask Learners.\u0026quot;. 2019.\n[15] Rico Sennrich, et al. \u0026ldquo;Neural machine translation of rare words with subword units.\u0026quot; arXiv preprint arXiv:1508.07909. 2015.\n[16] Zhenzhong Lan, et al. \u0026ldquo;ALBERT: A Lite BERT for Self-supervised Learning of Language Representations.\u0026quot; arXiv Preprint arXiv:1909.11942 (2019).\n[17] Yinhan Liu, et al. \u0026ldquo;RoBERTa: A Robustly Optimized BERT Pretraining Approach.\u0026quot; arXiv Preprint arXiv:1907.11692 (2019).\n[18] Tom B Brown, et al. \u0026ldquo;Language Models are Few-Shot Learners\u0026rdquo; NeuriPS 2020.\n[19] Zhilin Yang et al. “XLNet: Generalized Autoregressive Pretraining for Language Understanding.” NeuriPS 2019.\n[20] Mike Lewis et al. “BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension.” ACL 2020.\n[21] Kevin Clark et al. “ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators.” ICLR 2020.\n[22] Colin Raffel, et al. \u0026ldquo;Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer\u0026rdquo; JMLR 2020.\n","permalink":"https://lilianweng.github.io/posts/2019-01-31-lm/","summary":"[Updated on 2019-02-14: add ULMFiT and GPT-2.] [Updated on 2020-02-29: add ALBERT.] [Updated on 2020-10-25: add RoBERTa.] [Updated on 2020-12-13: add T5.] [Updated on 2020-12-30: add GPT-3.] [Updated on 2021-11-13: add XLNet, BART and ELECTRA; Also updated the Summary section.]\nFig. 0. I guess they are Elmo \u0026 Bert? (Image source: here) We have seen amazing progress in NLP in 2018. Large-scale pre-trained language modes like OpenAI GPT and BERT have achieved great performance on a variety of language tasks using generic model architectures.","title":"Generalized Language Models"},{"content":"In Part 3, we have reviewed models in the R-CNN family. All of them are region-based object detection algorithms. They can achieve high accuracy but could be too slow for certain applications such as autonomous driving. In Part 4, we only focus on fast object detection models, including SSD, RetinaNet, and models in the YOLO family.\nLinks to all the posts in the series: [Part 1] [Part 2] [Part 3] [Part 4].\nTwo-stage vs One-stage Detectors Models in the R-CNN family are all region-based. The detection happens in two stages: (1) First, the model proposes a set of regions of interests by select search or regional proposal network. The proposed regions are sparse as the potential bounding box candidates can be infinite. (2) Then a classifier only processes the region candidates.\nThe other different approach skips the region proposal stage and runs detection directly over a dense sampling of possible locations. This is how a one-stage object detection algorithm works. This is faster and simpler, but might potentially drag down the performance a bit.\nAll the models introduced in this post are one-stage detectors.\nYOLO: You Only Look Once The YOLO model (\u0026ldquo;You Only Look Once\u0026rdquo;; Redmon et al., 2016) is the very first attempt at building a fast real-time object detector. Because YOLO does not undergo the region proposal step and only predicts over a limited number of bounding boxes, it is able to do inference super fast.\nWorkflow Pre-train a CNN network on image classification task.\n Split an image into $S \\times S$ cells. If an object\u0026rsquo;s center falls into a cell, that cell is \u0026ldquo;responsible\u0026rdquo; for detecting the existence of that object. Each cell predicts (a) the location of $B$ bounding boxes, (b) a confidence score, and (c) a probability of object class conditioned on the existence of an object in the bounding box.\n The coordinates of bounding box are defined by a tuple of 4 values, (center x-coord, center y-coord, width, height) \u0026mdash; $(x, y, w, h)$, where $x$ and $y$ are set to be offset of a cell location. Moreover, $x$, $y$, $w$ and $h$ are normalized by the image width and height, and thus all between (0, 1]. A confidence score indicates the likelihood that the cell contains an object: Pr(containing an object) x IoU(pred, truth); where Pr = probability and IoU = interaction under union. If the cell contains an object, it predicts a probability of this object belonging to every class $C_i, i=1, \\dots, K$: Pr(the object belongs to the class C_i | containing an object). At this stage, the model only predicts one set of class probabilities per cell, regardless of the number of bounding boxes, $B$. In total, one image contains $S \\times S \\times B$ bounding boxes, each box corresponding to 4 location predictions, 1 confidence score, and K conditional probabilities for object classification. The total prediction values for one image is $S \\times S \\times (5B + K)$, which is the tensor shape of the final conv layer of the model. The final layer of the pre-trained CNN is modified to output a prediction tensor of size $S \\times S \\times (5B + K)$.\n Fig. 1. The workflow of YOLO model. (Image source: original paper) Network Architecture The base model is similar to GoogLeNet with inception module replaced by 1x1 and 3x3 conv layers. The final prediction of shape $S \\times S \\times (5B + K)$ is produced by two fully connected layers over the whole conv feature map.\nFig. 2. The network architecture of YOLO. Loss Function The loss consists of two parts, the localization loss for bounding box offset prediction and the classification loss for conditional class probabilities. Both parts are computed as the sum of squared errors. Two scale parameters are used to control how much we want to increase the loss from bounding box coordinate predictions ($\\lambda_\\text{coord}$) and how much we want to decrease the loss of confidence score predictions for boxes without objects ($\\lambda_\\text{noobj}$). Down-weighting the loss contributed by background boxes is important as most of the bounding boxes involve no instance. In the paper, the model sets $\\lambda_\\text{coord} = 5$ and $\\lambda_\\text{noobj} = 0.5$.\n $$ \\begin{aligned} \\mathcal{L}_\\text{loc} \u0026= \\lambda_\\text{coord} \\sum_{i=0}^{S^2} \\sum_{j=0}^B \\mathbb{1}_{ij}^\\text{obj} [(x_i - \\hat{x}_i)^2 + (y_i - \\hat{y}_i)^2 + (\\sqrt{w_i} - \\sqrt{\\hat{w}_i})^2 + (\\sqrt{h_i} - \\sqrt{\\hat{h}_i})^2 ] \\\\ \\mathcal{L}_\\text{cls} \u0026= \\sum_{i=0}^{S^2} \\sum_{j=0}^B \\big( \\mathbb{1}_{ij}^\\text{obj} + \\lambda_\\text{noobj} (1 - \\mathbb{1}_{ij}^\\text{obj})\\big) (C_{ij} - \\hat{C}_{ij})^2 + \\sum_{i=0}^{S^2} \\sum_{c \\in \\mathcal{C}} \\mathbb{1}_i^\\text{obj} (p_i(c) - \\hat{p}_i(c))^2\\\\ \\mathcal{L} \u0026= \\mathcal{L}_\\text{loc} + \\mathcal{L}_\\text{cls} \\end{aligned} $$ NOTE: In the original YOLO paper, the loss function uses $C_i$ instead of $C_{ij}$ as confidence score. I made the correction based on my own understanding, since every bounding box should have its own confidence score. Please kindly let me if you do not agree. Many thanks.\n where,\n $\\mathbb{1}_i^\\text{obj}$: An indicator function of whether the cell i contains an object. $\\mathbb{1}_{ij}^\\text{obj}$: It indicates whether the j-th bounding box of the cell i is \u0026ldquo;responsible\u0026rdquo; for the object prediction (see Fig. 3). $C_{ij}$: The confidence score of cell i, Pr(containing an object) * IoU(pred, truth). $\\hat{C}_{ij}$: The predicted confidence score. $\\mathcal{C}$: The set of all classes. $p_i(c)$: The conditional probability of whether cell i contains an object of class $c \\in \\mathcal{C}$. $\\hat{p}_i(c)$: The predicted conditional class probability. Fig. 3. At one location, in cell i, the model proposes B bounding box candidates and the one that has highest overlap with the ground truth is the \"responsible\" predictor. The loss function only penalizes classification error if an object is present in that grid cell, $\\mathbb{1}_i^\\text{obj} = 1$. It also only penalizes bounding box coordinate error if that predictor is \u0026ldquo;responsible\u0026rdquo; for the ground truth box, $\\mathbb{1}_{ij}^\\text{obj} = 1$.\nAs a one-stage object detector, YOLO is super fast, but it is not good at recognizing irregularly shaped objects or a group of small objects due to a limited number of bounding box candidates.\nSSD: Single Shot MultiBox Detector The Single Shot Detector (SSD; Liu et al, 2016) is one of the first attempts at using convolutional neural network\u0026rsquo;s pyramidal feature hierarchy for efficient detection of objects of various sizes.\nImage Pyramid SSD uses the VGG-16 model pre-trained on ImageNet as its base model for extracting useful image features. On top of VGG16, SSD adds several conv feature layers of decreasing sizes. They can be seen as a pyramid representation of images at different scales. Intuitively large fine-grained feature maps at earlier levels are good at capturing small objects and small coarse-grained feature maps can detect large objects well. In SSD, the detection happens in every pyramidal layer, targeting at objects of various sizes.\nFig. 4. The model architecture of SSD. Workflow Unlike YOLO, SSD does not split the image into grids of arbitrary size but predicts offset of predefined anchor boxes (this is called \u0026ldquo;default boxes\u0026rdquo; in the paper) for every location of the feature map. Each box has a fixed size and position relative to its corresponding cell. All the anchor boxes tile the whole feature map in a convolutional manner.\nFeature maps at different levels have different receptive field sizes. The anchor boxes on different levels are rescaled so that one feature map is only responsible for objects at one particular scale. For example, in Fig. 5 the dog can only be detected in the 4x4 feature map (higher level) while the cat is just captured by the 8x8 feature map (lower level).\nFig. 5. The SSD framework. (a) The training data contains images and ground truth boxes for every object. (b) In a fine-grained feature maps (8 x 8), the anchor boxes of different aspect ratios correspond to smaller area of the raw input. (c) In a coarse-grained feature map (4 x 4), the anchor boxes cover larger area of the raw input. (Image source: original paper) The width, height and the center location of an anchor box are all normalized to be (0, 1). At a location $(i, j)$ of the $\\ell$-th feature layer of size $m \\times n$, $i=1,\\dots,n, j=1,\\dots,m$, we have a unique linear scale proportional to the layer level and 5 different box aspect ratios (width-to-height ratios), in addition to a special scale (why we need this? the paper didn’t explain. maybe just a heuristic trick) when the aspect ratio is 1. This gives us 6 anchor boxes in total per feature cell.\n $$ \\begin{aligned} \\text{level index: } \u0026\\ell = 1, \\dots, L \\\\ \\text{scale of boxes: } \u0026s_\\ell = s_\\text{min} + \\frac{s_\\text{max} - s_\\text{min}}{L - 1} (\\ell - 1) \\\\ \\text{aspect ratio: } \u0026r \\in \\{1, 2, 3, 1/2, 1/3\\}\\\\ \\text{additional scale: } \u0026 s'_\\ell = \\sqrt{s_\\ell s_{\\ell + 1}} \\text{ when } r = 1 \\text{thus, 6 boxes in total.}\\\\ \\text{width: } \u0026w_\\ell^r = s_\\ell \\sqrt{r} \\\\ \\text{height: } \u0026h_\\ell^r = s_\\ell / \\sqrt{r} \\\\ \\text{center location: } \u0026 (x^i_\\ell, y^j_\\ell) = (\\frac{i+0.5}{m}, \\frac{j+0.5}{n}) \\end{aligned} $$ Fig. 6. An example of how the anchor box size is scaled up with the layer index $\\ell$ for $L=6, s\\_\\text{min} = 0.2, s\\_\\text{max} = 0.9$. Only the boxes of aspect ratio $r=1$ are illustrated. At every location, the model outputs 4 offsets and $c$ class probabilities by applying a $3 \\times 3 \\times p$ conv filter (where $p$ is the number of channels in the feature map) for every one of $k$ anchor boxes. Therefore, given a feature map of size $m \\times n$, we need $kmn(c+4)$ prediction filters.\nLoss Function Same as YOLO, the loss function is the sum of a localization loss and a classification loss.\n$\\mathcal{L} = \\frac{1}{N}(\\mathcal{L}_\\text{cls} + \\alpha \\mathcal{L}_\\text{loc})$\nwhere $N$ is the number of matched bounding boxes and $\\alpha$ balances the weights between two losses, picked by cross validation.\nThe localization loss is a smooth L1 loss between the predicted bounding box correction and the true values. The coordinate correction transformation is same as what R-CNN does in bounding box regression.\n $$ \\begin{aligned} \\mathcal{L}_\\text{loc} \u0026= \\sum_{i,j} \\sum_{m\\in\\{x, y, w, h\\}} \\mathbb{1}_{ij}^\\text{match} L_1^\\text{smooth}(d_m^i - t_m^j)^2\\\\ L_1^\\text{smooth}(x) \u0026= \\begin{cases} 0.5 x^2 \u0026 \\text{if } \\vert x \\vert where $\\mathbb{1}_{ij}^\\text{match}$ indicates whether the $i$-th bounding box with coordinates $(p^i_x, p^i_y, p^i_w, p^i_h)$ is matched to the $j$-th ground truth box with coordinates $(g^j_x, g^j_y, g^j_w, g^j_h)$ for any object. $d^i_m, m\\in\\{x, y, w, h\\}$ are the predicted correction terms. See this for how the transformation works.\nThe classification loss is a softmax loss over multiple classes (softmax_cross_entropy_with_logits in tensorflow):\n $$ \\mathcal{L}_\\text{cls} = -\\sum_{i \\in \\text{pos}} \\mathbb{1}_{ij}^k \\log(\\hat{c}_i^k) - \\sum_{i \\in \\text{neg}} \\log(\\hat{c}_i^0)\\text{, where }\\hat{c}_i^k = \\text{softmax}(c_i^k) $$ where $\\mathbb{1}_{ij}^k$ indicates whether the $i$-th bounding box and the $j$-th ground truth box are matched for an object in class $k$. $\\text{pos}$ is the set of matched bounding boxes ($N$ items in total) and $\\text{neg}$ is the set of negative examples. SSD uses hard negative mining to select easily misclassified negative examples to construct this $\\text{neg}$ set: Once all the anchor boxes are sorted by objectiveness confidence score, the model picks the top candidates for training so that neg:pos is at most 3:1.\nYOLOv2 / YOLO9000 YOLOv2 (Redmon \u0026amp; Farhadi, 2017) is an enhanced version of YOLO. YOLO9000 is built on top of YOLOv2 but trained with joint dataset combining the COCO detection dataset and the top 9000 classes from ImageNet.\nYOLOv2 Improvement A variety of modifications are applied to make YOLO prediction more accurate and faster, including:\n1. BatchNorm helps: Add batch norm on all the convolutional layers, leading to significant improvement over convergence.\n2. Image resolution matters: Fine-tuning the base model with high resolution images improves the detection performance.\n3. Convolutional anchor box detection: Rather than predicts the bounding box position with fully-connected layers over the whole feature map, YOLOv2 uses convolutional layers to predict locations of anchor boxes, like in faster R-CNN. The prediction of spatial locations and class probabilities are decoupled. Overall, the change leads to a slight decrease in mAP, but an increase in recall.\n4. K-mean clustering of box dimensions: Different from faster R-CNN that uses hand-picked sizes of anchor boxes, YOLOv2 runs k-mean clustering on the training data to find good priors on anchor box dimensions. The distance metric is designed to rely on IoU scores:\n $$ \\text{dist}(x, c_i) = 1 - \\text{IoU}(x, c_i), i=1,\\dots,k $$ where $x$ is a ground truth box candidate and $c_i$ is one of the centroids. The best number of centroids (anchor boxes) $k$ can be chosen by the elbow method.\nThe anchor boxes generated by clustering provide better average IoU conditioned on a fixed number of boxes.\n5. Direct location prediction: YOLOv2 formulates the bounding box prediction in a way that it would not diverge from the center location too much. If the box location prediction can place the box in any part of the image, like in regional proposal network, the model training could become unstable.\nGiven the anchor box of size $(p_w, p_h)$ at the grid cell with its top left corner at $(c_x, c_y)$, the model predicts the offset and the scale, $(t_x, t_y, t_w, t_h)$ and the corresponding predicted bounding box $b$ has center $(b_x, b_y)$ and size $(b_w, b_h)$. The confidence score is the sigmoid ($\\sigma$) of another output $t_o$.\n $$ \\begin{aligned} b_x \u0026= \\sigma(t_x) + c_x\\\\ b_y \u0026= \\sigma(t_y) + c_y\\\\ b_w \u0026= p_w e^{t_w}\\\\ b_h \u0026= p_h e^{t_h}\\\\ \\text{Pr}(\\text{object}) \u0026\\cdot \\text{IoU}(b, \\text{object}) = \\sigma(t_o) \\end{aligned} $$ Fig. 7. YOLOv2 bounding box location prediction. (Image source: original paper) 6. Add fine-grained features: YOLOv2 adds a passthrough layer to bring fine-grained features from an earlier layer to the last output layer. The mechanism of this passthrough layer is similar to identity mappings in ResNet to extract higher-dimensional features from previous layers. This leads to 1% performance increase.\n7. Multi-scale training: In order to train the model to be robust to input images of different sizes, a new size of input dimension is randomly sampled every 10 batches. Since conv layers of YOLOv2 downsample the input dimension by a factor of 32, the newly sampled size is a multiple of 32.\n8. Light-weighted base model: To make prediction even faster, YOLOv2 adopts a light-weighted base model, DarkNet-19, which has 19 conv layers and 5 max-pooling layers. The key point is to insert avg poolings and 1x1 conv filters between 3x3 conv layers.\nYOLO9000: Rich Dataset Training Because drawing bounding boxes on images for object detection is much more expensive than tagging images for classification, the paper proposed a way to combine small object detection dataset with large ImageNet so that the model can be exposed to a much larger number of object categories. The name of YOLO9000 comes from the top 9000 classes in ImageNet. During joint training, if an input image comes from the classification dataset, it only backpropagates the classification loss.\nThe detection dataset has much fewer and more general labels and, moreover, labels cross multiple datasets are often not mutually exclusive. For example, ImageNet has a label “Persian cat” while in COCO the same image would be labeled as “cat”. Without mutual exclusiveness, it does not make sense to apply softmax over all the classes.\nIn order to efficiently merge ImageNet labels (1000 classes, fine-grained) with COCO/PASCAL (\u0026lt; 100 classes, coarse-grained), YOLO9000 built a hierarchical tree structure with reference to WordNet so that general labels are closer to the root and the fine-grained class labels are leaves. In this way, \u0026ldquo;cat\u0026rdquo; is the parent node of \u0026ldquo;Persian cat\u0026rdquo;.\nFig. 8. The WordTree hierarchy merges labels from COCO and ImageNet. Blue nodes are COCO labels and red nodes are ImageNet labels. (Image source: original paper) To predict the probability of a class node, we can follow the path from the node to the root:\nPr(\u0026quot;persian cat\u0026quot; | contain a \u0026quot;physical object\u0026quot;) = Pr(\u0026quot;persian cat\u0026quot; | \u0026quot;cat\u0026quot;) Pr(\u0026quot;cat\u0026quot; | \u0026quot;animal\u0026quot;) Pr(\u0026quot;animal\u0026quot; | \u0026quot;physical object\u0026quot;) Pr(contain a \u0026quot;physical object\u0026quot;) # confidence score. Note that Pr(contain a \u0026quot;physical object\u0026quot;) is the confidence score, predicted separately in the bounding box detection pipeline. The path of conditional probability prediction can stop at any step, depending on which labels are available.\nRetinaNet The RetinaNet (Lin et al., 2018) is a one-stage dense object detector. Two crucial building blocks are featurized image pyramid and the use of focal loss.\nFocal Loss One issue for object detection model training is an extreme imbalance between background that contains no object and foreground that holds objects of interests. Focal loss is designed to assign more weights on hard, easily misclassified examples (i.e. background with noisy texture or partial object) and to down-weight easy examples (i.e. obviously empty background).\nStarting with a normal cross entropy loss for binary classification,\n $$ \\text{CE}(p, y) = -y\\log p - (1-y)\\log(1-p) $$ where $y \\in \\{0, 1\\}$ is a ground truth binary label, indicating whether a bounding box contains a object, and $p \\in [0, 1]$ is the predicted probability of objectiveness (aka confidence score).\nFor notational convenience,\n $$ \\text{let } p_t = \\begin{cases} p \u0026 \\text{if } y = 1\\\\ 1-p \u0026 \\text{otherwise} \\end{cases}, \\text{then } \\text{CE}(p, y)=\\text{CE}(p_t) = -\\log p_t $$ Easily classified examples with large $p_t \\gg 0.5$, that is, when $p$ is very close to 0 (when y=0) or 1 (when y=1), can incur a loss with non-trivial magnitude. Focal loss explicitly adds a weighting factor $(1-p_t)^\\gamma, \\gamma \\geq 0$ to each term in cross entropy so that the weight is small when $p_t$ is large and therefore easy examples are down-weighted.\n $$ \\text{FL}(p_t) = -(1-p_t)^\\gamma \\log p_t $$ Fig. 9. The focal loss focuses less on easy examples with a factor of $(1-p\\_t)^\\gamma$. (Image source: original paper) For a better control of the shape of the weighting function (see Fig. 10.), RetinaNet uses an $\\alpha$-balanced variant of the focal loss, where $\\alpha=0.25, \\gamma=2$ works the best.\n $$ \\text{FL}(p_t) = -\\alpha (1-p_t)^\\gamma \\log p_t $$ Fig. 10. The plot of focal loss weights $\\alpha (1-p\\_t)^\\gamma$ as a function of $p\\_t$, given different values of $\\alpha$ and $\\gamma$. Featurized Image Pyramid The featurized image pyramid (Lin et al., 2017) is the backbone network for RetinaNet. Following the same approach by image pyramid in SSD, featurized image pyramids provide a basic vision component for object detection at different scales.\nThe key idea of feature pyramid network is demonstrated in Fig. 11. The base structure contains a sequence of pyramid levels, each corresponding to one network stage. One stage contains multiple convolutional layers of the same size and the stage sizes are scaled down by a factor of 2. Let\u0026rsquo;s denote the last layer of the $i$-th stage as $C_i$.\nFig. 11. The illustration of the featurized image pyramid module. (Replot based on figure 3 in FPN paper) Two pathways connect conv layers:\n Bottom-up pathway is the normal feedforward computation. Top-down pathway goes in the inverse direction, adding coarse but semantically stronger feature maps back into the previous pyramid levels of a larger size via lateral connections. First, the higher-level features are upsampled spatially coarser to be 2x larger. For image upscaling, the paper used nearest neighbor upsampling. While there are many image upscaling algorithms such as using deconv, adopting another image scaling method might or might not improve the performance of RetinaNet. The larger feature map undergoes a 1x1 conv layer to reduce the channel dimension. Finally, these two feature maps are merged by element-wise addition. The lateral connections only happen at the last layer in stages, denoted as $\\{C_i\\}$, and the process continues until the finest (largest) merged feature map is generated. The prediction is made out of every merged map after a 3x3 conv layer, $\\{P_i\\}$. According to ablation studies, the importance rank of components of the featurized image pyramid design is as follows: 1x1 lateral connection \u0026gt; detect object across multiple layers \u0026gt; top-down enrichment \u0026gt; pyramid representation (compared to only check the finest layer).\nModel Architecture The featurized pyramid is constructed on top of the ResNet architecture. Recall that ResNet has 5 conv blocks (= network stages / pyramid levels). The last layer of the $i$-th pyramid level, $C_i$, has resolution $2^i$ lower than the raw input dimension.\nRetinaNet utilizes feature pyramid levels $P_3$ to $P_7$:\n $P_3$ to $P_5$ are computed from the corresponding ResNet residual stage from $C_3$ to $C_5$. They are connected by both top-down and bottom-up pathways. $P_6$ is obtained via a 3×3 stride-2 conv on top of $C_5$ $P_7$ applies ReLU and a 3×3 stride-2 conv on $P_6$. Adding higher pyramid levels on ResNet improves the performance for detecting large objects.\nSame as in SSD, detection happens in all pyramid levels by making a prediction out of every merged feature map. Because predictions share the same classifier and the box regressor, they are all formed to have the same channel dimension d=256.\nThere are A=9 anchor boxes per level:\n The base size corresponds to areas of $32^2$ to $512^2$ pixels on $P_3$ to $P_7$ respectively. There are three size ratios, $\\{2^0, 2^{1/3}, 2^{2/3}\\}$. For each size, there are three aspect ratios {1/2, 1, 2}. As usual, for each anchor box, the model outputs a class probability for each of $K$ classes in the classification subnet and regresses the offset from this anchor box to the nearest ground truth object in the box regression subnet. The classification subnet adopts the focal loss introduced above.\nFig. 12. The RetinaNet model architecture uses a FPN backbone on top of ResNet. (Image source: the FPN paper) YOLOv3 YOLOv3 is created by applying a bunch of design tricks on YOLOv2. The changes are inspired by recent advances in the object detection world.\nHere are a list of changes:\n1. Logistic regression for confidence scores: YOLOv3 predicts an confidence score for each bounding box using logistic regression, while YOLO and YOLOv2 uses sum of squared errors for classification terms (see the loss function above). Linear regression of offset prediction leads to a decrease in mAP.\n2. No more softmax for class prediction: When predicting class confidence, YOLOv3 uses multiple independent logistic classifier for each class rather than one softmax layer. This is very helpful especially considering that one image might have multiple labels and not all the labels are guaranteed to be mutually exclusive.\n3. Darknet + ResNet as the base model: The new Darknet-53 still relies on successive 3x3 and 1x1 conv layers, just like the original dark net architecture, but has residual blocks added.\n4. Multi-scale prediction: Inspired by image pyramid, YOLOv3 adds several conv layers after the base feature extractor model and makes prediction at three different scales among these conv layers. In this way, it has to deal with many more bounding box candidates of various sizes overall.\n5. Skip-layer concatenation: YOLOv3 also adds cross-layer connections between two prediction layers (except for the output layer) and earlier finer-grained feature maps. The model first up-samples the coarse feature maps and then merges it with the previous features by concatenation. The combination with finer-grained information makes it better at detecting small objects.\nInterestingly, focal loss does not help YOLOv3, potentially it might be due to the usage of $\\lambda_\\text{noobj}$ and $\\lambda_\\text{coord}$ \u0026mdash; they increase the loss from bounding box location predictions and decrease the loss from confidence predictions for background boxes.\nOverall YOLOv3 performs better and faster than SSD, and worse than RetinaNet but 3.8x faster.\nFig. 13. The comparison of various fast object detection models on speed and mAP performance. (Image source: focal loss paper with additional labels from the YOLOv3 paper.) Cited as:\n@article{weng2018detection4, title = \u0026quot;Object Detection Part 4: Fast Detection Models\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2018\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2018-12-27-object-recognition-part-4/\u0026quot; } Reference [1] Joseph Redmon, et al. \u0026ldquo;You only look once: Unified, real-time object detection.\u0026quot; CVPR 2016.\n[2] Joseph Redmon and Ali Farhadi. \u0026ldquo;YOLO9000: Better, Faster, Stronger.\u0026quot; CVPR 2017.\n[3] Joseph Redmon, Ali Farhadi. \u0026ldquo;YOLOv3: An incremental improvement.\u0026quot;.\n[4] Wei Liu et al. \u0026ldquo;SSD: Single Shot MultiBox Detector.\u0026quot; ECCV 2016.\n[5] Tsung-Yi Lin, et al. \u0026ldquo;Feature Pyramid Networks for Object Detection.\u0026quot; CVPR 2017.\n[6] Tsung-Yi Lin, et al. \u0026ldquo;Focal Loss for Dense Object Detection.\u0026quot; IEEE transactions on pattern analysis and machine intelligence, 2018.\n[7] \u0026ldquo;What\u0026rsquo;s new in YOLO v3?\u0026quot; by Ayoosh Kathuria on \u0026ldquo;Towards Data Science\u0026rdquo;, Apr 23, 2018.\n","permalink":"https://lilianweng.github.io/posts/2018-12-27-object-recognition-part-4/","summary":"In Part 3, we have reviewed models in the R-CNN family. All of them are region-based object detection algorithms. They can achieve high accuracy but could be too slow for certain applications such as autonomous driving. In Part 4, we only focus on fast object detection models, including SSD, RetinaNet, and models in the YOLO family.\nLinks to all the posts in the series: [Part 1] [Part 2] [Part 3] [Part 4].","title":"Object Detection Part 4: Fast Detection Models"},{"content":"[Updated on 2019-10-01: thanks to Tianhao, we have this post translated in Chinese!]\nA good machine learning model often requires training with a large number of samples. Humans, in contrast, learn new concepts and skills much faster and more efficiently. Kids who have seen cats and birds only a few times can quickly tell them apart. People who know how to ride a bike are likely to discover the way to ride a motorcycle fast with little or even no demonstration. Is it possible to design a machine learning model with similar properties \u0026mdash; learning new concepts and skills fast with a few training examples? That\u0026rsquo;s essentially what meta-learning aims to solve.\nWe expect a good meta-learning model capable of well adapting or generalizing to new tasks and new environments that have never been encountered during training time. The adaptation process, essentially a mini learning session, happens during test but with a limited exposure to the new task configurations. Eventually, the adapted model can complete new tasks. This is why meta-learning is also known as learning to learn.\nThe tasks can be any well-defined family of machine learning problems: supervised learning, reinforcement learning, etc. For example, here are a couple concrete meta-learning tasks:\n A classifier trained on non-cat images can tell whether a given image contains a cat after seeing a handful of cat pictures. A game bot is able to quickly master a new game. A mini robot completes the desired task on an uphill surface during test even through it was only trained in a flat surface environment. Define the Meta-Learning Problem In this post, we focus on the case when each desired task is a supervised learning problem like image classification. There is a lot of interesting literature on meta-learning with reinforcement learning problems (aka \u0026ldquo;Meta Reinforcement Learning\u0026rdquo;), but we would not cover them here.\nA Simple View A good meta-learning model should be trained over a variety of learning tasks and optimized for the best performance on a distribution of tasks, including potentially unseen tasks. Each task is associated with a dataset $\\mathcal{D}$, containing both feature vectors and true labels. The optimal model parameters are:\n $$ \\theta^* = \\arg\\min_\\theta \\mathbb{E}_{\\mathcal{D}\\sim p(\\mathcal{D})} [\\mathcal{L}_\\theta(\\mathcal{D})] $$ It looks very similar to a normal learning task, but one dataset is considered as one data sample.\nFew-shot classification is an instantiation of meta-learning in the field of supervised learning. The dataset $\\mathcal{D}$ is often split into two parts, a support set $S$ for learning and a prediction set $B$ for training or testing, $\\mathcal{D}=\\langle S, B\\rangle$. Often we consider a K-shot N-class classification task: the support set contains K labelled examples for each of N classes.\nFig. 1. An example of 4-shot 2-class image classification. (Image thumbnails are from Pinterest) Training in the Same Way as Testing A dataset $\\mathcal{D}$ contains pairs of feature vectors and labels, $\\mathcal{D} = \\{(\\mathbf{x}_i, y_i)\\}$ and each label belongs to a known label set $\\mathcal{L}^\\text{label}$. Let\u0026rsquo;s say, our classifier $f_\\theta$ with parameter $\\theta$ outputs a probability of a data point belonging to the class $y$ given the feature vector $\\mathbf{x}$, $P_\\theta(y\\vert\\mathbf{x})$.\nThe optimal parameters should maximize the probability of true labels across multiple training batches $B \\subset \\mathcal{D}$:\n $$ \\begin{aligned} \\theta^* \u0026= {\\arg\\max}_{\\theta} \\mathbb{E}_{(\\mathbf{x}, y)\\in \\mathcal{D}}[P_\\theta(y \\vert \\mathbf{x})] \u0026\\\\ \\theta^* \u0026= {\\arg\\max}_{\\theta} \\mathbb{E}_{B\\subset \\mathcal{D}}[\\sum_{(\\mathbf{x}, y)\\in B}P_\\theta(y \\vert \\mathbf{x})] \u0026 \\scriptstyle{\\text{; trained with mini-batches.}} \\end{aligned} $$ In few-shot classification, the goal is to reduce the prediction error on data samples with unknown labels given a small support set for \u0026ldquo;fast learning\u0026rdquo; (think of how \u0026ldquo;fine-tuning\u0026rdquo; works). To make the training process mimics what happens during inference, we would like to \u0026ldquo;fake\u0026rdquo; datasets with a subset of labels to avoid exposing all the labels to the model and modify the optimization procedure accordingly to encourage fast learning:\n Sample a subset of labels, $L\\subset\\mathcal{L}^\\text{label}$. Sample a support set $S^L \\subset \\mathcal{D}$ and a training batch $B^L \\subset \\mathcal{D}$. Both of them only contain data points with labels belonging to the sampled label set $L$, $y \\in L, \\forall (x, y) \\in S^L, B^L$. The support set is part of the model input. The final optimization uses the mini-batch $B^L$ to compute the loss and update the model parameters through backpropagation, in the same way as how we use it in the supervised learning. You may consider each pair of sampled dataset $(S^L, B^L)$ as one data point. The model is trained such that it can generalize to other datasets. Symbols in red are added for meta-learning in addition to the supervised learning objective.\n $$ \\theta = \\arg\\max_\\theta \\color{red}{E_{L\\subset\\mathcal{L}}[} E_{\\color{red}{S^L \\subset\\mathcal{D}, }B^L \\subset\\mathcal{D}} [\\sum_{(x, y)\\in B^L} P_\\theta(x, y\\color{red}{, S^L})] \\color{red}{]} $$ The idea is to some extent similar to using a pre-trained model in image classification (ImageNet) or language modeling (big text corpora) when only a limited set of task-specific data samples are available. Meta-learning takes this idea one step further, rather than fine-tuning according to one down-steam task, it optimizes the model to be good at many, if not all.\nLearner and Meta-Learner Another popular view of meta-learning decomposes the model update into two stages:\n A classifier $f_\\theta$ is the \u0026ldquo;learner\u0026rdquo; model, trained for operating a given task; In the meantime, a optimizer $g_\\phi$ learns how to update the learner model\u0026rsquo;s parameters via the support set $S$, $\\theta' = g_\\phi(\\theta, S)$. Then in final optimization step, we need to update both $\\theta$ and $\\phi$ to maximize:\n $$ \\mathbb{E}_{L\\subset\\mathcal{L}}[ \\mathbb{E}_{S^L \\subset\\mathcal{D}, B^L \\subset\\mathcal{D}} [\\sum_{(\\mathbf{x}, y)\\in B^L} P_{g_\\phi(\\theta, S^L)}(y \\vert \\mathbf{x})]] $$ Common Approaches There are three common approaches to meta-learning: metric-based, model-based, and optimization-based. Oriol Vinyals has a nice summary in his talk at meta-learning symposium @ NIPS 2018:\n| \u0026mdash;\u0026mdash;\u0026mdash;\u0026mdash;- | \u0026mdash;\u0026mdash;\u0026mdash;\u0026mdash;- | \u0026mdash;\u0026mdash;\u0026mdash;\u0026mdash;- | \u0026mdash;\u0026mdash;\u0026mdash;\u0026mdash;- |\n Model-based Metric-based Optimization-based Key idea RNN; memory Metric learning Gradient descent How $P_\\theta(y \\vert \\mathbf{x})$ is modeled? $f_\\theta(\\mathbf{x}, S)$ $\\sum_{(\\mathbf{x}_i, y_i) \\in S} k_\\theta(\\mathbf{x}, \\mathbf{x}_i)y_i$ (*) $P_{g_\\phi(\\theta, S^L)}(y \\vert \\mathbf{x})$ (*) $k_\\theta$ is a kernel function measuring the similarity between $\\mathbf{x}_i$ and $\\mathbf{x}$.\nNext we are gonna review classic models in each approach.\nMetric-Based The core idea in metric-based meta-learning is similar to nearest neighbors algorithms (i.e., k-NN classificer and k-means clustering) and kernel density estimation. The predicted probability over a set of known labels $y$ is a weighted sum of labels of support set samples. The weight is generated by a kernel function $k_\\theta$, measuring the similarity between two data samples.\n $$ P_\\theta(y \\vert \\mathbf{x}, S) = \\sum_{(\\mathbf{x}_i, y_i) \\in S} k_\\theta(\\mathbf{x}, \\mathbf{x}_i)y_i $$ To learn a good kernel is crucial to the success of a metric-based meta-learning model. Metric learning is well aligned with this intention, as it aims to learn a metric or distance function over objects. The notion of a good metric is problem-dependent. It should represent the relationship between inputs in the task space and facilitate problem solving.\nAll the models introduced below learn embedding vectors of input data explicitly and use them to design proper kernel functions.\nConvolutional Siamese Neural Network The Siamese Neural Network is composed of two twin networks and their outputs are jointly trained on top with a function to learn the relationship between pairs of input data samples. The twin networks are identical, sharing the same weights and network parameters. In other words, both refer to the same embedding network that learns an efficient embedding to reveal relationship between pairs of data points.\nKoch, Zemel \u0026amp; Salakhutdinov (2015) proposed a method to use the siamese neural network to do one-shot image classification. First, the siamese network is trained for a verification task for telling whether two input images are in the same class. It outputs the probability of two images belonging to the same class. Then, during test time, the siamese network processes all the image pairs between a test image and every image in the support set. The final prediction is the class of the support image with the highest probability.\nFig. 2. The architecture of convolutional siamese neural network for few-show image classification. First, convolutional siamese network learns to encode two images into feature vectors via a embedding function $f_\\theta$ which contains a couple of convolutional layers. The L1-distance between two embeddings is $\\vert f_\\theta(\\mathbf{x}_i) - f_\\theta(\\mathbf{x}_j) \\vert$. The distance is converted to a probability $p$ by a linear feedforward layer and sigmoid. It is the probability of whether two images are drawn from the same class. Intuitively the loss is cross entropy because the label is binary. $$ \\begin{aligned} p(\\mathbf{x}_i, \\mathbf{x}_j) \u0026= \\sigma(\\mathbf{W}\\vert f_\\theta(\\mathbf{x}_i) - f_\\theta(\\mathbf{x}_j) \\vert) \\\\ \\mathcal{L}(B) \u0026= \\sum_{(\\mathbf{x}_i, \\mathbf{x}_j, y_i, y_j)\\in B} \\mathbf{1}_{y_i=y_j}\\log p(\\mathbf{x}_i, \\mathbf{x}_j) + (1-\\mathbf{1}_{y_i=y_j})\\log (1-p(\\mathbf{x}_i, \\mathbf{x}_j)) \\end{aligned} $$ Images in the training batch $B$ can be augmented with distortion. Of course, you can replace the L1 distance with other distance metric, L2, cosine, etc. Just make sure they are differential and then everything else works the same.\nGiven a support set $S$ and a test image $\\mathbf{x}$, the final predicted class is:\n $$ \\hat{c}_S(\\mathbf{x}) = c(\\arg\\max_{\\mathbf{x}_i \\in S} P(\\mathbf{x}, \\mathbf{x}_i)) $$ where $c(\\mathbf{x})$ is the class label of an image $\\mathbf{x}$ and $\\hat{c}(.)$ is the predicted label.\nThe assumption is that the learned embedding can be generalized to be useful for measuring the distance between images of unknown categories. This is the same assumption behind transfer learning via the adoption of a pre-trained model; for example, the convolutional features learned in the model pre-trained with ImageNet are expected to help other image tasks. However, the benefit of a pre-trained model decreases when the new task diverges from the original task that the model was trained on.\nMatching Networks The task of Matching Networks (Vinyals et al., 2016) is to learn a classifier $c_S$ for any given (small) support set $S=\\{x_i, y_i\\}_{i=1}^k$ (k-shot classification). This classifier defines a probability distribution over output labels $y$ given a test example $\\mathbf{x}$. Similar to other metric-based models, the classifier output is defined as a sum of labels of support samples weighted by attention kernel $a(\\mathbf{x}, \\mathbf{x}_i)$ - which should be proportional to the similarity between $\\mathbf{x}$ and $\\mathbf{x}_i$.\nFig. 3. The architecture of Matching Networks. (Image source: original paper) $$ c_S(\\mathbf{x}) = P(y \\vert \\mathbf{x}, S) = \\sum_{i=1}^k a(\\mathbf{x}, \\mathbf{x}_i) y_i \\text{, where }S=\\{(\\mathbf{x}_i, y_i)\\}_{i=1}^k $$ The attention kernel depends on two embedding functions, $f$ and $g$, for encoding the test sample and the support set samples respectively. The attention weight between two data points is the cosine similarity, $\\text{cosine}(.)$, between their embedding vectors, normalized by softmax:\n $$ a(\\mathbf{x}, \\mathbf{x}_i) = \\frac{\\exp(\\text{cosine}(f(\\mathbf{x}), g(\\mathbf{x}_i))}{\\sum_{j=1}^k\\exp(\\text{cosine}(f(\\mathbf{x}), g(\\mathbf{x}_j))} $$ Simple Embedding In the simple version, an embedding function is a neural network with a single data sample as input. Potentially we can set $f=g$.\nFull Context Embeddings The embedding vectors are critical inputs for building a good classifier. Taking a single data point as input might not be enough to efficiently gauge the entire feature space. Therefore, the Matching Network model further proposed to enhance the embedding functions by taking as input the whole support set $S$ in addition to the original input, so that the learned embedding can be adjusted based on the relationship with other support samples.\n $g_\\theta(\\mathbf{x}_i, S)$ uses a bidirectional LSTM to encode $\\mathbf{x}_i$ in the context of the entire support set $S$.\n $f_\\theta(\\mathbf{x}, S)$ encodes the test sample $\\mathbf{x}$ visa an LSTM with read attention over the support set $S$.\n First the test sample goes through a simple neural network, such as a CNN, to extract basic features, $f'(\\mathbf{x})$. Then an LSTM is trained with a read attention vector over the support set as part of the hidden state: $$ \\begin{aligned} \\hat{\\mathbf{h}}_t, \\mathbf{c}_t \u0026= \\text{LSTM}(f'(\\mathbf{x}), [\\mathbf{h}_{t-1}, \\mathbf{r}_{t-1}], \\mathbf{c}_{t-1}) \\\\ \\mathbf{h}_t \u0026= \\hat{\\mathbf{h}}_t + f'(\\mathbf{x}) \\\\ \\mathbf{r}_{t-1} \u0026= \\sum_{i=1}^k a(\\mathbf{h}_{t-1}, g(\\mathbf{x}_i)) g(\\mathbf{x}_i) \\\\ a(\\mathbf{h}_{t-1}, g(\\mathbf{x}_i)) \u0026= \\text{softmax}(\\mathbf{h}_{t-1}^\\top g(\\mathbf{x}_i)) = \\frac{\\exp(\\mathbf{h}_{t-1}^\\top g(\\mathbf{x}_i))}{\\sum_{j=1}^k \\exp(\\mathbf{h}_{t-1}^\\top g(\\mathbf{x}_j))} \\end{aligned} $$ Eventually $f(\\mathbf{x}, S)=\\mathbf{h}_K$ if we do K steps of \u0026ldquo;read\u0026rdquo;. This embedding method is called \u0026ldquo;Full Contextual Embeddings (FCE)\u0026rdquo;. Interestingly it does help improve the performance on a hard task (few-shot classification on mini ImageNet), but makes no difference on a simple task (Omniglot).\nThe training process in Matching Networks is designed to match inference at test time, see the details in the earlier section. It is worthy of mentioning that the Matching Networks paper refined the idea that training and testing conditions should match.\n $$ \\theta^* = \\arg\\max_\\theta \\mathbb{E}_{L\\subset\\mathcal{L}}[ \\mathbb{E}_{S^L \\subset\\mathcal{D}, B^L \\subset\\mathcal{D}} [\\sum_{(\\mathbf{x}, y)\\in B^L} P_\\theta(y\\vert\\mathbf{x}, S^L)]] $$ Relation Network Relation Network (RN) (Sung et al., 2018) is similar to siamese network but with a few differences:\n The relationship is not captured by a simple L1 distance in the feature space, but predicted by a CNN classifier $g_\\phi$. The relation score between a pair of inputs, $\\mathbf{x}_i$ and $\\mathbf{x}_j$, is $r_{ij} = g_\\phi([\\mathbf{x}_i, \\mathbf{x}_j])$ where $[.,.]$ is concatenation. The objective function is MSE loss instead of cross-entropy, because conceptually RN focuses more on predicting relation scores which is more like regression, rather than binary classification, $\\mathcal{L}(B) = \\sum_{(\\mathbf{x}_i, \\mathbf{x}_j, y_i, y_j)\\in B} (r_{ij} - \\mathbf{1}_{y_i=y_j})^2$. Fig. 4. Relation Network architecture for a 5-way 1-shot problem with one query example. (Image source: original paper) (Note: There is another Relation Network for relational reasoning, proposed by DeepMind. Don\u0026rsquo;t get confused.)\nPrototypical Networks Prototypical Networks (Snell, Swersky \u0026amp; Zemel, 2017) use an embedding function $f_\\theta$ to encode each input into a $M$-dimensional feature vector. A prototype feature vector is defined for every class $c \\in \\mathcal{C}$, as the mean vector of the embedded support data samples in this class.\n $$ \\mathbf{v}_c = \\frac{1}{|S_c|} \\sum_{(\\mathbf{x}_i, y_i) \\in S_c} f_\\theta(\\mathbf{x}_i) $$ Fig. 5. Prototypical networks in the few-shot and zero-shot scenarios. (Image source: original paper) The distribution over classes for a given test input $\\mathbf{x}$ is a softmax over the inverse of distances between the test data embedding and prototype vectors.\n $$ P(y=c\\vert\\mathbf{x})=\\text{softmax}(-d_\\varphi(f_\\theta(\\mathbf{x}), \\mathbf{v}_c)) = \\frac{\\exp(-d_\\varphi(f_\\theta(\\mathbf{x}), \\mathbf{v}_c))}{\\sum_{c' \\in \\mathcal{C}}\\exp(-d_\\varphi(f_\\theta(\\mathbf{x}), \\mathbf{v}_{c'}))} $$ where $d_\\varphi$ can be any distance function as long as $\\varphi$ is differentiable. In the paper, they used the squared euclidean distance.\nThe loss function is the negative log-likelihood: $\\mathcal{L}(\\theta) = -\\log P_\\theta(y=c\\vert\\mathbf{x})$.\nModel-Based Model-based meta-learning models make no assumption on the form of $P_\\theta(y\\vert\\mathbf{x})$. Rather it depends on a model designed specifically for fast learning \u0026mdash; a model that updates its parameters rapidly with a few training steps. This rapid parameter update can be achieved by its internal architecture or controlled by another meta-learner model.\nMemory-Augmented Neural Networks A family of model architectures use external memory storage to facilitate the learning process of neural networks, including Neural Turing Machines and Memory Networks. With an explicit storage buffer, it is easier for the network to rapidly incorporate new information and not to forget in the future. Such a model is known as MANN, short for \u0026ldquo;Memory-Augmented Neural Network\u0026rdquo;. Note that recurrent neural networks with only internal memory such as vanilla RNN or LSTM are not MANNs.\nBecause MANN is expected to encode new information fast and thus to adapt to new tasks after only a few samples, it fits well for meta-learning. Taking the Neural Turing Machine (NTM) as the base model, Santoro et al. (2016) proposed a set of modifications on the training setup and the memory retrieval mechanisms (or \u0026ldquo;addressing mechanisms\u0026rdquo;, deciding how to assign attention weights to memory vectors). Please go through the NTM section in my other post first if you are not familiar with this matter before reading forward.\nAs a quick recap, NTM couples a controller neural network with external memory storage. The controller learns to read and write memory rows by soft attention, while the memory serves as a knowledge repository. The attention weights are generated by its addressing mechanism: content-based + location based.\nFig. 6. The architecture of Neural Turing Machine (NTM). The memory at time t, $\\mathbf{M}\\_t$ is a matrix of size $N \\times M$, containing N vector rows and each has M dimensions. MANN for Meta-Learning To use MANN for meta-learning tasks, we need to train it in a way that the memory can encode and capture information of new tasks fast and, in the meantime, any stored representation is easily and stably accessible.\nThe training described in Santoro et al., 2016 happens in an interesting way so that the memory is forced to hold information for longer until the appropriate labels are presented later. In each training episode, the truth label $y_t$ is presented with one step offset, $(\\mathbf{x}_{t+1}, y_t)$: it is the true label for the input at the previous time step t, but presented as part of the input at time step t+1.\nFig. 7. Task setup in MANN for meta-learning (Image source: original paper). In this way, MANN is motivated to memorize the information of a new dataset, because the memory has to hold the current input until the label is present later and then retrieve the old information to make a prediction accordingly.\nNext let us see how the memory is updated for efficient information retrieval and storage.\nAddressing Mechanism for Meta-Learning Aside from the training process, a new pure content-based addressing mechanism is utilized to make the model better suitable for meta-learning.\n\u0026raquo; How to read from memory? The read attention is constructed purely based on the content similarity.\nFirst, a key feature vector $\\mathbf{k}_t$ is produced at the time step t by the controller as a function of the input $\\mathbf{x}$. Similar to NTM, a read weighting vector $\\mathbf{w}_t^r$ of N elements is computed as the cosine similarity between the key vector and every memory vector row, normalized by softmax. The read vector $\\mathbf{r}_t$ is a sum of memory records weighted by such weightings:\n $$ \\mathbf{r}_i = \\sum_{i=1}^N w_t^r(i)\\mathbf{M}_t(i) \\text{, where } w_t^r(i) = \\text{softmax}(\\frac{\\mathbf{k}_t \\cdot \\mathbf{M}_t(i)}{\\|\\mathbf{k}_t\\| \\cdot \\|\\mathbf{M}_t(i)\\|}) $$ where $M_t$ is the memory matrix at time t and $M_t(i)$ is the i-th row in this matrix.\n\u0026raquo; How to write into memory? The addressing mechanism for writing newly received information into memory operates a lot like the cache replacement policy. The Least Recently Used Access (LRUA) writer is designed for MANN to better work in the scenario of meta-learning. A LRUA write head prefers to write new content to either the least used memory location or the most recently used memory location.\n Rarely used locations: so that we can preserve frequently used information (see LFU); The last used location: the motivation is that once a piece of information is retrieved once, it probably won\u0026rsquo;t be called again for a while (see MRU). There are many cache replacement algorithms and each of them could potentially replace the design here with better performance in different use cases. Furthermore, it would be a good idea to learn the memory usage pattern and addressing strategies rather than arbitrarily set it.\nThe preference of LRUA is carried out in a way that everything is differentiable:\n The usage weight $\\mathbf{w}^u_t$ at time t is a sum of current read and write vectors, in addition to the decayed last usage weight, $\\gamma \\mathbf{w}^u_{t-1}$, where $\\gamma$ is a decay factor. The write vector is an interpolation between the previous read weight (prefer \u0026ldquo;the last used location\u0026rdquo;) and the previous least-used weight (prefer \u0026ldquo;rarely used location\u0026rdquo;). The interpolation parameter is the sigmoid of a hyperparameter $\\alpha$. The least-used weight $\\mathbf{w}^{lu}$ is scaled according to usage weights $\\mathbf{w}_t^u$, in which any dimension remains at 1 if smaller than the n-th smallest element in the vector and 0 otherwise. $$ \\begin{aligned} \\mathbf{w}_t^u \u0026= \\gamma \\mathbf{w}_{t-1}^u + \\mathbf{w}_t^r + \\mathbf{w}_t^w \\\\ \\mathbf{w}_t^r \u0026= \\text{softmax}(\\text{cosine}(\\mathbf{k}_t, \\mathbf{M}_t(i))) \\\\ \\mathbf{w}_t^w \u0026= \\sigma(\\alpha)\\mathbf{w}_{t-1}^r + (1-\\sigma(\\alpha))\\mathbf{w}^{lu}_{t-1}\\\\ \\mathbf{w}_t^{lu} \u0026= \\mathbf{1}_{w_t^u(i) \\leq m(\\mathbf{w}_t^u, n)} \\text{, where }m(\\mathbf{w}_t^u, n)\\text{ is the }n\\text{-th smallest element in vector }\\mathbf{w}_t^u\\text{.} \\end{aligned} $$ Finally, after the least used memory location, indicated by $\\mathbf{w}_t^{lu}$, is set to zero, every memory row is updated:\n $$ \\mathbf{M}_t(i) = \\mathbf{M}_{t-1}(i) + w_t^w(i)\\mathbf{k}_t, \\forall i $$ Meta Networks Meta Networks (Munkhdalai \u0026amp; Yu, 2017), short for MetaNet, is a meta-learning model with architecture and training process designed for rapid generalization across tasks.\nFast Weights The rapid generalization of MetaNet relies on \u0026ldquo;fast weights\u0026rdquo;. There are a handful of papers on this topic, but I haven\u0026rsquo;t read all of them in detail and I failed to find a very concrete definition, only a vague agreement on the concept. Normally weights in the neural networks are updated by stochastic gradient descent in an objective function and this process is known to be slow. One faster way to learn is to utilize one neural network to predict the parameters of another neural network and the generated weights are called fast weights. In comparison, the ordinary SGD-based weights are named slow weights.\nIn MetaNet, loss gradients are used as meta information to populate models that learn fast weights. Slow and fast weights are combined to make predictions in neural networks.\nFig. 8. Combining slow and fast weights in a MLP. $\\bigoplus$ is element-wise sum. (Image source: original paper). Model Components Disclaimer: Below you will find my annotations are different from those in the paper. imo, the paper is poorly written, but the idea is still interesting. So I\u0026rsquo;m presenting the idea in my own language.\n Key components of MetaNet are:\n An embedding function $f_\\theta$, parameterized by $\\theta$, encodes raw inputs into feature vectors. Similar to Siamese Neural Network, these embeddings are trained to be useful for telling whether two inputs are of the same class (verification task). A base learner model $g_\\phi$, parameterized by weights $\\phi$, completes the actual learning task. If we stop here, it looks just like Relation Network. MetaNet, in addition, explicitly models the fast weights of both functions and then aggregates them back into the model (See Fig. 8).\nTherefore we need additional two functions to output fast weights for $f$ and $g$ respectively.\n $F_w$: a LSTM parameterized by $w$ for learning fast weights $\\theta^+$ of the embedding function $f$. It takes as input gradients of $f$\u0026rsquo;s embedding loss for verification task. $G_v$: a neural network parameterized by $v$ learning fast weights $\\phi^+$ for the base learner $g$ from its loss gradients. In MetaNet, the learner\u0026rsquo;s loss gradients are viewed as the meta information of the task. Ok, now let\u0026rsquo;s see how meta networks are trained. The training data contains multiple pairs of datasets: a support set $S=\\{\\mathbf{x}'_i, y'_i\\}_{i=1}^K$ and a test set $U=\\{\\mathbf{x}_i, y_i\\}_{i=1}^L$. Recall that we have four networks and four sets of model parameters to learn, $(\\theta, \\phi, w, v)$.\nFig.9. The MetaNet architecture. Training Process Sample a random pair of inputs at each time step t from the support set $S$, $(\\mathbf{x}'_i, y'_i)$ and $(\\mathbf{x}'_j, y_j)$. Let $\\mathbf{x}_{(t,1)}=\\mathbf{x}'_i$ and $\\mathbf{x}_{(t,2)}=\\mathbf{x}'_j$. for $t = 1, \\dots, K$:\n a. Compute a loss for representation learning; i.e., cross entropy for the verification task: $\\mathcal{L}^\\text{emb}_t = \\mathbf{1}_{y'_i=y'_j} \\log P_t + (1 - \\mathbf{1}_{y'_i=y'_j})\\log(1 - P_t)\\text{, where }P_t = \\sigma(\\mathbf{W}\\vert f_\\theta(\\mathbf{x}_{(t,1)}) - f_\\theta(\\mathbf{x}_{(t,2)})\\vert)$ Compute the task-level fast weights: $\\theta^+ = F_w(\\nabla_\\theta \\mathcal{L}^\\text{emb}_1, \\dots, \\mathcal{L}^\\text{emb}_T)$\n Next go through examples in the support set $S$ and compute the example-level fast weights. Meanwhile, update the memory with learned representations. for $i=1, \\dots, K$:\n a. The base learner outputs a probability distribution: $P(\\hat{y}_i \\vert \\mathbf{x}_i) = g_\\phi(\\mathbf{x}_i)$ and the loss can be cross-entropy or MSE: $\\mathcal{L}^\\text{task}_i = y'_i \\log g_\\phi(\\mathbf{x}'_i) + (1- y'_i) \\log (1 - g_\\phi(\\mathbf{x}'_i))$ b. Extract meta information (loss gradients) of the task and compute the example-level fast weights: $\\phi_i^+ = G_v(\\nabla_\\phi\\mathcal{L}^\\text{task}_i)$ Then store $\\phi^+_i$ into $i$-th location of the \u0026ldquo;value\u0026rdquo; memory $\\mathbf{M}$. d. Encode the support sample into a task-specific input representation using both slow and fast weights: $r'_i = f_{\\theta, \\theta^+}(\\mathbf{x}'_i)$ Then store $r'_i$ into $i$-th location of the \u0026ldquo;key\u0026rdquo; memory $\\mathbf{R}$. Finally it is the time to construct the training loss using the test set $U=\\{\\mathbf{x}_i, y_i\\}_{i=1}^L$. Starts with $\\mathcal{L}_\\text{train}=0$: for $j=1, \\dots, L$:\n a. Encode the test sample into a task-specific input representation: $r_j = f_{\\theta, \\theta^+}(\\mathbf{x}_j)$ b. The fast weights are computed by attending to representations of support set samples in memory $\\mathbf{R}$. The attention function is of your choice. Here MetaNet uses cosine similarity: $$ \\begin{aligned} a_j \u0026= \\text{cosine}(\\mathbf{R}, r_j) = [\\frac{r'_1\\cdot r_j}{\\|r'_1\\|\\cdot\\|r_j\\|}, \\dots, \\frac{r'_N\\cdot r_j}{\\|r'_N\\|\\cdot\\|r_j\\|}]\\\\ \\phi^+_j \u0026= \\text{softmax}(a_j)^\\top \\mathbf{M} \\end{aligned} $$ c. Update the training loss: $\\mathcal{L}_\\text{train} \\leftarrow \\mathcal{L}_\\text{train} + \\mathcal{L}^\\text{task}(g_{\\phi, \\phi^+}(\\mathbf{x}_i), y_i) $ Update all the parameters $(\\theta, \\phi, w, v)$ using $\\mathcal{L}_\\text{train}$.\n Optimization-Based Deep learning models learn through backpropagation of gradients. However, the gradient-based optimization is neither designed to cope with a small number of training samples, nor to converge within a small number of optimization steps. Is there a way to adjust the optimization algorithm so that the model can be good at learning with a few examples? This is what optimization-based approach meta-learning algorithms intend for.\nLSTM Meta-Learner The optimization algorithm can be explicitly modeled. Ravi \u0026amp; Larochelle (2017) did so and named it \u0026ldquo;meta-learner\u0026rdquo;, while the original model for handling the task is called \u0026ldquo;learner\u0026rdquo;. The goal of the meta-learner is to efficiently update the learner\u0026rsquo;s parameters using a small support set so that the learner can adapt to the new task quickly.\nLet\u0026rsquo;s denote the learner model as $M_\\theta$ parameterized by $\\theta$, the meta-learner as $R_\\Theta$ with parameters $\\Theta$, and the loss function $\\mathcal{L}$.\nWhy LSTM? The meta-learner is modeled as a LSTM, because:\n There is similarity between the gradient-based update in backpropagation and the cell-state update in LSTM. Knowing a history of gradients benefits the gradient update; think about how momentum works. The update for the learner\u0026rsquo;s parameters at time step t with a learning rate $\\alpha_t$ is:\n $$ \\theta_t = \\theta_{t-1} - \\alpha_t \\nabla_{\\theta_{t-1}}\\mathcal{L}_t $$ It has the same form as the cell state update in LSTM, if we set forget gate $f_t=1$, input gate $i_t = \\alpha_t$, cell state $c_t = \\theta_t$, and new cell state $\\tilde{c}_t = -\\nabla_{\\theta_{t-1}}\\mathcal{L}_t$:\n $$ \\begin{aligned} c_t \u0026= f_t \\odot c_{t-1} + i_t \\odot \\tilde{c}_t\\\\ \u0026= \\theta_{t-1} - \\alpha_t\\nabla_{\\theta_{t-1}}\\mathcal{L}_t \\end{aligned} $$ While fixing $f_t=1$ and $i_t=\\alpha_t$ might not be the optimal, both of them can be learnable and adaptable to different datasets.\n $$ \\begin{aligned} f_t \u0026= \\sigma(\\mathbf{W}_f \\cdot [\\nabla_{\\theta_{t-1}}\\mathcal{L}_t, \\mathcal{L}_t, \\theta_{t-1}, f_{t-1}] + \\mathbf{b}_f) \u0026 \\scriptstyle{\\text{; how much to forget the old value of parameters.}}\\\\ i_t \u0026= \\sigma(\\mathbf{W}_i \\cdot [\\nabla_{\\theta_{t-1}}\\mathcal{L}_t, \\mathcal{L}_t, \\theta_{t-1}, i_{t-1}] + \\mathbf{b}_i) \u0026 \\scriptstyle{\\text{; corresponding to the learning rate at time step t.}}\\\\ \\tilde{\\theta}_t \u0026= -\\nabla_{\\theta_{t-1}}\\mathcal{L}_t \u0026\\\\ \\theta_t \u0026= f_t \\odot \\theta_{t-1} + i_t \\odot \\tilde{\\theta}_t \u0026\\\\ \\end{aligned} $$ Model Setup Fig. 10. How the learner $M\\_\\theta$ and the meta-learner $R\\_\\Theta$ are trained. (Image source: original paper with more annotations) The training process mimics what happens during test, since it has been proved to be beneficial in Matching Networks. During each training epoch, we first sample a dataset $\\mathcal{D} = (\\mathcal{D}_\\text{train}, \\mathcal{D}_\\text{test}) \\in \\hat{\\mathcal{D}}_\\text{meta-train}$ and then sample mini-batches out of $\\mathcal{D}_\\text{train}$ to update $\\theta$ for $T$ rounds. The final state of the learner parameter $\\theta_T$ is used to train the meta-learner on the test data $\\mathcal{D}_\\text{test}$.\nTwo implementation details to pay extra attention to:\n How to compress the parameter space in LSTM meta-learner? As the meta-learner is modeling parameters of another neural network, it would have hundreds of thousands of variables to learn. Following the idea of sharing parameters across coordinates, To simplify the training process, the meta-learner assumes that the loss $\\mathcal{L}_t$ and the gradient $\\nabla_{\\theta_{t-1}} \\mathcal{L}_t$ are independent. MAML MAML, short for Model-Agnostic Meta-Learning (Finn, et al. 2017) is a fairly general optimization algorithm, compatible with any model that learns through gradient descent.\nLet\u0026rsquo;s say our model is $f_\\theta$ with parameters $\\theta$. Given a task $\\tau_i$ and its associated dataset $(\\mathcal{D}^{(i)}_\\text{train}, \\mathcal{D}^{(i)}_\\text{test})$, we can update the model parameters by one or more gradient descent steps (the following example only contains one step):\n $$ \\theta'_i = \\theta - \\alpha \\nabla_\\theta\\mathcal{L}^{(0)}_{\\tau_i}(f_\\theta) $$ where $\\mathcal{L}^{(0)}$ is the loss computed using the mini data batch with id (0).\nFig. 11. Diagram of MAML. (Image source: original paper) Well, the above formula only optimizes for one task. To achieve a good generalization across a variety of tasks, we would like to find the optimal $\\theta^*$ so that the task-specific fine-tuning is more efficient. Now, we sample a new data batch with id (1) for updating the meta-objective. The loss, denoted as $\\mathcal{L}^{(1)}$, depends on the mini batch (1). The superscripts in $\\mathcal{L}^{(0)}$ and $\\mathcal{L}^{(1)}$ only indicate different data batches, and they refer to the same loss objective for the same task.\n $$ \\begin{aligned} \\theta^* \u0026= \\arg\\min_\\theta \\sum_{\\tau_i \\sim p(\\tau)} \\mathcal{L}_{\\tau_i}^{(1)} (f_{\\theta'_i}) = \\arg\\min_\\theta \\sum_{\\tau_i \\sim p(\\tau)} \\mathcal{L}_{\\tau_i}^{(1)} (f_{\\theta - \\alpha\\nabla_\\theta \\mathcal{L}_{\\tau_i}^{(0)}(f_\\theta)}) \u0026 \\\\ \\theta \u0026\\leftarrow \\theta - \\beta \\nabla_{\\theta} \\sum_{\\tau_i \\sim p(\\tau)} \\mathcal{L}_{\\tau_i}^{(1)} (f_{\\theta - \\alpha\\nabla_\\theta \\mathcal{L}_{\\tau_i}^{(0)}(f_\\theta)}) \u0026 \\scriptstyle{\\text{; updating rule}} \\end{aligned} $$ Fig. 12. The general form of MAML algorithm. (Image source: original paper) First-Order MAML The meta-optimization step above relies on second derivatives. To make the computation less expensive, a modified version of MAML omits second derivatives, resulting in a simplified and cheaper implementation, known as First-Order MAML (FOMAML).\nLet\u0026rsquo;s consider the case of performing $k$ inner gradient steps, $k\\geq1$. Starting with the initial model parameter $\\theta_\\text{meta}$:\n $$ \\begin{aligned} \\theta_0 \u0026= \\theta_\\text{meta}\\\\ \\theta_1 \u0026= \\theta_0 - \\alpha\\nabla_\\theta\\mathcal{L}^{(0)}(\\theta_0)\\\\ \\theta_2 \u0026= \\theta_1 - \\alpha\\nabla_\\theta\\mathcal{L}^{(0)}(\\theta_1)\\\\ \u0026\\dots\\\\ \\theta_k \u0026= \\theta_{k-1} - \\alpha\\nabla_\\theta\\mathcal{L}^{(0)}(\\theta_{k-1}) \\end{aligned} $$ Then in the outer loop, we sample a new data batch for updating the meta-objective.\n $$ \\begin{aligned} \\theta_\\text{meta} \u0026\\leftarrow \\theta_\\text{meta} - \\beta g_\\text{MAML} \u0026 \\scriptstyle{\\text{; update for meta-objective}} \\\\[2mm] \\text{where } g_\\text{MAML} \u0026= \\nabla_{\\theta} \\mathcal{L}^{(1)}(\\theta_k) \u0026\\\\[2mm] \u0026= \\nabla_{\\theta_k} \\mathcal{L}^{(1)}(\\theta_k) \\cdot (\\nabla_{\\theta_{k-1}} \\theta_k) \\dots (\\nabla_{\\theta_0} \\theta_1) \\cdot (\\nabla_{\\theta} \\theta_0) \u0026 \\scriptstyle{\\text{; following the chain rule}} \\\\ \u0026= \\nabla_{\\theta_k} \\mathcal{L}^{(1)}(\\theta_k) \\cdot \\Big( \\prod_{i=1}^k \\nabla_{\\theta_{i-1}} \\theta_i \\Big) \\cdot I \u0026 \\\\ \u0026= \\nabla_{\\theta_k} \\mathcal{L}^{(1)}(\\theta_k) \\cdot \\prod_{i=1}^k \\nabla_{\\theta_{i-1}} (\\theta_{i-1} - \\alpha\\nabla_\\theta\\mathcal{L}^{(0)}(\\theta_{i-1})) \u0026 \\\\ \u0026= \\nabla_{\\theta_k} \\mathcal{L}^{(1)}(\\theta_k) \\cdot \\prod_{i=1}^k (I - \\alpha\\nabla_{\\theta_{i-1}}(\\nabla_\\theta\\mathcal{L}^{(0)}(\\theta_{i-1}))) \u0026 \\end{aligned} $$ The MAML gradient is:\n $$ g_\\text{MAML} = \\nabla_{\\theta_k} \\mathcal{L}^{(1)}(\\theta_k) \\cdot \\prod_{i=1}^k (I - \\alpha \\color{red}{\\nabla_{\\theta_{i-1}}(\\nabla_\\theta\\mathcal{L}^{(0)}(\\theta_{i-1}))}) $$ The First-Order MAML ignores the second derivative part in red. It is simplified as follows, equivalent to the derivative of the last inner gradient update result.\n $$ g_\\text{FOMAML} = \\nabla_{\\theta_k} \\mathcal{L}^{(1)}(\\theta_k) $$ Reptile Reptile (Nichol, Achiam \u0026amp; Schulman, 2018) is a remarkably simple meta-learning optimization algorithm. It is similar to MAML in many ways, given that both rely on meta-optimization through gradient descent and both are model-agnostic.\nThe Reptile works by repeatedly:\n sampling a task, training on it by multiple gradient descent steps, and then moving the model weights towards the new parameters. See the algorithm below: $\\text{SGD}(\\mathcal{L}_{\\tau_i}, \\theta, k)$ performs stochastic gradient update for k steps on the loss $\\mathcal{L}_{\\tau_i}$ starting with initial parameter $\\theta$ and returns the final parameter vector. The batch version samples multiple tasks instead of one within each iteration. The reptile gradient is defined as $(\\theta - W)/\\alpha$, where $\\alpha$ is the stepsize used by the SGD operation.\nFig. 13. The batched version of Reptile algorithm. (Image source: original paper) At a glance, the algorithm looks a lot like an ordinary SGD. However, because the task-specific optimization can take more than one step. it eventually makes $$\\text{SGD}(\\mathbb{E} \\tau[\\mathcal{L}{\\tau}], \\theta, k)$ diverge from $\\mathbb{E}\\tau [\\text{SGD}(\\mathcal{L}{\\tau}, \\theta, k)]$$ when k \u0026gt; 1.\nThe Optimization Assumption Assuming that a task $\\tau \\sim p(\\tau)$ has a manifold of optimal network configuration, $\\mathcal{W}_{\\tau}^*$. The model $f_\\theta$ achieves the best performance for task $\\tau$ when $\\theta$ lays on the surface of $\\mathcal{W}_{\\tau}^*$. To find a solution that is good across tasks, we would like to find a parameter close to all the optimal manifolds of all tasks:\n $$ \\theta^* = \\arg\\min_\\theta \\mathbb{E}_{\\tau \\sim p(\\tau)} [\\frac{1}{2} \\text{dist}(\\theta, \\mathcal{W}_\\tau^*)^2] $$ Fig. 14. The Reptile algorithm updates the parameter alternatively to be closer to the optimal manifolds of different tasks. (Image source: original paper) Let\u0026rsquo;s use the L2 distance as $\\text{dist}(.)$ and the distance between a point $\\theta$ and a set $\\mathcal{W}_\\tau^*$ equals to the distance between $\\theta$ and a point $W_{\\tau}^*(\\theta)$ on the manifold that is closest to $\\theta$:\n $$ \\text{dist}(\\theta, \\mathcal{W}_{\\tau}^*) = \\text{dist}(\\theta, W_{\\tau}^*(\\theta)) \\text{, where }W_{\\tau}^*(\\theta) = \\arg\\min_{W\\in\\mathcal{W}_{\\tau}^*} \\text{dist}(\\theta, W) $$ The gradient of the squared euclidean distance is:\n $$ \\begin{aligned} \\nabla_\\theta[\\frac{1}{2}\\text{dist}(\\theta, \\mathcal{W}_{\\tau_i}^*)^2] \u0026= \\nabla_\\theta[\\frac{1}{2}\\text{dist}(\\theta, W_{\\tau_i}^*(\\theta))^2] \u0026 \\\\ \u0026= \\nabla_\\theta[\\frac{1}{2}(\\theta - W_{\\tau_i}^*(\\theta))^2] \u0026 \\\\ \u0026= \\theta - W_{\\tau_i}^*(\\theta) \u0026 \\scriptstyle{\\text{; See notes.}} \\end{aligned} $$ Notes: According to the Reptile paper, \u0026ldquo;the gradient of the squared euclidean distance between a point Θ and a set S is the vector 2(Θ − p), where p is the closest point in S to Θ\u0026rdquo;. Technically the closest point in S is also a function of Θ, but I\u0026rsquo;m not sure why the gradient does not need to worry about the derivative of p. (Please feel free to leave me a comment or send me an email about this if you have ideas.)\nThus the update rule for one stochastic gradient step is:\n $$ \\theta = \\theta - \\alpha \\nabla_\\theta[\\frac{1}{2} \\text{dist}(\\theta, \\mathcal{W}_{\\tau_i}^*)^2] = \\theta - \\alpha(\\theta - W_{\\tau_i}^*(\\theta)) = (1-\\alpha)\\theta + \\alpha W_{\\tau_i}^*(\\theta) $$ The closest point on the optimal task manifold $W_{\\tau_i}^*(\\theta)$ cannot be computed exactly, but Reptile approximates it using $\\text{SGD}(\\mathcal{L}_\\tau, \\theta, k)$.\nReptile vs FOMAML To demonstrate the deeper connection between Reptile and MAML, let\u0026rsquo;s expand the update formula with an example performing two gradient steps, k=2 in $\\text{SGD}(.)$. Same as defined above, $\\mathcal{L}^{(0)}$ and $\\mathcal{L}^{(1)}$ are losses using different mini-batches of data. For ease of reading, we adopt two simplified annotations: $g^{(i)}_j = \\nabla_{\\theta} \\mathcal{L}^{(i)}(\\theta_j)$ and $H^{(i)}_j = \\nabla^2_{\\theta} \\mathcal{L}^{(i)}(\\theta_j)$.\n $$ \\begin{aligned} \\theta_0 \u0026= \\theta_\\text{meta}\\\\ \\theta_1 \u0026= \\theta_0 - \\alpha\\nabla_\\theta\\mathcal{L}^{(0)}(\\theta_0)= \\theta_0 - \\alpha g^{(0)}_0 \\\\ \\theta_2 \u0026= \\theta_1 - \\alpha\\nabla_\\theta\\mathcal{L}^{(1)}(\\theta_1) = \\theta_0 - \\alpha g^{(0)}_0 - \\alpha g^{(1)}_1 \\end{aligned} $$ According to the early section, the gradient of FOMAML is the last inner gradient update result. Therefore, when k=1:\n $$ \\begin{aligned} g_\\text{FOMAML} \u0026= \\nabla_{\\theta_1} \\mathcal{L}^{(1)}(\\theta_1) = g^{(1)}_1 \\\\ g_\\text{MAML} \u0026= \\nabla_{\\theta_1} \\mathcal{L}^{(1)}(\\theta_1) \\cdot (I - \\alpha\\nabla^2_{\\theta} \\mathcal{L}^{(0)}(\\theta_0)) = g^{(1)}_1 - \\alpha H^{(0)}_0 g^{(1)}_1 \\end{aligned} $$ The Reptile gradient is defined as:\n $$ g_\\text{Reptile} = (\\theta_0 - \\theta_2) / \\alpha = g^{(0)}_0 + g^{(1)}_1 $$ Up to now we have:\nFig. 15. Reptile versus FOMAML in one loop of meta-optimization. (Image source: slides on Reptile by Yoonho Lee.) $$ \\begin{aligned} g_\\text{FOMAML} \u0026= g^{(1)}_1 \\\\ g_\\text{MAML} \u0026= g^{(1)}_1 - \\alpha H^{(0)}_0 g^{(1)}_1 \\\\ g_\\text{Reptile} \u0026= g^{(0)}_0 + g^{(1)}_1 \\end{aligned} $$ Next let\u0026rsquo;s try further expand $g^{(1)}_1$ using Taylor expansion. Recall that Taylor expansion of a function $f(x)$ that is differentiable at a number $a$ is:\n $$ f(x) = f(a) + \\frac{f'(a)}{1!}(x-a) + \\frac{f''(a)}{2!}(x-a)^2 + \\dots = \\sum_{i=0}^\\infty \\frac{f^{(i)}(a)}{i!}(x-a)^i $$ We can consider $\\nabla_{\\theta}\\mathcal{L}^{(1)}(.)$ as a function and $\\theta_0$ as a value point. The Taylor expansion of $g_1^{(1)}$ at the value point $\\theta_0$ is:\n $$ \\begin{aligned} g_1^{(1)} \u0026= \\nabla_{\\theta}\\mathcal{L}^{(1)}(\\theta_1) \\\\ \u0026= \\nabla_{\\theta}\\mathcal{L}^{(1)}(\\theta_0) + \\nabla^2_\\theta\\mathcal{L}^{(1)}(\\theta_0)(\\theta_1 - \\theta_0) + \\frac{1}{2}\\nabla^3_\\theta\\mathcal{L}^{(1)}(\\theta_0)(\\theta_1 - \\theta_0)^2 + \\dots \u0026 \\\\ \u0026= g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} + \\frac{\\alpha^2}{2}\\nabla^3_\\theta\\mathcal{L}^{(1)}(\\theta_0) (g_0^{(0)})^2 + \\dots \u0026 \\scriptstyle{\\text{; because }\\theta_1-\\theta_0=-\\alpha g_0^{(0)}} \\\\ \u0026= g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} + O(\\alpha^2) \\end{aligned} $$ Plug in the expanded form of $g_1^{(1)}$ into the MAML gradients with one step inner gradient update:\n $$ \\begin{aligned} g_\\text{FOMAML} \u0026= g^{(1)}_1 = g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} + O(\\alpha^2)\\\\ g_\\text{MAML} \u0026= g^{(1)}_1 - \\alpha H^{(0)}_0 g^{(1)}_1 \\\\ \u0026= g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} + O(\\alpha^2) - \\alpha H^{(0)}_0 (g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} + O(\\alpha^2))\\\\ \u0026= g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} - \\alpha H^{(0)}_0 g_0^{(1)} + \\alpha^2 \\alpha H^{(0)}_0 H^{(1)}_0 g_0^{(0)} + O(\\alpha^2)\\\\ \u0026= g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} - \\alpha H^{(0)}_0 g_0^{(1)} + O(\\alpha^2) \\end{aligned} $$ The Reptile gradient becomes:\n $$ \\begin{aligned} g_\\text{Reptile} \u0026= g^{(0)}_0 + g^{(1)}_1 \\\\ \u0026= g^{(0)}_0 + g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} + O(\\alpha^2) \\end{aligned} $$ So far we have the formula of three types of gradients:\n $$ \\begin{aligned} g_\\text{FOMAML} \u0026= g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} + O(\\alpha^2)\\\\ g_\\text{MAML} \u0026= g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} - \\alpha H^{(0)}_0 g_0^{(1)} + O(\\alpha^2)\\\\ g_\\text{Reptile} \u0026= g^{(0)}_0 + g_0^{(1)} - \\alpha H^{(1)}_0 g_0^{(0)} + O(\\alpha^2) \\end{aligned} $$ During training, we often average over multiple data batches. In our example, the mini batches (0) and (1) are interchangeable since both are drawn at random. The expectation $\\mathbb{E}_{\\tau,0,1}$ is averaged over two data batches, ids (0) and (1), for task $\\tau$.\nLet,\n $A = \\mathbb{E}_{\\tau,0,1} [g_0^{(0)}] = \\mathbb{E}_{\\tau,0,1} [g_0^{(1)}]$; it is the average gradient of task loss. We expect to improve the model parameter to achieve better task performance by following this direction pointed by $A$. $B = \\mathbb{E}_{\\tau,0,1} [H^{(1)}_0 g_0^{(0)}] = \\frac{1}{2}\\mathbb{E}_{\\tau,0,1} [H^{(1)}_0 g_0^{(0)} + H^{(0)}_0 g_0^{(1)}] = \\frac{1}{2}\\mathbb{E}_{\\tau,0,1} [\\nabla_\\theta(g^{(0)}_0 g_0^{(1)})]$; it is the direction (gradient) that increases the inner product of gradients of two different mini batches for the same task. We expect to improve the model parameter to achieve better generalization over different data by following this direction pointed by $B$. To conclude, both MAML and Reptile aim to optimize for the same goal, better task performance (guided by A) and better generalization (guided by B), when the gradient update is approximated by first three leading terms.\n $$ \\begin{aligned} \\mathbb{E}_{\\tau,1,2}[g_\\text{FOMAML}] \u0026= A - \\alpha B + O(\\alpha^2)\\\\ \\mathbb{E}_{\\tau,1,2}[g_\\text{MAML}] \u0026= A - 2\\alpha B + O(\\alpha^2)\\\\ \\mathbb{E}_{\\tau,1,2}[g_\\text{Reptile}] \u0026= 2A - \\alpha B + O(\\alpha^2) \\end{aligned} $$ It is not clear to me whether the ignored term $O(\\alpha^2)$ might play a big impact on the parameter learning. But given that FOMAML is able to obtain a similar performance as the full version of MAML, it might be safe to say higher-level derivatives would not be critical during gradient descent update.\n Cited as:\n@article{weng2018metalearning, title = \u0026quot;Meta-Learning: Learning to Learn Fast\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2018\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2018-11-30-meta-learning/\u0026quot; } Reference [1] Brenden M. Lake, Ruslan Salakhutdinov, and Joshua B. Tenenbaum. \u0026ldquo;Human-level concept learning through probabilistic program induction.\u0026quot; Science 350.6266 (2015): 1332-1338.\n[2] Oriol Vinyals' talk on \u0026ldquo;Model vs Optimization Meta Learning\u0026rdquo;\n[3] Gregory Koch, Richard Zemel, and Ruslan Salakhutdinov. \u0026ldquo;Siamese neural networks for one-shot image recognition.\u0026quot; ICML Deep Learning Workshop. 2015.\n[4] Oriol Vinyals, et al. \u0026ldquo;Matching networks for one shot learning.\u0026quot; NIPS. 2016.\n[5] Flood Sung, et al. \u0026ldquo;Learning to compare: Relation network for few-shot learning.\u0026quot; CVPR. 2018.\n[6] Jake Snell, Kevin Swersky, and Richard Zemel. \u0026ldquo;Prototypical Networks for Few-shot Learning.\u0026quot; CVPR. 2018.\n[7] Adam Santoro, et al. \u0026ldquo;Meta-learning with memory-augmented neural networks.\u0026quot; ICML. 2016.\n[8] Alex Graves, Greg Wayne, and Ivo Danihelka. \u0026ldquo;Neural turing machines.\u0026quot; arXiv preprint arXiv:1410.5401 (2014).\n[9] Tsendsuren Munkhdalai and Hong Yu. \u0026ldquo;Meta Networks.\u0026quot; ICML. 2017.\n[10] Sachin Ravi and Hugo Larochelle. \u0026ldquo;Optimization as a Model for Few-Shot Learning.\u0026quot; ICLR. 2017.\n[11] Chelsea Finn\u0026rsquo;s BAIR blog on \u0026ldquo;Learning to Learn\u0026rdquo;.\n[12] Chelsea Finn, Pieter Abbeel, and Sergey Levine. \u0026ldquo;Model-agnostic meta-learning for fast adaptation of deep networks.\u0026quot; ICML 2017.\n[13] Alex Nichol, Joshua Achiam, John Schulman. \u0026ldquo;On First-Order Meta-Learning Algorithms.\u0026quot; arXiv preprint arXiv:1803.02999 (2018).\n[14] Slides on Reptile by Yoonho Lee.\n","permalink":"https://lilianweng.github.io/posts/2018-11-30-meta-learning/","summary":"[Updated on 2019-10-01: thanks to Tianhao, we have this post translated in Chinese!]\nA good machine learning model often requires training with a large number of samples. Humans, in contrast, learn new concepts and skills much faster and more efficiently. Kids who have seen cats and birds only a few times can quickly tell them apart. People who know how to ride a bike are likely to discover the way to ride a motorcycle fast with little or even no demonstration.","title":"Meta-Learning: Learning to Learn Fast"},{"content":"So far, I\u0026rsquo;ve written about two types of generative models, GAN and VAE. Neither of them explicitly learns the probability density function of real data, $p(\\mathbf{x})$ (where $\\mathbf{x} \\in \\mathcal{D}$) \u0026mdash; because it is really hard! Taking the generative model with latent variables as an example, $p(\\mathbf{x}) = \\int p(\\mathbf{x}\\vert\\mathbf{z})p(\\mathbf{z})d\\mathbf{z}$ can hardly be calculated as it is intractable to go through all possible values of the latent code $\\mathbf{z}$.\nFlow-based deep generative models conquer this hard problem with the help of normalizing flows, a powerful statistics tool for density estimation. A good estimation of $p(\\mathbf{x})$ makes it possible to efficiently complete many downstream tasks: sample unobserved but realistic new data points (data generation), predict the rareness of future events (density estimation), infer latent variables, fill in incomplete data samples, etc.\nTypes of Generative Models Here is a quick summary of the difference between GAN, VAE, and flow-based generative models:\n Generative adversarial networks: GAN provides a smart solution to model the data generation, an unsupervised learning problem, as a supervised one. The discriminator model learns to distinguish the real data from the fake samples that are produced by the generator model. Two models are trained as they are playing a minimax game. Variational autoencoders: VAE inexplicitly optimizes the log-likelihood of the data by maximizing the evidence lower bound (ELBO). Flow-based generative models: A flow-based generative model is constructed by a sequence of invertible transformations. Unlike other two, the model explicitly learns the data distribution $p(\\mathbf{x})$ and therefore the loss function is simply the negative log-likelihood. Fig. 1. Comparison of three categories of generative models. Linear Algebra Basics Recap We should understand two key concepts before getting into the flow-based generative model: the Jacobian determinant and the change of variable rule. Pretty basic, so feel free to skip.\nJacobian Matrix and Determinant Given a function of mapping a $n$-dimensional input vector $\\mathbf{x}$ to a $m$-dimensional output vector, $\\mathbf{f}: \\mathbb{R}^n \\mapsto \\mathbb{R}^m$, the matrix of all first-order partial derivatives of this function is called the Jacobian matrix, $\\mathbf{J}$ where one entry on the i-th row and j-th column is $\\mathbf{J}_{ij} = \\frac{\\partial f_i}{\\partial x_j}$.\n $$ \\mathbf{J} = \\begin{bmatrix} \\frac{\\partial f_1}{\\partial x_1} \u0026 \\dots \u0026 \\frac{\\partial f_1}{\\partial x_n} \\\\[6pt] \\vdots \u0026 \\ddots \u0026 \\vdots \\\\[6pt] \\frac{\\partial f_m}{\\partial x_1} \u0026 \\dots \u0026 \\frac{\\partial f_m}{\\partial x_n} \\\\[6pt] \\end{bmatrix} $$ The determinant is one real number computed as a function of all the elements in a squared matrix. Note that the determinant only exists for square matrices. The absolute value of the determinant can be thought of as a measure of \u0026ldquo;how much multiplication by the matrix expands or contracts space\u0026rdquo;.\nThe determinant of a nxn matrix $M$ is:\n $$ \\det M = \\det \\begin{bmatrix} a_{11} \u0026 a_{12} \u0026 \\dots \u0026 a_{1n} \\\\ a_{21} \u0026 a_{22} \u0026 \\dots \u0026 a_{2n} \\\\ \\vdots \u0026 \\vdots \u0026 \u0026 \\vdots \\\\ a_{n1} \u0026 a_{n2} \u0026 \\dots \u0026 a_{nn} \\\\ \\end{bmatrix} = \\sum_{j_1 j_2 \\dots j_n} (-1)^{\\tau(j_1 j_2 \\dots j_n)} a_{1j_1} a_{2j_2} \\dots a_{nj_n} $$ where the subscript under the summation $j_1 j_2 \\dots j_n$ are all permutations of the set {1, 2, \u0026hellip;, n}, so there are $n!$ items in total; $\\tau(.)$ indicates the signature of a permutation.\nThe determinant of a square matrix $M$ detects whether it is invertible: If $\\det(M)=0$ then $M$ is not invertible (a singular matrix with linearly dependent rows or columns; or any row or column is all 0); otherwise, if $\\det(M)\\neq 0$, $M$ is invertible.\nThe determinant of the product is equivalent to the product of the determinants: $\\det(AB) = \\det(A)\\det(B)$. (proof)\nChange of Variable Theorem Let\u0026rsquo;s review the change of variable theorem specifically in the context of probability density estimation, starting with a single variable case.\nGiven a random variable $z$ and its known probability density function $z \\sim \\pi(z)$, we would like to construct a new random variable using a 1-1 mapping function $x = f(z)$. The function $f$ is invertible, so $z=f^{-1}(x)$. Now the question is how to infer the unknown probability density function of the new variable, $p(x)$?\n $$ \\begin{aligned} \u0026 \\int p(x)dx = \\int \\pi(z)dz = 1 \\scriptstyle{\\text{ ; Definition of probability distribution.}}\\\\ \u0026 p(x) = \\pi(z) \\left\\vert\\frac{dz}{dx}\\right\\vert = \\pi(f^{-1}(x)) \\left\\vert\\frac{d f^{-1}}{dx}\\right\\vert = \\pi(f^{-1}(x)) \\vert (f^{-1})'(x) \\vert \\end{aligned} $$ By definition, the integral $\\int \\pi(z)dz$ is the sum of an infinite number of rectangles of infinitesimal width $\\Delta z$. The height of such a rectangle at position $z$ is the value of the density function $\\pi(z)$. When we substitute the variable, $z = f^{-1}(x)$ yields $\\frac{\\Delta z}{\\Delta x} = (f^{-1}(x))'$ and $\\Delta z = (f^{-1}(x))' \\Delta x$. Here $\\vert(f^{-1}(x))'\\vert$ indicates the ratio between the area of rectangles defined in two different coordinate of variables $z$ and $x$ respectively.\nThe multivariable version has a similar format:\n $$ \\begin{aligned} \\mathbf{z} \u0026\\sim \\pi(\\mathbf{z}), \\mathbf{x} = f(\\mathbf{z}), \\mathbf{z} = f^{-1}(\\mathbf{x}) \\\\ p(\\mathbf{x}) \u0026= \\pi(\\mathbf{z}) \\left\\vert \\det \\dfrac{d \\mathbf{z}}{d \\mathbf{x}} \\right\\vert = \\pi(f^{-1}(\\mathbf{x})) \\left\\vert \\det \\dfrac{d f^{-1}}{d \\mathbf{x}} \\right\\vert \\end{aligned} $$ where $\\det \\frac{\\partial f}{\\partial\\mathbf{z}}$ is the Jacobian determinant of the function $f$. The full proof of the multivariate version is out of the scope of this post; ask Google if interested ;)\nWhat is Normalizing Flows? Being able to do good density estimation has direct applications in many machine learning problems, but it is very hard. For example, since we need to run backward propagation in deep learning models, the embedded probability distribution (i.e. posterior $p(\\mathbf{z}\\vert\\mathbf{x})$) is expected to be simple enough to calculate the derivative easily and efficiently. That is why Gaussian distribution is often used in latent variable generative models, even though most of real world distributions are much more complicated than Gaussian.\nHere comes a Normalizing Flow (NF) model for better and more powerful distribution approximation. A normalizing flow transforms a simple distribution into a complex one by applying a sequence of invertible transformation functions. Flowing through a chain of transformations, we repeatedly substitute the variable for the new one according to the change of variables theorem and eventually obtain a probability distribution of the final target variable.\nFig. 2. Illustration of a normalizing flow model, transforming a simple distribution $p\\_0(\\mathbf{z}\\_0)$ to a complex one $p\\_K(\\mathbf{z}\\_K)$ step by step. As defined in Fig. 2,\n $$ \\begin{aligned} \\mathbf{z}_{i-1} \u0026\\sim p_{i-1}(\\mathbf{z}_{i-1}) \\\\ \\mathbf{z}_i \u0026= f_i(\\mathbf{z}_{i-1})\\text{, thus }\\mathbf{z}_{i-1} = f_i^{-1}(\\mathbf{z}_i) \\\\ p_i(\\mathbf{z}_i) \u0026= p_{i-1}(f_i^{-1}(\\mathbf{z}_i)) \\left\\vert \\det\\dfrac{d f_i^{-1}}{d \\mathbf{z}_i} \\right\\vert \\end{aligned} $$ Then let\u0026rsquo;s convert the equation to be a function of $\\mathbf{z}_i$ so that we can do inference with the base distribution.\n $$ \\begin{aligned} p_i(\\mathbf{z}_i) \u0026= p_{i-1}(f_i^{-1}(\\mathbf{z}_i)) \\left\\vert \\det\\dfrac{d f_i^{-1}}{d \\mathbf{z}_i} \\right\\vert \\\\ \u0026= p_{i-1}(\\mathbf{z}_{i-1}) \\left\\vert \\det \\color{red}{\\Big(\\dfrac{d f_i}{d\\mathbf{z}_{i-1}}\\Big)^{-1}} \\right\\vert \u0026 \\scriptstyle{\\text{; According to the inverse func theorem.}} \\\\ \u0026= p_{i-1}(\\mathbf{z}_{i-1}) \\color{red}{\\left\\vert \\det \\dfrac{d f_i}{d\\mathbf{z}_{i-1}} \\right\\vert^{-1}} \u0026 \\scriptstyle{\\text{; According to a property of Jacobians of invertible func.}} \\\\ \\log p_i(\\mathbf{z}_i) \u0026= \\log p_{i-1}(\\mathbf{z}_{i-1}) - \\log \\left\\vert \\det \\dfrac{d f_i}{d\\mathbf{z}_{i-1}} \\right\\vert \\end{aligned} $$ (*) A note on the \u0026ldquo;inverse function theorem\u0026rdquo;: If $y=f(x)$ and $x=f^{-1}(y)$, we have:\n $$ \\dfrac{df^{-1}(y)}{dy} = \\dfrac{dx}{dy} = (\\dfrac{dy}{dx})^{-1} = (\\dfrac{df(x)}{dx})^{-1} $$ (*) A note on \u0026ldquo;Jacobians of invertible function\u0026rdquo;: The determinant of the inverse of an invertible matrix is the inverse of the determinant: $\\det(M^{-1}) = (\\det(M))^{-1}$, because $\\det(M)\\det(M^{-1}) = \\det(M \\cdot M^{-1}) = \\det(I) = 1$.\nGiven such a chain of probability density functions, we know the relationship between each pair of consecutive variables. We can expand the equation of the output $\\mathbf{x}$ step by step until tracing back to the initial distribution $\\mathbf{z}_0$.\n $$ \\begin{aligned} \\mathbf{x} = \\mathbf{z}_K \u0026= f_K \\circ f_{K-1} \\circ \\dots \\circ f_1 (\\mathbf{z}_0) \\\\ \\log p(\\mathbf{x}) = \\log \\pi_K(\\mathbf{z}_K) \u0026= \\log \\pi_{K-1}(\\mathbf{z}_{K-1}) - \\log\\left\\vert\\det\\dfrac{d f_K}{d \\mathbf{z}_{K-1}}\\right\\vert \\\\ \u0026= \\log \\pi_{K-2}(\\mathbf{z}_{K-2}) - \\log\\left\\vert\\det\\dfrac{d f_{K-1}}{d\\mathbf{z}_{K-2}}\\right\\vert - \\log\\left\\vert\\det\\dfrac{d f_K}{d\\mathbf{z}_{K-1}}\\right\\vert \\\\ \u0026= \\dots \\\\ \u0026= \\log \\pi_0(\\mathbf{z}_0) - \\sum_{i=1}^K \\log\\left\\vert\\det\\dfrac{d f_i}{d\\mathbf{z}_{i-1}}\\right\\vert \\end{aligned} $$ The path traversed by the random variables $\\mathbf{z}_i = f_i(\\mathbf{z}_{i-1})$ is the flow and the full chain formed by the successive distributions $\\pi_i$ is called a normalizing flow. Required by the computation in the equation, a transformation function $f_i$ should satisfy two properties:\n It is easily invertible. Its Jacobian determinant is easy to compute. Models with Normalizing Flows With normalizing flows in our toolbox, the exact log-likelihood of input data $\\log p(\\mathbf{x})$ becomes tractable. As a result, the training criterion of flow-based generative model is simply the negative log-likelihood (NLL) over the training dataset $\\mathcal{D}$:\n $$ \\mathcal{L}(\\mathcal{D}) = - \\frac{1}{\\vert\\mathcal{D}\\vert}\\sum_{\\mathbf{x} \\in \\mathcal{D}} \\log p(\\mathbf{x}) $$ RealNVP The RealNVP (Real-valued Non-Volume Preserving; Dinh et al., 2017) model implements a normalizing flow by stacking a sequence of invertible bijective transformation functions. In each bijection $f: \\mathbf{x} \\mapsto \\mathbf{y}$, known as affine coupling layer, the input dimensions are split into two parts:\n The first $d$ dimensions stay same; The second part, $d+1$ to $D$ dimensions, undergo an affine transformation (\u0026ldquo;scale-and-shift\u0026rdquo;) and both the scale and shift parameters are functions of the first $d$ dimensions. $$ \\begin{aligned} \\mathbf{y}_{1:d} \u0026= \\mathbf{x}_{1:d} \\\\ \\mathbf{y}_{d+1:D} \u0026= \\mathbf{x}_{d+1:D} \\odot \\exp({s(\\mathbf{x}_{1:d})}) + t(\\mathbf{x}_{1:d}) \\end{aligned} $$ where $s(.)$ and $t(.)$ are scale and translation functions and both map $\\mathbb{R}^d \\mapsto \\mathbb{R}^{D-d}$. The $\\odot$ operation is the element-wise product.\nNow let\u0026rsquo;s check whether this transformation satisfy two basic properties for a flow transformation.\nCondition 1: \u0026ldquo;It is easily invertible.\u0026rdquo;\nYes and it is fairly straightforward.\n $$ \\begin{cases} \\mathbf{y}_{1:d} \u0026= \\mathbf{x}_{1:d} \\\\ \\mathbf{y}_{d+1:D} \u0026= \\mathbf{x}_{d+1:D} \\odot \\exp({s(\\mathbf{x}_{1:d})}) + t(\\mathbf{x}_{1:d}) \\end{cases} \\Leftrightarrow \\begin{cases} \\mathbf{x}_{1:d} \u0026= \\mathbf{y}_{1:d} \\\\ \\mathbf{x}_{d+1:D} \u0026= (\\mathbf{y}_{d+1:D} - t(\\mathbf{y}_{1:d})) \\odot \\exp(-s(\\mathbf{y}_{1:d})) \\end{cases} $$ Condition 2: \u0026ldquo;Its Jacobian determinant is easy to compute.\u0026rdquo;\nYes. It is not hard to get the Jacobian matrix and determinant of this transformation. The Jacobian is a lower triangular matrix.\n $$ \\mathbf{J} = \\begin{bmatrix} \\mathbb{I}_d \u0026 \\mathbf{0}_{d\\times(D-d)} \\\\[5pt] \\frac{\\partial \\mathbf{y}_{d+1:D}}{\\partial \\mathbf{x}_{1:d}} \u0026 \\text{diag}(\\exp(s(\\mathbf{x}_{1:d}))) \\end{bmatrix} $$ Hence the determinant is simply the product of terms on the diagonal.\n $$ \\det(\\mathbf{J}) = \\prod_{j=1}^{D-d}\\exp(s(\\mathbf{x}_{1:d}))_j = \\exp(\\sum_{j=1}^{D-d} s(\\mathbf{x}_{1:d})_j) $$ So far, the affine coupling layer looks perfect for constructing a normalizing flow :)\nEven better, since (i) computing $f^-1$ does not require computing the inverse of $s$ or $t$ and (ii) computing the Jacobian determinant does not involve computing the Jacobian of $s$ or $t$, those functions can be arbitrarily complex; i.e. both $s$ and $t$ can be modeled by deep neural networks.\nIn one affine coupling layer, some dimensions (channels) remain unchanged. To make sure all the inputs have a chance to be altered, the model reverses the ordering in each layer so that different components are left unchanged. Following such an alternating pattern, the set of units which remain identical in one transformation layer are always modified in the next. Batch normalization is found to help training models with a very deep stack of coupling layers.\nFurthermore, RealNVP can work in a multi-scale architecture to build a more efficient model for large inputs. The multi-scale architecture applies several \u0026ldquo;sampling\u0026rdquo; operations to normal affine layers, including spatial checkerboard pattern masking, squeezing operation, and channel-wise masking. Read the paper for more details on the multi-scale architecture.\nNICE The NICE (Non-linear Independent Component Estimation; Dinh, et al. 2015) model is a predecessor of RealNVP. The transformation in NICE is the affine coupling layer without the scale term, known as additive coupling layer.\n $$ \\begin{cases} \\mathbf{y}_{1:d} \u0026= \\mathbf{x}_{1:d} \\\\ \\mathbf{y}_{d+1:D} \u0026= \\mathbf{x}_{d+1:D} + m(\\mathbf{x}_{1:d}) \\end{cases} \\Leftrightarrow \\begin{cases} \\mathbf{x}_{1:d} \u0026= \\mathbf{y}_{1:d} \\\\ \\mathbf{x}_{d+1:D} \u0026= \\mathbf{y}_{d+1:D} - m(\\mathbf{y}_{1:d}) \\end{cases} $$ Glow The Glow (Kingma and Dhariwal, 2018) model extends the previous reversible generative models, NICE and RealNVP, and simplifies the architecture by replacing the reverse permutation operation on the channel ordering with invertible 1x1 convolutions.\nFig. 3. One step of flow in the Glow model. (Image source: Kingma and Dhariwal, 2018) There are three substeps in one step of flow in Glow.\nSubstep 1: Activation normalization (short for \u0026ldquo;actnorm\u0026rdquo;)\nIt performs an affine transformation using a scale and bias parameter per channel, similar to batch normalization, but works for mini-batch size 1. The parameters are trainable but initialized so that the first minibatch of data have mean 0 and standard deviation 1 after actnorm.\nSubstep 2: Invertible 1x1 conv\nBetween layers of the RealNVP flow, the ordering of channels is reversed so that all the data dimensions have a chance to be altered. A 1×1 convolution with equal number of input and output channels is a generalization of any permutation of the channel ordering.\nSay, we have an invertible 1x1 convolution of an input $h \\times w \\times c$ tensor $\\mathbf{h}$ with a weight matrix $\\mathbf{W}$ of size $c \\times c$. The output is a $h \\times w \\times c$ tensor, labeled as $f = \\texttt{conv2d}(\\mathbf{h}; \\mathbf{W})$. In order to apply the change of variable rule, we need to compute the Jacobian determinant $\\vert \\det\\partial f / \\partial\\mathbf{h}\\vert$.\nBoth the input and output of 1x1 convolution here can be viewed as a matrix of size $h \\times w$. Each entry $\\mathbf{x}_{ij}$ ($i=1,\\dots,h, j=1,\\dots,w$) in $\\mathbf{h}$ is a vector of $c$ channels and each entry is multiplied by the weight matrix $\\mathbf{W}$ to obtain the corresponding entry $\\mathbf{y}_{ij}$ in the output matrix respectively. The derivative of each entry is $\\partial \\mathbf{x}_{ij} \\mathbf{W} / \\partial\\mathbf{x}_{ij} = \\mathbf{W}$ and there are $h \\times w$ such entries in total:\n $$ \\log \\left\\vert\\det \\frac{\\partial\\texttt{conv2d}(\\mathbf{h}; \\mathbf{W})}{\\partial\\mathbf{h}}\\right\\vert = \\log (\\vert\\det\\mathbf{W}\\vert^{h \\cdot w}\\vert) = h \\cdot w \\cdot \\log \\vert\\det\\mathbf{W}\\vert $$ The inverse 1x1 convolution depends on the inverse matrix $\\mathbf{W}^{-1}$. Since the weight matrix is relatively small, the amount of computation for the matrix determinant (tf.linalg.det) and inversion (tf.linalg.inv) is still under control.\nSubstep 3: Affine coupling layer\nThe design is same as in RealNVP.\nFig. 4. Three substeps in one step of flow in Glow. (Image source: Kingma and Dhariwal, 2018) Models with Autoregressive Flows The autoregressive constraint is a way to model sequential data, $\\mathbf{x} = [x_1, \\dots, x_D]$: each output only depends on the data observed in the past, but not on the future ones. In other words, the probability of observing $x_i$ is conditioned on $x_1, \\dots, x_{i-1}$ and the product of these conditional probabilities gives us the probability of observing the full sequence:\n $$ p(\\mathbf{x}) = \\prod_{i=1}^{D} p(x_i\\vert x_1, \\dots, x_{i-1}) = \\prod_{i=1}^{D} p(x_i\\vert x_{1:i-1}) $$ How to model the conditional density is of your choice. It can be a univariate Gaussian with mean and standard deviation computed as a function of $x_{1:i-1}$, or a multilayer neural network with $x_{1:i-1}$ as the input.\nIf a flow transformation in a normalizing flow is framed as an autoregressive model \u0026mdash; each dimension in a vector variable is conditioned on the previous dimensions \u0026mdash; this is an autoregressive flow.\nThis section starts with several classic autoregressive models (MADE, PixelRNN, WaveNet) and then we dive into autoregressive flow models (MAF and IAF).\nMADE MADE (Masked Autoencoder for Distribution Estimation; Germain et al., 2015) is a specially designed architecture to enforce the autoregressive property in the autoencoder efficiently. When using an autoencoder to predict the conditional probabilities, rather than feeding the autoencoder with input of different observation windows $D$ times, MADE removes the contribution from certain hidden units by multiplying binary mask matrices so that each input dimension is reconstructed only from previous dimensions in a given ordering in a single pass.\nIn a multilayer fully-connected neural network, say, we have $L$ hidden layers with weight matrices $\\mathbf{W}^1, \\dots, \\mathbf{W}^L$ and an output layer with weight matrix $\\mathbf{V}$. The output $\\hat{\\mathbf{x}}$ has each dimension $\\hat{x}_i = p(x_i\\vert x_{1:i-1})$.\nWithout any mask, the computation through layers looks like the following:\n $$ \\begin{aligned} \\mathbf{h}^0 \u0026= \\mathbf{x} \\\\ \\mathbf{h}^l \u0026= \\text{activation}^l(\\mathbf{W}^l\\mathbf{h}^{l-1} + \\mathbf{b}^l) \\\\ \\hat{\\mathbf{x}} \u0026= \\sigma(\\mathbf{V}\\mathbf{h}^L + \\mathbf{c}) \\end{aligned} $$ Fig. 5. Demonstration of how MADE works in a three-layer feed-forward neural network. (Image source: Germain et al., 2015) To zero out some connections between layers, we can simply element-wise multiply every weight matrix by a binary mask matrix. Each hidden node is assigned with a random \u0026ldquo;connectivity integer\u0026rdquo; between $1$ and $D-1$; the assigned value for the $k$-th unit in the $l$-th layer is denoted by $m^l_k$. The binary mask matrix is determined by element-wise comparing values of two nodes in two layers.\n $$ \\begin{aligned} \\mathbf{h}^l \u0026= \\text{activation}^l((\\mathbf{W}^l \\color{red}{\\odot \\mathbf{M}^{\\mathbf{W}^l}}) \\mathbf{h}^{l-1} + \\mathbf{b}^l) \\\\ \\hat{\\mathbf{x}} \u0026= \\sigma((\\mathbf{V} \\color{red}{\\odot \\mathbf{M}^{\\mathbf{V}}}) \\mathbf{h}^L + \\mathbf{c}) \\\\ M^{\\mathbf{W}^l}_{k', k} \u0026= \\mathbf{1}_{m^l_{k'} \\geq m^{l-1}_k} = \\begin{cases} 1, \u0026 \\text{if } m^l_{k'} \\geq m^{l-1}_k\\\\ 0, \u0026 \\text{otherwise} \\end{cases} \\\\ M^{\\mathbf{V}}_{d, k} \u0026= \\mathbf{1}_{d \\geq m^L_k} = \\begin{cases} 1, \u0026 \\text{if } d m^L_k\\\\ 0, \u0026 \\text{otherwise} \\end{cases} \\end{aligned} $$ A unit in the current layer can only be connected to other units with equal or smaller numbers in the previous layer and this type of dependency easily propagates through the network up to the output layer. Once the numbers are assigned to all the units and layers, the ordering of input dimensions is fixed and the conditional probability is produced with respect to it. See a great illustration in Fig. 5. To make sure all the hidden units are connected to the input and output layers through some paths, the $m^l_k$ is sampled to be equal or greater than the minimal connectivity integer in the previous layer, $\\min_{k'} m_{k'}^{l-1}$.\nMADE training can be further facilitated by:\n Order-agnostic training: shuffle the input dimensions, so that MADE is able to model any arbitrary ordering; can create an ensemble of autoregressive models at the runtime. Connectivity-agnostic training: to avoid a model being tied up to a specific connectivity pattern constraints, resample $m^l_k$ for each training minibatch. PixelRNN PixelRNN (Oord et al, 2016) is a deep generative model for images. The image is generated one pixel at a time and each new pixel is sampled conditional on the pixels that have been seen before.\nLet\u0026rsquo;s consider an image of size $n \\times n$, $\\mathbf{x} = \\{x_1, \\dots, x_{n^2}\\}$, the model starts generating pixels from the top left corner, from left to right and top to bottom (See Fig. 6).\nFig. 6. The context for generating one pixel in PixelRNN. (Image source: Oord et al, 2016) Every pixel $x_i$ is sampled from a probability distribution conditional over the the past context: pixels above it or on the left of it when in the same row. The definition of such context looks pretty arbitrary, because how visual attention is attended to an image is more flexible. Somehow magically a generative model with such a strong assumption works.\nOne implementation that could capture the entire context is the Diagonal BiLSTM. First, apply the skewing operation by offsetting each row of the input feature map by one position with respect to the previous row, so that computation for each row can be parallelized. Then the LSTM states are computed with respect to the current pixel and the pixels on the left.\nFig. 7. (a) PixelRNN with diagonal BiLSTM. (b) Skewing operation that offsets each row in the feature map by one with regards to the row above. (Image source: Oord et al, 2016) $$ \\begin{aligned} \\lbrack \\mathbf{o}_i, \\mathbf{f}_i, \\mathbf{i}_i, \\mathbf{g}_i \\rbrack \u0026= \\sigma(\\mathbf{K}^{ss} \\circledast \\mathbf{h}_{i-1} + \\mathbf{K}^{is} \\circledast \\mathbf{x}_i) \u0026 \\scriptstyle{\\text{; }\\sigma\\scriptstyle{\\text{ is tanh for g, but otherwise sigmoid; }}\\circledast\\scriptstyle{\\text{ is convolution operation.}}} \\\\ \\mathbf{c}_i \u0026= \\mathbf{f}_i \\odot \\mathbf{c}_{i-1} + \\mathbf{i}_i \\odot \\mathbf{g}_i \u0026 \\scriptstyle{\\text{; }}\\odot\\scriptstyle{\\text{ is elementwise product.}}\\\\ \\mathbf{h}_i \u0026= \\mathbf{o}_i \\odot \\tanh(\\mathbf{c}_i) \\end{aligned} $$ where $\\circledast$ denotes the convolution operation and $\\odot$ is the element-wise multiplication. The input-to-state component $\\mathbf{K}^{is}$ is a 1x1 convolution, while the state-to-state recurrent component is computed with a column-wise convolution $\\mathbf{K}^{ss}$ with a kernel of size 2x1.\nThe diagonal BiLSTM layers are capable of processing an unbounded context field, but expensive to compute due to the sequential dependency between states. A faster implementation uses multiple convolutional layers without pooling to define a bounded context box. The convolution kernel is masked so that the future context is not seen, similar to MADE. This convolution version is called PixelCNN.\nFig. 8. PixelCNN with masked convolution constructed by an elementwise product of a mask tensor and the convolution kernel before applying it. (Image source: http://slazebni.cs.illinois.edu/spring17/lec13_advanced.pdf) WaveNet WaveNet (Van Den Oord, et al. 2016) is very similar to PixelCNN but applied to 1-D audio signals. WaveNet consists of a stack of causal convolution which is a convolution operation designed to respect the ordering: the prediction at a certain timestamp can only consume the data observed in the past, no dependency on the future. In PixelCNN, the causal convolution is implemented by masked convolution kernel. The causal convolution in WaveNet is simply to shift the output by a number of timestamps to the future so that the output is aligned with the last input element.\nOne big drawback of convolution layer is a very limited size of receptive field. The output can hardly depend on the input hundreds or thousands of timesteps ago, which can be a crucial requirement for modeling long sequences. WaveNet therefore adopts dilated convolution (animation), where the kernel is applied to an evenly-distributed subset of samples in a much larger receptive field of the input.\nFig. 9. Visualization of WaveNet models with a stack of (top) causal convolution layers and (bottom) dilated convolution layers. (Image source: Van Den Oord, et al. 2016) WaveNet uses the gated activation unit as the non-linear layer, as it is found to work significantly better than ReLU for modeling 1-D audio data. The residual connection is applied after the gated activation.\n $$ \\mathbf{z} = \\tanh(\\mathbf{W}_{f,k}\\circledast\\mathbf{x})\\odot\\sigma(\\mathbf{W}_{g,k}\\circledast\\mathbf{x}) $$ where $\\mathbf{W}_{f,k}$ and $\\mathbf{W}_{g,k}$ are convolution filter and gate weight matrix of the $k$-th layer, respectively; both are learnable.\nMasked Autoregressive Flow Masked Autoregressive Flow (MAF; Papamakarios et al., 2017) is a type of normalizing flows, where the transformation layer is built as an autoregressive neural network. MAF is very similar to Inverse Autoregressive Flow (IAF) introduced later. See more discussion on the relationship between MAF and IAF in the next section.\nGiven two random variables, $\\mathbf{z} \\sim \\pi(\\mathbf{z})$ and $\\mathbf{x} \\sim p(\\mathbf{x})$ and the probability density function $\\pi(\\mathbf{z})$ is known, MAF aims to learn $p(\\mathbf{x})$. MAF generates each $x_i$ conditioned on the past dimensions $\\mathbf{x}_{1:i-1}$.\nPrecisely the conditional probability is an affine transformation of $\\mathbf{z}$, where the scale and shift terms are functions of the observed part of $\\mathbf{x}$.\n Data generation, producing a new $\\mathbf{x}$: $x_i \\sim p(x_i\\vert\\mathbf{x}_{1:i-1}) = z_i \\odot \\sigma_i(\\mathbf{x}_{1:i-1}) + \\mu_i(\\mathbf{x}_{1:i-1})\\text{, where }\\mathbf{z} \\sim \\pi(\\mathbf{z})$\n Density estimation, given a known $\\mathbf{x}$: $p(\\mathbf{x}) = \\prod_{i=1}^D p(x_i\\vert\\mathbf{x}_{1:i-1})$\nThe generation procedure is sequential, so it is slow by design. While density estimation only needs one pass the network using architecture like MADE. The transformation function is trivial to inverse and the Jacobian determinant is easy to compute too.\nInverse Autoregressive Flow Similar to MAF, Inverse autoregressive flow (IAF; Kingma et al., 2016) models the conditional probability of the target variable as an autoregressive model too, but with a reversed flow, thus achieving a much efficient sampling process.\nFirst, let\u0026rsquo;s reverse the affine transformation in MAF:\n $$ z_i = \\frac{x_i - \\mu_i(\\mathbf{x}_{1:i-1})}{\\sigma_i(\\mathbf{x}_{1:i-1})} = -\\frac{\\mu_i(\\mathbf{x}_{1:i-1})}{\\sigma_i(\\mathbf{x}_{1:i-1})} + x_i \\odot \\frac{1}{\\sigma_i(\\mathbf{x}_{1:i-1})} $$ If let:\n $$ \\begin{aligned} \u0026 \\tilde{\\mathbf{x}} = \\mathbf{z}\\text{, }\\tilde{p}(.) = \\pi(.)\\text{, }\\tilde{\\mathbf{x}} \\sim \\tilde{p}(\\tilde{\\mathbf{x}}) \\\\ \u0026 \\tilde{\\mathbf{z}} = \\mathbf{x} \\text{, }\\tilde{\\pi}(.) = p(.)\\text{, }\\tilde{\\mathbf{z}} \\sim \\tilde{\\pi}(\\tilde{\\mathbf{z}})\\\\ \u0026 \\tilde{\\mu}_i(\\tilde{\\mathbf{z}}_{1:i-1}) = \\tilde{\\mu}_i(\\mathbf{x}_{1:i-1}) = -\\frac{\\mu_i(\\mathbf{x}_{1:i-1})}{\\sigma_i(\\mathbf{x}_{1:i-1})} \\\\ \u0026 \\tilde{\\sigma}(\\tilde{\\mathbf{z}}_{1:i-1}) = \\tilde{\\sigma}(\\mathbf{x}_{1:i-1}) = \\frac{1}{\\sigma_i(\\mathbf{x}_{1:i-1})} \\end{aligned} $$ Then we would have,\n $$ \\tilde{x}_i \\sim p(\\tilde{x}_i\\vert\\tilde{\\mathbf{z}}_{1:i}) = \\tilde{z}_i \\odot \\tilde{\\sigma}_i(\\tilde{\\mathbf{z}}_{1:i-1}) + \\tilde{\\mu}_i(\\tilde{\\mathbf{z}}_{1:i-1}) \\text{, where }\\tilde{\\mathbf{z}} \\sim \\tilde{\\pi}(\\tilde{\\mathbf{z}}) $$ IAF intends to estimate the probability density function of $\\tilde{\\mathbf{x}}$ given that $\\tilde{\\pi}(\\tilde{\\mathbf{z}})$ is already known. The inverse flow is an autoregressive affine transformation too, same as in MAF, but the scale and shift terms are autoregressive functions of observed variables from the known distribution $\\tilde{\\pi}(\\tilde{\\mathbf{z}})$. See the comparison between MAF and IAF in Fig. 10.\nFig. 10. Comparison of MAF and IAF. The variable with known density is in green while the unknown one is in red. Computations of the individual elements $\\tilde{x}_i$ do not depend on each other, so they are easily parallelizable (only one pass using MADE). The density estimation for a known $\\tilde{\\mathbf{x}}$ is not efficient, because we have to recover the value of $\\tilde{z}_i$ in a sequential order, $\\tilde{z}_i = (\\tilde{x}_i - \\tilde{\\mu}_i(\\tilde{\\mathbf{z}}_{1:i-1})) / \\tilde{\\sigma}_i(\\tilde{\\mathbf{z}}_{1:i-1})$, thus D times in total.\n Base distribution Target distribution Model Data generation Density estimation MAF $\\mathbf{z}\\sim\\pi(\\mathbf{z})$ $\\mathbf{x}\\sim p(\\mathbf{x})$ $x_i = z_i \\odot \\sigma_i(\\mathbf{x}_{1:i-1}) + \\mu_i(\\mathbf{x}_{1:i-1})$ Sequential; slow One pass; fast IAF $\\tilde{\\mathbf{z}}\\sim\\tilde{\\pi}(\\tilde{\\mathbf{z}})$ $\\tilde{\\mathbf{x}}\\sim\\tilde{p}(\\tilde{\\mathbf{x}})$ $\\tilde{x}_i = \\tilde{z}_i \\odot \\tilde{\\sigma}_i(\\tilde{\\mathbf{z}}_{1:i-1}) + \\tilde{\\mu}_i(\\tilde{\\mathbf{z}}_{1:i-1})$ One pass; fast Sequential; slow \u0026mdash;\u0026mdash;\u0026mdash;- \u0026mdash;\u0026mdash;\u0026mdash;- \u0026mdash;\u0026mdash;\u0026mdash;- \u0026mdash;\u0026mdash;\u0026mdash;- \u0026mdash;\u0026mdash;\u0026mdash;- \u0026mdash;\u0026mdash;\u0026mdash;- VAE + Flows In Variational Autoencoder, if we want to model the posterior $p(\\mathbf{z}\\vert\\mathbf{x})$ as a more complicated distribution rather than simple Gaussian. Intuitively we can use normalizing flow to transform the base Gaussian for better density approximation. The encoder then would predict a set of scale and shift terms $(\\mu_i, \\sigma_i)$ which are all functions of input $\\mathbf{x}$. Read the paper for more details if interested.\n If you notice mistakes and errors in this post, don\u0026rsquo;t hesitate to contact me at [lilian dot wengweng at gmail dot com] and I would be very happy to correct them right away!\nSee you in the next post :D\n Cited as:\n@article{weng2018flow, title = \u0026quot;Flow-based Deep Generative Models\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2018\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2018-10-13-flow-models/\u0026quot; } Reference [1] Danilo Jimenez Rezende, and Shakir Mohamed. \u0026ldquo;Variational inference with normalizing flows.\u0026quot; ICML 2015.\n[2] Normalizing Flows Tutorial, Part 1: Distributions and Determinants by Eric Jang.\n[3] Normalizing Flows Tutorial, Part 2: Modern Normalizing Flows by Eric Jang.\n[4] Normalizing Flows by Adam Kosiorek.\n[5] Laurent Dinh, Jascha Sohl-Dickstein, and Samy Bengio. \u0026ldquo;Density estimation using Real NVP.\u0026quot; ICLR 2017.\n[6] Laurent Dinh, David Krueger, and Yoshua Bengio. \u0026ldquo;NICE: Non-linear independent components estimation.\u0026quot; ICLR 2015 Workshop track.\n[7] Diederik P. Kingma, and Prafulla Dhariwal. \u0026ldquo;Glow: Generative flow with invertible 1x1 convolutions.\u0026quot; arXiv:1807.03039 (2018).\n[8] Germain, Mathieu, Karol Gregor, Iain Murray, and Hugo Larochelle. \u0026ldquo;Made: Masked autoencoder for distribution estimation.\u0026quot; ICML 2015.\n[9] Aaron van den Oord, Nal Kalchbrenner, and Koray Kavukcuoglu. \u0026ldquo;Pixel recurrent neural networks.\u0026quot; ICML 2016.\n[10] Diederik P. Kingma, et al. \u0026ldquo;Improved variational inference with inverse autoregressive flow.\u0026quot; NIPS. 2016.\n[11] George Papamakarios, Iain Murray, and Theo Pavlakou. \u0026ldquo;Masked autoregressive flow for density estimation.\u0026quot; NIPS 2017.\n[12] Jianlin Su, and Guang Wu. \u0026ldquo;f-VAEs: Improve VAEs with Conditional Flows.\u0026quot; arXiv:1809.05861 (2018).\n[13] Van Den Oord, Aaron, et al. \u0026ldquo;WaveNet: A generative model for raw audio.\u0026quot; SSW. 2016.\n","permalink":"https://lilianweng.github.io/posts/2018-10-13-flow-models/","summary":"So far, I\u0026rsquo;ve written about two types of generative models, GAN and VAE. Neither of them explicitly learns the probability density function of real data, $p(\\mathbf{x})$ (where $\\mathbf{x} \\in \\mathcal{D}$) \u0026mdash; because it is really hard! Taking the generative model with latent variables as an example, $p(\\mathbf{x}) = \\int p(\\mathbf{x}\\vert\\mathbf{z})p(\\mathbf{z})d\\mathbf{z}$ can hardly be calculated as it is intractable to go through all possible values of the latent code $\\mathbf{z}$.\nFlow-based deep generative models conquer this hard problem with the help of normalizing flows, a powerful statistics tool for density estimation.","title":"Flow-based Deep Generative Models"},{"content":"[Updated on 2019-07-18: add a section on VQ-VAE \u0026amp; VQ-VAE-2.] [Updated on 2019-07-26: add a section on TD-VAE.] \nAutocoder is invented to reconstruct high-dimensional data using a neural network model with a narrow bottleneck layer in the middle (oops, this is probably not true for Variational Autoencoder, and we will investigate it in details in later sections). A nice byproduct is dimension reduction: the bottleneck layer captures a compressed latent encoding. Such a low-dimensional representation can be used as en embedding vector in various applications (i.e. search), help data compression, or reveal the underlying data generative factors.\nNotation Symbol Mean $\\mathcal{D}$ The dataset, $\\mathcal{D} = \\{ \\mathbf{x}^{(1)}, \\mathbf{x}^{(2)}, \\dots, \\mathbf{x}^{(n)} \\}$, contains $n$ data samples; $\\vert\\mathcal{D}\\vert =n $. $\\mathbf{x}^{(i)}$ Each data point is a vector of $d$ dimensions, $\\mathbf{x}^{(i)} = [x^{(i)}_1, x^{(i)}_2, \\dots, x^{(i)}_d]$. $\\mathbf{x}$ One data sample from the dataset, $\\mathbf{x} \\in \\mathcal{D}$. $\\mathbf{x}’$ The reconstructed version of $\\mathbf{x}$. $\\tilde{\\mathbf{x}}$ The corrupted version of $\\mathbf{x}$. $\\mathbf{z}$ The compressed code learned in the bottleneck layer. $a_j^{(l)}$ The activation function for the $j$-th neuron in the $l$-th hidden layer. $g_{\\phi}(.)$ The encoding function parameterized by $\\phi$. $f_{\\theta}(.)$ The decoding function parameterized by $\\theta$. $q_{\\phi}(\\mathbf{z}\\vert\\mathbf{x})$ Estimated posterior probability function, also known as probabilistic encoder. $p_{\\theta}(\\mathbf{x}\\vert\\mathbf{z})$ Likelihood of generating true data sample given the latent code, also known as probabilistic decoder. Autoencoder Autoencoder is a neural network designed to learn an identity function in an unsupervised way to reconstruct the original input while compressing the data in the process so as to discover a more efficient and compressed representation. The idea was originated in the 1980s, and later promoted by the seminal paper by Hinton \u0026amp; Salakhutdinov, 2006.\nIt consists of two networks:\n Encoder network: It translates the original high-dimension input into the latent low-dimensional code. The input size is larger than the output size. Decoder network: The decoder network recovers the data from the code, likely with larger and larger output layers. Fig. 1. Illustration of autoencoder model architecture. The encoder network essentially accomplishes the dimensionality reduction, just like how we would use Principal Component Analysis (PCA) or Matrix Factorization (MF) for. In addition, the autoencoder is explicitly optimized for the data reconstruction from the code. A good intermediate representation not only can capture latent variables, but also benefits a full decompression process.\nThe model contains an encoder function $g(.)$ parameterized by $\\phi$ and a decoder function $f(.)$ parameterized by $\\theta$. The low-dimensional code learned for input $\\mathbf{x}$ in the bottleneck layer is $\\mathbf{z} = g_\\phi(\\mathbf{x})$ and the reconstructed input is $\\mathbf{x}' = f_\\theta(g_\\phi(\\mathbf{x}))$.\nThe parameters $(\\theta, \\phi)$ are learned together to output a reconstructed data sample same as the original input, $\\mathbf{x} \\approx f_\\theta(g_\\phi(\\mathbf{x}))$, or in other words, to learn an identity function. There are various metrics to quantify the difference between two vectors, such as cross entropy when the activation function is sigmoid, or as simple as MSE loss:\n $$ L_\\text{AE}(\\theta, \\phi) = \\frac{1}{n}\\sum_{i=1}^n (\\mathbf{x}^{(i)} - f_\\theta(g_\\phi(\\mathbf{x}^{(i)})))^2 $$ Denoising Autoencoder Since the autoencoder learns the identity function, we are facing the risk of \u0026ldquo;overfitting\u0026rdquo; when there are more network parameters than the number of data points.\nTo avoid overfitting and improve the robustness, Denoising Autoencoder (Vincent et al. 2008) proposed a modification to the basic autoencoder. The input is partially corrupted by adding noises to or masking some values of the input vector in a stochastic manner, $\\tilde{\\mathbf{x}} \\sim \\mathcal{M}_\\mathcal{D}(\\tilde{\\mathbf{x}} \\vert \\mathbf{x})$. Then the model is trained to recover the original input (note: not the corrupt one).\n $$ \\begin{aligned} \\tilde{\\mathbf{x}}^{(i)} \u0026\\sim \\mathcal{M}_\\mathcal{D}(\\tilde{\\mathbf{x}}^{(i)} \\vert \\mathbf{x}^{(i)})\\\\ L_\\text{DAE}(\\theta, \\phi) \u0026= \\frac{1}{n} \\sum_{i=1}^n (\\mathbf{x}^{(i)} - f_\\theta(g_\\phi(\\tilde{\\mathbf{x}}^{(i)})))^2 \\end{aligned} $$ where $\\mathcal{M}_\\mathcal{D}$ defines the mapping from the true data samples to the noisy or corrupted ones.\nFig. 2. Illustration of denoising autoencoder model architecture. This design is motivated by the fact that humans can easily recognize an object or a scene even the view is partially occluded or corrupted. To \u0026ldquo;repair\u0026rdquo; the partially destroyed input, the denoising autoencoder has to discover and capture relationship between dimensions of input in order to infer missing pieces.\nFor high dimensional input with high redundancy, like images, the model is likely to depend on evidence gathered from a combination of many input dimensions to recover the denoised version rather than to overfit one dimension. This builds up a good foundation for learning robust latent representation.\nThe noise is controlled by a stochastic mapping $\\mathcal{M}_\\mathcal{D}(\\tilde{\\mathbf{x}} \\vert \\mathbf{x})$, and it is not specific to a particular type of corruption process (i.e. masking noise, Gaussian noise, salt-and-pepper noise, etc.). Naturally the corruption process can be equipped with prior knowledge\nIn the experiment of the original DAE paper, the noise is applied in this way: a fixed proportion of input dimensions are selected at random and their values are forced to 0. Sounds a lot like dropout, right? Well, the denoising autoencoder was proposed in 2008, 4 years before the dropout paper (Hinton, et al. 2012) ;)\nFig. 3. Stacking denoising autoencoders. (Image source: Vincent et al., 2010) -- Sparse Autoencoder Sparse Autoencoder applies a \u0026ldquo;sparse\u0026rdquo; constraint on the hidden unit activation to avoid overfitting and improve robustness. It forces the model to only have a small number of hidden units being activated at the same time, or in other words, one hidden neuron should be inactivate most of time.\nRecall that common activation functions include sigmoid, tanh, relu, leaky relu, etc. A neuron is activated when the value is close to 1 and inactivate with a value close to 0.\nLet’s say there are $s_l$ neurons in the $l$-th hidden layer and the activation function for the $j$-th neuron in this layer is labelled as $a^{(l)}_j(.)$, $j=1, \\dots, s_l$. The fraction of activation of this neuron $\\hat{\\rho}_j$ is expected to be a small number $\\rho$, known as sparsity parameter; a common config is $\\rho = 0.05$.\n $$ \\hat{\\rho}_j^{(l)} = \\frac{1}{n} \\sum_{i=1}^n [a_j^{(l)}(\\mathbf{x}^{(i)})] \\approx \\rho $$ This constraint is achieved by adding a penalty term into the loss function. The KL-divergence $D_\\text{KL}$ measures the difference between two Bernoulli distributions, one with mean $\\rho$ and the other with mean $\\hat{\\rho}_j^{(l)}$. The hyperparameter $\\beta$ controls how strong the penalty we want to apply on the sparsity loss.\n $$ \\begin{aligned} L_\\text{SAE}(\\theta) \u0026= L(\\theta) + \\beta \\sum_{l=1}^L \\sum_{j=1}^{s_l} D_\\text{KL}(\\rho \\| \\hat{\\rho}_j^{(l)}) \\\\ \u0026= L(\\theta) + \\beta \\sum_{l=1}^L \\sum_{j=1}^{s_l} \\rho\\log\\frac{\\rho}{\\hat{\\rho}_j^{(l)}} + (1-\\rho)\\log\\frac{1-\\rho}{1-\\hat{\\rho}_j^{(l)}} \\end{aligned} $$ Fig. 4. The KL divergence between a Bernoulli distribution with mean $\\rho=0.25$ and a Bernoulli distribution with mean $0 \\leq \\hat{\\rho} \\leq 1$. $k$-Sparse Autoencoder\nIn $k$-Sparse Autoencoder (Makhzani and Frey, 2013), the sparsity is enforced by only keeping the top k highest activations in the bottleneck layer with linear activation function. First we run feedforward through the encoder network to get the compressed code: $\\mathbf{z} = g(\\mathbf{x})$. Sort the values in the code vector $\\mathbf{z}$. Only the k largest values are kept while other neurons are set to 0. This can be done in a ReLU layer with an adjustable threshold too. Now we have a sparsified code: $\\mathbf{z}’ = \\text{Sparsify}(\\mathbf{z})$. Compute the output and the loss from the sparsified code, $L = |\\mathbf{x} - f(\\mathbf{z}') |_2^2$. And, the back-propagation only goes through the top k activated hidden units!\nFig. 5. Filters of the k-sparse autoencoder for different sparsity levels k, learnt from MNIST with 1000 hidden units.. (Image source: Makhzani and Frey, 2013) Contractive Autoencoder Similar to sparse autoencoder, Contractive Autoencoder (Rifai, et al, 2011) encourages the learned representation to stay in a contractive space for better robustness.\nIt adds a term in the loss function to penalize the representation being too sensitive to the input, and thus improve the robustness to small perturbations around the training data points. The sensitivity is measured by the Frobenius norm of the Jacobian matrix of the encoder activations with respect to the input:\n $$ \\|J_f(\\mathbf{x})\\|_F^2 = \\sum_{ij} \\Big( \\frac{\\partial h_j(\\mathbf{x})}{\\partial x_i} \\Big)^2 $$ where $h_j$ is one unit output in the compressed code $\\mathbf{z} = f(x)$.\nThis penalty term is the sum of squares of all partial derivatives of the learned encoding with respect to input dimensions. The authors claimed that empirically this penalty was found to carve a representation that corresponds to a lower-dimensional non-linear manifold, while staying more invariant to majority directions orthogonal to the manifold.\nVAE: Variational Autoencoder The idea of Variational Autoencoder (Kingma \u0026amp; Welling, 2014), short for VAE, is actually less similar to all the autoencoder models above, but deeply rooted in the methods of variational bayesian and graphical model.\nInstead of mapping the input into a fixed vector, we want to map it into a distribution. Let’s label this distribution as $p_\\theta$, parameterized by $\\theta$. The relationship between the data input $\\mathbf{x}$ and the latent encoding vector $\\mathbf{z}$ can be fully defined by:\n Prior $p_\\theta(\\mathbf{z})$ Likelihood $p_\\theta(\\mathbf{x}\\vert\\mathbf{z})$ Posterior $p_\\theta(\\mathbf{z}\\vert\\mathbf{x})$ Assuming that we know the real parameter $\\theta^{*}$ for this distribution. In order to generate a sample that looks like a real data point $\\mathbf{x}^{(i)}$, we follow these steps:\n First, sample a $\\mathbf{z}^{(i)}$ from a prior distribution $p_{\\theta^*}(\\mathbf{z})$. Then a value $\\mathbf{x}^{(i)}$ is generated from a conditional distribution $p_{\\theta^*}(\\mathbf{x} \\vert \\mathbf{z} = \\mathbf{z}^{(i)})$. The optimal parameter $\\theta^{*}$ is the one that maximizes the probability of generating real data samples:\n $$ \\theta^{*} = \\arg\\max_\\theta \\prod_{i=1}^n p_\\theta(\\mathbf{x}^{(i)}) $$ Commonly we use the log probabilities to convert the product on RHS to a sum:\n $$ \\theta^{*} = \\arg\\max_\\theta \\sum_{i=1}^n \\log p_\\theta(\\mathbf{x}^{(i)}) $$ Now let’s update the equation to better demonstrate the data generation process so as to involve the encoding vector:\n $$ p_\\theta(\\mathbf{x}^{(i)}) = \\int p_\\theta(\\mathbf{x}^{(i)}\\vert\\mathbf{z}) p_\\theta(\\mathbf{z}) d\\mathbf{z} $$ Unfortunately it is not easy to compute $p_\\theta(\\mathbf{x}^{(i)})$ in this way, as it is very expensive to check all the possible values of $\\mathbf{z}$ and sum them up. To narrow down the value space to facilitate faster search, we would like to introduce a new approximation function to output what is a likely code given an input $\\mathbf{x}$, $q_\\phi(\\mathbf{z}\\vert\\mathbf{x})$, parameterized by $\\phi$.\nFig. 6. The graphical model involved in Variational Autoencoder. Solid lines denote the generative distribution $p\\_\\theta(.)$ and dashed lines denote the distribution $q\\_\\phi (\\mathbf{z}\\vert\\mathbf{x})$ to approximate the intractable posterior $p\\_\\theta (\\mathbf{z}\\vert\\mathbf{x})$. Now the structure looks a lot like an autoencoder:\n The conditional probability $p_\\theta(\\mathbf{x} \\vert \\mathbf{z})$ defines a generative model, similar to the decoder $f_\\theta(\\mathbf{x} \\vert \\mathbf{z})$ introduced above. $p_\\theta(\\mathbf{x} \\vert \\mathbf{z})$ is also known as probabilistic decoder. The approximation function $q_\\phi(\\mathbf{z} \\vert \\mathbf{x})$ is the probabilistic encoder, playing a similar role as $g_\\phi(\\mathbf{z} \\vert \\mathbf{x})$ above. Loss Function: ELBO The estimated posterior $q_\\phi(\\mathbf{z}\\vert\\mathbf{x})$ should be very close to the real one $p_\\theta(\\mathbf{z}\\vert\\mathbf{x})$. We can use Kullback-Leibler divergence to quantify the distance between these two distributions. KL divergence $D_\\text{KL}(X|Y)$ measures how much information is lost if the distribution Y is used to represent X.\nIn our case we want to minimize $D_\\text{KL}( q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) | p_\\theta(\\mathbf{z}\\vert\\mathbf{x}) )$ with respect to $\\phi$.\nBut why use $D_\\text{KL}(q_\\phi | p_\\theta)$ (reversed KL) instead of $D_\\text{KL}(p_\\theta | q_\\phi)$ (forward KL)? Eric Jang has a great explanation in his post on Bayesian Variational methods. As a quick recap:\nFig. 7. Forward and reversed KL divergence have different demands on how to match two distributions. (Image source: blog.evjang.com/2016/08/variational-bayes.html) Forward KL divergence: $D_\\text{KL}(P|Q) = \\mathbb{E}_{z\\sim P(z)} \\log\\frac{P(z)}{Q(z)}$; we have to ensure that Q(z)\u0026gt;0 wherever P(z)\u0026gt;0. The optimized variational distribution $q(z)$ has to cover over the entire $p(z)$. Reversed KL divergence: $D_\\text{KL}(Q|P) = \\mathbb{E}_{z\\sim Q(z)} \\log\\frac{Q(z)}{P(z)}$; minimizing the reversed KL divergence squeezes the $Q(z)$ under $P(z)$. Let\u0026rsquo;s now expand the equation:\n $$ \\begin{aligned} \u0026 D_\\text{KL}( q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z}\\vert\\mathbf{x}) ) \u0026 \\\\ \u0026=\\int q_\\phi(\\mathbf{z} \\vert \\mathbf{x}) \\log\\frac{q_\\phi(\\mathbf{z} \\vert \\mathbf{x})}{p_\\theta(\\mathbf{z} \\vert \\mathbf{x})} d\\mathbf{z} \u0026 \\\\ \u0026=\\int q_\\phi(\\mathbf{z} \\vert \\mathbf{x}) \\log\\frac{q_\\phi(\\mathbf{z} \\vert \\mathbf{x})p_\\theta(\\mathbf{x})}{p_\\theta(\\mathbf{z}, \\mathbf{x})} d\\mathbf{z} \u0026 \\scriptstyle{\\text{; Because }p(z \\vert x) = p(z, x) / p(x)} \\\\ \u0026=\\int q_\\phi(\\mathbf{z} \\vert \\mathbf{x}) \\big( \\log p_\\theta(\\mathbf{x}) + \\log\\frac{q_\\phi(\\mathbf{z} \\vert \\mathbf{x})}{p_\\theta(\\mathbf{z}, \\mathbf{x})} \\big) d\\mathbf{z} \u0026 \\\\ \u0026=\\log p_\\theta(\\mathbf{x}) + \\int q_\\phi(\\mathbf{z} \\vert \\mathbf{x})\\log\\frac{q_\\phi(\\mathbf{z} \\vert \\mathbf{x})}{p_\\theta(\\mathbf{z}, \\mathbf{x})} d\\mathbf{z} \u0026 \\scriptstyle{\\text{; Because }\\int q(z \\vert x) dz = 1}\\\\ \u0026=\\log p_\\theta(\\mathbf{x}) + \\int q_\\phi(\\mathbf{z} \\vert \\mathbf{x})\\log\\frac{q_\\phi(\\mathbf{z} \\vert \\mathbf{x})}{p_\\theta(\\mathbf{x}\\vert\\mathbf{z})p_\\theta(\\mathbf{z})} d\\mathbf{z} \u0026 \\scriptstyle{\\text{; Because }p(z, x) = p(x \\vert z) p(z)} \\\\ \u0026=\\log p_\\theta(\\mathbf{x}) + \\mathbb{E}_{\\mathbf{z}\\sim q_\\phi(\\mathbf{z} \\vert \\mathbf{x})}[\\log \\frac{q_\\phi(\\mathbf{z} \\vert \\mathbf{x})}{p_\\theta(\\mathbf{z})} - \\log p_\\theta(\\mathbf{x} \\vert \\mathbf{z})] \u0026\\\\ \u0026=\\log p_\\theta(\\mathbf{x}) + D_\\text{KL}(q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z})) - \\mathbb{E}_{\\mathbf{z}\\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})}\\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z}) \u0026 \\end{aligned} $$ So we have:\n $$ D_\\text{KL}( q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z}\\vert\\mathbf{x}) ) =\\log p_\\theta(\\mathbf{x}) + D_\\text{KL}(q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z})) - \\mathbb{E}_{\\mathbf{z}\\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})}\\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z}) $$ Once rearrange the left and right hand side of the equation,\n $$ \\log p_\\theta(\\mathbf{x}) - D_\\text{KL}( q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z}\\vert\\mathbf{x}) ) = \\mathbb{E}_{\\mathbf{z}\\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})}\\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z}) - D_\\text{KL}(q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z})) $$ The LHS of the equation is exactly what we want to maximize when learning the true distributions: we want to maximize the (log-)likelihood of generating real data (that is $\\log p_\\theta(\\mathbf{x})$) and also minimize the difference between the real and estimated posterior distributions (the term $D_\\text{KL}$ works like a regularizer). Note that $p_\\theta(\\mathbf{x})$ is fixed with respect to $q_\\phi$.\nThe negation of the above defines our loss function:\n $$ \\begin{aligned} L_\\text{VAE}(\\theta, \\phi) \u0026= -\\log p_\\theta(\\mathbf{x}) + D_\\text{KL}( q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z}\\vert\\mathbf{x}) )\\\\ \u0026= - \\mathbb{E}_{\\mathbf{z} \\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})} \\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z}) + D_\\text{KL}( q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z}) ) \\\\ \\theta^{*}, \\phi^{*} \u0026= \\arg\\min_{\\theta, \\phi} L_\\text{VAE} \\end{aligned} $$ In Variational Bayesian methods, this loss function is known as the variational lower bound, or evidence lower bound. The \u0026ldquo;lower bound\u0026rdquo; part in the name comes from the fact that KL divergence is always non-negative and thus $-L_\\text{VAE}$ is the lower bound of $\\log p_\\theta (\\mathbf{x})$.\n $$ -L_\\text{VAE} = \\log p_\\theta(\\mathbf{x}) - D_\\text{KL}( q_\\phi(\\mathbf{z}\\vert\\mathbf{x}) \\| p_\\theta(\\mathbf{z}\\vert\\mathbf{x}) ) \\leq \\log p_\\theta(\\mathbf{x}) $$ Therefore by minimizing the loss, we are maximizing the lower bound of the probability of generating real data samples.\nReparameterization Trick The expectation term in the loss function invokes generating samples from $\\mathbf{z} \\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})$. Sampling is a stochastic process and therefore we cannot backpropagate the gradient. To make it trainable, the reparameterization trick is introduced: It is often possible to express the random variable $\\mathbf{z}$ as a deterministic variable $\\mathbf{z} = \\mathcal{T}_\\phi(\\mathbf{x}, \\boldsymbol{\\epsilon})$, where $\\boldsymbol{\\epsilon}$ is an auxiliary independent random variable, and the transformation function $\\mathcal{T}_\\phi$ parameterized by $\\phi$ converts $\\boldsymbol{\\epsilon}$ to $\\mathbf{z}$.\nFor example, a common choice of the form of $q_\\phi(\\mathbf{z}\\vert\\mathbf{x})$ is a multivariate Gaussian with a diagonal covariance structure:\n $$ \\begin{aligned} \\mathbf{z} \u0026\\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x}^{(i)}) = \\mathcal{N}(\\mathbf{z}; \\boldsymbol{\\mu}^{(i)}, \\boldsymbol{\\sigma}^{2(i)}\\boldsymbol{I}) \u0026 \\\\ \\mathbf{z} \u0026= \\boldsymbol{\\mu} + \\boldsymbol{\\sigma} \\odot \\boldsymbol{\\epsilon} \\text{, where } \\boldsymbol{\\epsilon} \\sim \\mathcal{N}(0, \\boldsymbol{I}) \u0026 \\scriptstyle{\\text{; Reparameterization trick.}} \\end{aligned} $$ where $\\odot$ refers to element-wise product.\nFig. 8. Illustration of how the reparameterization trick makes the $\\mathbf{z}$ sampling process trainable.(Image source: Slide 12 in Kingma’s NIPS 2015 workshop talk) The reparameterization trick works for other types of distributions too, not only Gaussian. In the multivariate Gaussian case, we make the model trainable by learning the mean and variance of the distribution, $\\mu$ and $\\sigma$, explicitly using the reparameterization trick, while the stochasticity remains in the random variable $\\boldsymbol{\\epsilon} \\sim \\mathcal{N}(0, \\boldsymbol{I})$.\nFig. 9. Illustration of variational autoencoder model with the multivariate Gaussian assumption. Beta-VAE If each variable in the inferred latent representation $\\mathbf{z}$ is only sensitive to one single generative factor and relatively invariant to other factors, we will say this representation is disentangled or factorized. One benefit that often comes with disentangled representation is good interpretability and easy generalization to a variety of tasks.\nFor example, a model trained on photos of human faces might capture the gentle, skin color, hair color, hair length, emotion, whether wearing a pair of glasses and many other relatively independent factors in separate dimensions. Such a disentangled representation is very beneficial to facial image generation.\nβ-VAE (Higgins et al., 2017) is a modification of Variational Autoencoder with a special emphasis to discover disentangled latent factors. Following the same incentive in VAE, we want to maximize the probability of generating real data, while keeping the distance between the real and estimated posterior distributions small (say, under a small constant $\\delta$):\n $$ \\begin{aligned} \u0026\\max_{\\phi, \\theta} \\mathbb{E}_{\\mathbf{x}\\sim\\mathcal{D}}[\\mathbb{E}_{\\mathbf{z} \\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})} \\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z})]\\\\ \u0026\\text{subject to } D_\\text{KL}(q_\\phi(\\mathbf{z}\\vert\\mathbf{x})\\|p_\\theta(\\mathbf{z})) We can rewrite it as a Lagrangian with a Lagrangian multiplier $\\beta$ under the KKT condition. The above optimization problem with only one inequality constraint is equivalent to maximizing the following equation $\\mathcal{F}(\\theta, \\phi, \\beta)$:\n $$ \\begin{aligned} \\mathcal{F}(\\theta, \\phi, \\beta) \u0026= \\mathbb{E}_{\\mathbf{z} \\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})} \\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z}) - \\beta(D_\\text{KL}(q_\\phi(\\mathbf{z}\\vert\\mathbf{x})\\|p_\\theta(\\mathbf{z})) - \\delta) \u0026 \\\\ \u0026 = \\mathbb{E}_{\\mathbf{z} \\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})} \\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z}) - \\beta D_\\text{KL}(q_\\phi(\\mathbf{z}\\vert\\mathbf{x})\\|p_\\theta(\\mathbf{z})) + \\beta \\delta \u0026 \\\\ \u0026 \\geq \\mathbb{E}_{\\mathbf{z} \\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})} \\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z}) - \\beta D_\\text{KL}(q_\\phi(\\mathbf{z}\\vert\\mathbf{x})\\|p_\\theta(\\mathbf{z})) \u0026 \\scriptstyle{\\text{; Because }\\beta,\\delta\\geq 0} \\end{aligned} $$ The loss function of $\\beta$-VAE is defined as:\n $$ L_\\text{BETA}(\\phi, \\beta) = - \\mathbb{E}_{\\mathbf{z} \\sim q_\\phi(\\mathbf{z}\\vert\\mathbf{x})} \\log p_\\theta(\\mathbf{x}\\vert\\mathbf{z}) + \\beta D_\\text{KL}(q_\\phi(\\mathbf{z}\\vert\\mathbf{x})\\|p_\\theta(\\mathbf{z})) $$ where the Lagrangian multiplier $\\beta$ is considered as a hyperparameter.\nSince the negation of $L_\\text{BETA}(\\phi, \\beta)$ is the lower bound of the Lagrangian $\\mathcal{F}(\\theta, \\phi, \\beta)$. Minimizing the loss is equivalent to maximizing the Lagrangian and thus works for our initial optimization problem.\nWhen $\\beta=1$, it is same as VAE. When $\\beta \u0026gt; 1$, it applies a stronger constraint on the latent bottleneck and limits the representation capacity of $\\mathbf{z}$. For some conditionally independent generative factors, keeping them disentangled is the most efficient representation. Therefore a higher $\\beta$ encourages more efficient latent encoding and further encourages the disentanglement. Meanwhile, a higher $\\beta$ may create a trade-off between reconstruction quality and the extent of disentanglement.\nBurgess, et al. (2017) discussed the distentangling in $\\beta$-VAE in depth with an inspiration by the information bottleneck theory and further proposed a modification to $\\beta$-VAE to better control the encoding representation capacity.\nVQ-VAE and VQ-VAE-2 The VQ-VAE (“Vector Quantised-Variational AutoEncoder”; van den Oord, et al. 2017) model learns a discrete latent variable by the encoder, since discrete representations may be a more natural fit for problems like language, speech, reasoning, etc.\nVector quantisation (VQ) is a method to map $K$-dimensional vectors into a finite set of “code” vectors. The process is very much similar to KNN algorithm. The optimal centroid code vector that a sample should be mapped to is the one with minimum euclidean distance.\nLet $\\mathbf{e} \\in \\mathbb{R}^{K \\times D}, i=1, \\dots, K$ be the latent embedding space (also known as \u0026ldquo;codebook\u0026rdquo;) in VQ-VAE, where $K$ is the number of latent variable categories and $D$ is the embedding size. An individual embedding vector is $\\mathbf{e}_i \\in \\mathbb{R}^{D}, i=1, \\dots, K$.\nThe encoder output $E(\\mathbf{x}) = \\mathbf{z}_e$ goes through a nearest-neighbor lookup to match to one of $K$ embedding vectors and then this matched code vector becomes the input for the decoder $D(.)$:\n $$ \\mathbf{z}_q(\\mathbf{x}) = \\text{Quantize}(E(\\mathbf{x})) = \\mathbf{e}_k \\text{ where } k = \\arg\\min_i \\|E(\\mathbf{x}) - \\mathbf{e}_i \\|_2 $$ Note that the discrete latent variables can have different shapes in differnet applications; for example, 1D for speech, 2D for image and 3D for video.\nFig. 10. The architecture of VQ-VAE (Image source: van den Oord, et al. 2017) Because argmin() is non-differentiable on a discrete space, the gradients $\\nabla_z L$ from decoder input $\\mathbf{z}_q$ is copied to the encoder output $\\mathbf{z}_e$. Other than reconstruction loss, VQ-VAE also optimizes:\n VQ loss: The L2 error between the embedding space and the encoder outputs. Commitment loss: A measure to encourage the encoder output to stay close to the embedding space and to prevent it from fluctuating too frequently from one code vector to another. $$ L = \\underbrace{\\|\\mathbf{x} - D(\\mathbf{e}_k)\\|_2^2}_{\\textrm{reconstruction loss}} + \\underbrace{\\|\\text{sg}[E(\\mathbf{x})] - \\mathbf{e}_k\\|_2^2}_{\\textrm{VQ loss}} + \\underbrace{\\beta \\|E(\\mathbf{x}) - \\text{sg}[\\mathbf{e}_k]\\|_2^2}_{\\textrm{commitment loss}} $$ where $\\text{sq}[.]$ is the stop_gradient operator.\nThe embedding vectors in the codebook is updated through EMA (exponential moving average). Given a code vector $\\mathbf{e}_i$, say we have $n_i$ encoder output vectors, $\\{\\mathbf{z}_{i,j}\\}_{j=1}^{n_i}$, that are quantized to $\\mathbf{e}_i$:\n $$ N_i^{(t)} = \\gamma N_i^{(t-1)} + (1-\\gamma)n_i^{(t)}\\;\\;\\; \\mathbf{m}_i^{(t)} = \\gamma \\mathbf{m}_i^{(t-1)} + (1-\\gamma)\\sum_{j=1}^{n_i^{(t)}}\\mathbf{z}_{i,j}^{(t)}\\;\\;\\; \\mathbf{e}_i^{(t)} = \\mathbf{m}_i^{(t)} / N_i^{(t)} $$ where $(t)$ refers to batch sequence in time. $N_i$ and $\\mathbf{m}_i$ are accumulated vector count and volume, respectively.\nVQ-VAE-2 (Ali Razavi, et al. 2019) is a two-level hierarchical VQ-VAE combined with self-attention autoregressive model.\n Stage 1 is to train a hierarchical VQ-VAE: The design of hierarchical latent variables intends to separate local patterns (i.e., texture) from global information (i.e., object shapes). The training of the larger bottom level codebook is conditioned on the smaller top level code too, so that it does not have to learn everything from scratch. Stage 2 is to learn a prior over the latent discrete codebook so that we sample from it and generate images. In this way, the decoder can receive input vectors sampled from a similar distribution as the one in training. A powerful autoregressive model enhanced with multi-headed self-attention layers is used to capture the prior distribution (like PixelSNAIL; Chen et al 2017). Considering that VQ-VAE-2 depends on discrete latent variables configured in a simple hierarchical setting, the quality of its generated images are pretty amazing.\nFig. 11. Architecture of hierarchical VQ-VAE and multi-stage image generation. (Image source: Ali Razavi, et al. 2019) Fig. 12. The VQ-VAE-2 algorithm. (Image source: Ali Razavi, et al. 2019) TD-VAE TD-VAE (“Temporal Difference VAE”; Gregor et al., 2019) works with sequential data. It relies on three main ideas, described below.\nFig. 13. State-space model as a Markov Chain model. 1. State-Space Models In (latent) state-space models, a sequence of unobserved hidden states $\\mathbf{z} = (z_1, \\dots, z_T)$ determine the observation states $\\mathbf{x} = (x_1, \\dots, x_T)$. Each time step in the Markov chain model in Fig. 13 can be trained in a similar manner as in Fig. 6, where the intractable posterior $p(z \\vert x)$ is approximated by a function $q(z \\vert x)$.\n2. Belief State An agent should learn to encode all the past states to reason about the future, named as belief state, $b_t = belief(x_1, \\dots, x_t) = belief(b_{t-1}, x_t)$. Given this, the distribution of future states conditioned on the past can be written as $p(x_{t+1}, \\dots, x_T \\vert x_1, \\dots, x_t) \\approx p(x_{t+1}, \\dots, x_T \\vert b_t)$. The hidden states in a recurrent policy are used as the agent\u0026rsquo;s belief state in TD-VAE. Thus we have $b_t = \\text{RNN}(b_{t-1}, x_t)$.\n3. Jumpy Prediction Further, an agent is expected to imagine distant futures based on all the information gathered so far, suggesting the capability of making jumpy predictions, that is, predicting states several steps further into the future.\nRecall what we have learned from the variance lower bound above:\n $$ \\begin{aligned} \\log p(x) \u0026\\geq \\log p(x) - D_\\text{KL}(q(z|x)\\|p(z|x)) \\\\ \u0026= \\mathbb{E}_{z\\sim q} \\log p(x|z) - D_\\text{KL}(q(z|x)\\|p(z)) \\\\ \u0026= \\mathbb{E}_{z \\sim q} \\log p(x|z) - \\mathbb{E}_{z \\sim q} \\log \\frac{q(z|x)}{p(z)} \\\\ \u0026= \\mathbb{E}_{z \\sim q}[\\log p(x|z) -\\log q(z|x) + \\log p(z)] \\\\ \u0026= \\mathbb{E}_{z \\sim q}[\\log p(x, z) -\\log q(z|x)] \\\\ \\log p(x) \u0026\\geq \\mathbb{E}_{z \\sim q}[\\log p(x, z) -\\log q(z|x)] \\end{aligned} $$ Now let\u0026rsquo;s model the distribution of the state $x_t$ as a probability function conditioned on all the past states $x_{\u0026lt;t}$ and two latent variables, $z_t$ and $z_{t-1}$, at current time step and one step back:\n $$ \\log p(x_t|x_{Continue expanding the equation:\n $$ \\begin{aligned} \u0026 \\log p(x_t|x_{Notice two things:\n The red terms can be ignored according to Markov assumptions. The blue term is expanded according to Markov assumptions. The green term is expanded to include an one-step prediction back to the past as a smoothing distribution. Precisely, there are four types of distributions to learn:\n $p_D(.)$ is the decoder distribution: $p(x_t \\mid z_t)$ is the encoder by the common definition; $p(x_t \\mid z_t) \\to p_D(x_t \\mid z_t)$; $p_T(.)$ is the transition distribution: $p(z_t \\mid z_{t-1})$ captures the sequential dependency between latent variables; $p(z_t \\mid z_{t-1}) \\to p_T(z_t \\mid z_{t-1})$; $p_B(.)$ is the belief distribution: Both $p(z_{t-1} \\mid x_{\u0026lt;t})$ and $q(z_t \\mid x_{\\leq t})$ can use the belief states to predict the latent variables; $p(z_{t-1} \\mid x_{\u0026lt;t}) \\to p_B(z_{t-1} \\mid b_{t-1})$; $q(z_{t} \\mid x_{\\leq t}) \\to p_B(z_t \\mid b_t)$; $p_S(.)$ is the smoothing distribution: The back-to-past smoothing term $q(z_{t-1} \\mid z_t, x_{\\leq t})$ can be rewritten to be dependent of belief states too; $q(z_{t-1} \\mid z_t, x_{\\leq t}) \\to p_S(z_{t-1} \\mid z_t, b_{t-1}, b_t)$; To incorporate the idea of jumpy prediction, the sequential ELBO has to not only work on $t, t+1$, but also two distant timestamp $t_1 \u0026lt; t_2$. Here is the final TD-VAE objective function to maximize:\n $$ J_{t_1, t_2} = \\mathbb{E}[ \\log p_D(x_{t_2}|z_{t_2}) + \\log p_B(z_{t_1}|b_{t_1}) + \\log p_T(z_{t_2}|z_{t_1}) - \\log p_B(z_{t_2}|b_{t_2}) - \\log p_S(z_{t_1}|z_{t_2}, b_{t_1}, b_{t_2})] $$ Fig. 14. A detailed overview of TD-VAE architecture, very nicely done. (Image source: TD-VAE paper) Cited as:\n@article{weng2018VAE, title = \u0026quot;From Autoencoder to Beta-VAE\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2018\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2018-08-12-vae/\u0026quot; } References [1] Geoffrey E. Hinton, and Ruslan R. Salakhutdinov. \u0026ldquo;Reducing the dimensionality of data with neural networks.\u0026quot; Science 313.5786 (2006): 504-507.\n[2] Pascal Vincent, et al. \u0026ldquo;Extracting and composing robust features with denoising autoencoders.\u0026quot; ICML, 2008.\n[3] Pascal Vincent, et al. \u0026ldquo;Stacked denoising autoencoders: Learning useful representations in a deep network with a local denoising criterion.\u0026quot;. Journal of machine learning research 11.Dec (2010): 3371-3408.\n[4] Geoffrey E. Hinton, Nitish Srivastava, Alex Krizhevsky, Ilya Sutskever, and Ruslan R. Salakhutdinov. \u0026ldquo;Improving neural networks by preventing co-adaptation of feature detectors.\u0026rdquo; arXiv preprint arXiv:1207.0580 (2012).\n[5] Sparse Autoencoder by Andrew Ng.\n[6] Alireza Makhzani, Brendan Frey (2013). \u0026ldquo;k-sparse autoencoder\u0026rdquo;. ICLR 2014.\n[7] Salah Rifai, et al. \u0026ldquo;Contractive auto-encoders: Explicit invariance during feature extraction.\u0026quot; ICML, 2011.\n[8] Diederik P. Kingma, and Max Welling. \u0026ldquo;Auto-encoding variational bayes.\u0026quot; ICLR 2014.\n[9] Tutorial - What is a variational autoencoder? on jaan.io\n[10] Youtube tutorial: Variational Autoencoders by Arxiv Insights\n[11] \u0026ldquo;A Beginner\u0026rsquo;s Guide to Variational Methods: Mean-Field Approximation\u0026rdquo; by Eric Jang.\n[12] Carl Doersch. \u0026ldquo;Tutorial on variational autoencoders.\u0026quot; arXiv:1606.05908, 2016.\n[13] Irina Higgins, et al. \u0026quot;$\\beta$-VAE: Learning basic visual concepts with a constrained variational framework.\u0026quot; ICLR 2017.\n[14] Christopher P. Burgess, et al. \u0026ldquo;Understanding disentangling in beta-VAE.\u0026quot; NIPS 2017.\n[15] Aaron van den Oord, et al. \u0026ldquo;Neural Discrete Representation Learning\u0026rdquo; NIPS 2017.\n[16] Ali Razavi, et al. \u0026ldquo;Generating Diverse High-Fidelity Images with VQ-VAE-2\u0026rdquo;. arXiv preprint arXiv:1906.00446 (2019).\n[17] Xi Chen, et al. \u0026ldquo;PixelSNAIL: An Improved Autoregressive Generative Model.\u0026quot; arXiv preprint arXiv:1712.09763 (2017).\n[18] Karol Gregor, et al. \u0026ldquo;Temporal Difference Variational Auto-Encoder.\u0026quot; ICLR 2019.\n","permalink":"https://lilianweng.github.io/posts/2018-08-12-vae/","summary":"[Updated on 2019-07-18: add a section on VQ-VAE \u0026amp; VQ-VAE-2.] [Updated on 2019-07-26: add a section on TD-VAE.] \nAutocoder is invented to reconstruct high-dimensional data using a neural network model with a narrow bottleneck layer in the middle (oops, this is probably not true for Variational Autoencoder, and we will investigate it in details in later sections). A nice byproduct is dimension reduction: the bottleneck layer captures a compressed latent encoding.","title":"From Autoencoder to Beta-VAE"},{"content":"[Updated on 2018-10-28: Add Pointer Network and the link to my implementation of Transformer.] [Updated on 2018-11-06: Add a link to the implementation of Transformer model.] [Updated on 2018-11-18: Add Neural Turing Machines.] [Updated on 2019-07-18: Correct the mistake on using the term \u0026ldquo;self-attention\u0026rdquo; when introducing the show-attention-tell paper; moved it to Self-Attention section.] [Updated on 2020-04-07: A follow-up post on improved Transformer models is here.]\nAttention is, to some extent, motivated by how we pay visual attention to different regions of an image or correlate words in one sentence. Take the picture of a Shiba Inu in Fig. 1 as an example.\nFig. 1. A Shiba Inu in a men’s outfit. The credit of the original photo goes to Instagram @mensweardog. Human visual attention allows us to focus on a certain region with \u0026ldquo;high resolution\u0026rdquo; (i.e. look at the pointy ear in the yellow box) while perceiving the surrounding image in \u0026ldquo;low resolution\u0026rdquo; (i.e. now how about the snowy background and the outfit?), and then adjust the focal point or do the inference accordingly. Given a small patch of an image, pixels in the rest provide clues what should be displayed there. We expect to see a pointy ear in the yellow box because we have seen a dog’s nose, another pointy ear on the right, and Shiba\u0026rsquo;s mystery eyes (stuff in the red boxes). However, the sweater and blanket at the bottom would not be as helpful as those doggy features.\nSimilarly, we can explain the relationship between words in one sentence or close context. When we see \u0026ldquo;eating\u0026rdquo;, we expect to encounter a food word very soon. The color term describes the food, but probably not so much with \u0026ldquo;eating\u0026rdquo; directly.\nFig. 2. One word \"attends\" to other words in the same sentence differently. In a nutshell, attention in deep learning can be broadly interpreted as a vector of importance weights: in order to predict or infer one element, such as a pixel in an image or a word in a sentence, we estimate using the attention vector how strongly it is correlated with (or \u0026ldquo;attends to\u0026rdquo; as you may have read in many papers) other elements and take the sum of their values weighted by the attention vector as the approximation of the target.\nWhat’s Wrong with Seq2Seq Model? The seq2seq model was born in the field of language modeling (Sutskever, et al. 2014). Broadly speaking, it aims to transform an input sequence (source) to a new one (target) and both sequences can be of arbitrary lengths. Examples of transformation tasks include machine translation between multiple languages in either text or audio, question-answer dialog generation, or even parsing sentences into grammar trees.\nThe seq2seq model normally has an encoder-decoder architecture, composed of:\n An encoder processes the input sequence and compresses the information into a context vector (also known as sentence embedding or \u0026ldquo;thought\u0026rdquo; vector) of a fixed length. This representation is expected to be a good summary of the meaning of the whole source sequence. A decoder is initialized with the context vector to emit the transformed output. The early work only used the last state of the encoder network as the decoder initial state. Both the encoder and decoder are recurrent neural networks, i.e. using LSTM or GRU units.\nFig. 3. The encoder-decoder model, translating the sentence \"she is eating a green apple\" to Chinese. The visualization of both encoder and decoder is unrolled in time. A critical and apparent disadvantage of this fixed-length context vector design is incapability of remembering long sentences. Often it has forgotten the first part once it completes processing the whole input. The attention mechanism was born (Bahdanau et al., 2015) to resolve this problem.\nBorn for Translation The attention mechanism was born to help memorize long source sentences in neural machine translation (NMT). Rather than building a single context vector out of the encoder\u0026rsquo;s last hidden state, the secret sauce invented by attention is to create shortcuts between the context vector and the entire source input. The weights of these shortcut connections are customizable for each output element.\nWhile the context vector has access to the entire input sequence, we don’t need to worry about forgetting. The alignment between the source and target is learned and controlled by the context vector. Essentially the context vector consumes three pieces of information:\n encoder hidden states; decoder hidden states; alignment between source and target. Fig. 4. The encoder-decoder model with additive attention mechanism in Bahdanau et al., 2015. Definition Now let’s define the attention mechanism introduced in NMT in a scientific way. Say, we have a source sequence $\\mathbf{x}$ of length $n$ and try to output a target sequence $\\mathbf{y}$ of length $m$:\n $$ \\begin{aligned} \\mathbf{x} \u0026= [x_1, x_2, \\dots, x_n] \\\\ \\mathbf{y} \u0026= [y_1, y_2, \\dots, y_m] \\end{aligned} $$ (Variables in bold indicate that they are vectors; same for everything else in this post.)\nThe encoder is a bidirectional RNN (or other recurrent network setting of your choice) with a forward hidden state $\\overrightarrow{\\boldsymbol{h}}_i$ and a backward one $\\overleftarrow{\\boldsymbol{h}}_i$. A simple concatenation of two represents the encoder state. The motivation is to include both the preceding and following words in the annotation of one word.\n $$ \\boldsymbol{h}_i = [\\overrightarrow{\\boldsymbol{h}}_i^\\top; \\overleftarrow{\\boldsymbol{h}}_i^\\top]^\\top, i=1,\\dots,n $$ The decoder network has hidden state $\\boldsymbol{s}_t=f(\\boldsymbol{s}_{t-1}, y_{t-1}, \\mathbf{c}_t)$ for the output word at position t, $t=1,\\dots,m$, where the context vector $\\mathbf{c}_t$ is a sum of hidden states of the input sequence, weighted by alignment scores:\n $$ \\begin{aligned} \\mathbf{c}_t \u0026= \\sum_{i=1}^n \\alpha_{t,i} \\boldsymbol{h}_i \u0026 \\small{\\text{; Context vector for output }y_t}\\\\ \\alpha_{t,i} \u0026= \\text{align}(y_t, x_i) \u0026 \\small{\\text{; How well two words }y_t\\text{ and }x_i\\text{ are aligned.}}\\\\ \u0026= \\frac{\\exp(\\text{score}(\\boldsymbol{s}_{t-1}, \\boldsymbol{h}_i))}{\\sum_{i'=1}^n \\exp(\\text{score}(\\boldsymbol{s}_{t-1}, \\boldsymbol{h}_{i'}))} \u0026 \\small{\\text{; Softmax of some predefined alignment score.}}. \\end{aligned} $$ The alignment model assigns a score $\\alpha_{t,i}$ to the pair of input at position i and output at position t, $(y_t, x_i)$, based on how well they match. The set of $\\{\\alpha_{t, i}\\}$ are weights defining how much of each source hidden state should be considered for each output. In Bahdanau\u0026rsquo;s paper, the alignment score $\\alpha$ is parametrized by a feed-forward network with a single hidden layer and this network is jointly trained with other parts of the model. The score function is therefore in the following form, given that tanh is used as the non-linear activation function:\n $$ \\text{score}(\\boldsymbol{s}_t, \\boldsymbol{h}_i) = \\mathbf{v}_a^\\top \\tanh(\\mathbf{W}_a[\\boldsymbol{s}_t; \\boldsymbol{h}_i]) $$ where both $\\mathbf{v}_a$ and $\\mathbf{W}_a$ are weight matrices to be learned in the alignment model.\nThe matrix of alignment scores is a nice byproduct to explicitly show the correlation between source and target words.\nFig. 5. Alignment matrix of \"L'accord sur l'Espace économique européen a été signé en août 1992\" (French) and its English translation \"The agreement on the European Economic Area was signed in August 1992\". (Image source: Fig 3 in Bahdanau et al., 2015) Check out this nice tutorial by Tensorflow team for more implementation instructions.\nA Family of Attention Mechanisms With the help of the attention, the dependencies between source and target sequences are not restricted by the in-between distance anymore! Given the big improvement by attention in machine translation, it soon got extended into the computer vision field (Xu et al. 2015) and people started exploring various other forms of attention mechanisms (Luong, et al., 2015; Britz et al., 2017; Vaswani, et al., 2017).\nSummary Below is a summary table of several popular attention mechanisms and corresponding alignment score functions:\n Name Alignment score function Citation Content-base attention $\\text{score}(\\boldsymbol{s}_t, \\boldsymbol{h}_i) = \\text{cosine}[\\boldsymbol{s}_t, \\boldsymbol{h}_i]$ Graves2014 Additive(*) $\\text{score}(\\boldsymbol{s}_t, \\boldsymbol{h}_i) = \\mathbf{v}_a^\\top \\tanh(\\mathbf{W}_a[\\boldsymbol{s}_{t-1}; \\boldsymbol{h}_i])$ Bahdanau2015 Location-Base $\\alpha_{t,i} = \\text{softmax}(\\mathbf{W}_a \\boldsymbol{s}_t)$Note: This simplifies the softmax alignment to only depend on the target position. Luong2015 General $\\text{score}(\\boldsymbol{s}_t, \\boldsymbol{h}_i) = \\boldsymbol{s}_t^\\top\\mathbf{W}_a\\boldsymbol{h}_i$where $\\mathbf{W}_a$ is a trainable weight matrix in the attention layer. Luong2015 Dot-Product $\\text{score}(\\boldsymbol{s}_t, \\boldsymbol{h}_i) = \\boldsymbol{s}_t^\\top\\boldsymbol{h}_i$ Luong2015 Scaled Dot-Product(^) $\\text{score}(\\boldsymbol{s}_t, \\boldsymbol{h}_i) = \\frac{\\boldsymbol{s}_t^\\top\\boldsymbol{h}_i}{\\sqrt{n}}$Note: very similar to the dot-product attention except for a scaling factor; where n is the dimension of the source hidden state. Vaswani2017 (*) Referred to as \u0026ldquo;concat\u0026rdquo; in Luong, et al., 2015 and as \u0026ldquo;additive attention\u0026rdquo; in Vaswani, et al., 2017. (^) It adds a scaling factor $1/\\sqrt{n}$, motivated by the concern when the input is large, the softmax function may have an extremely small gradient, hard for efficient learning.\nHere are a summary of broader categories of attention mechanisms:\n Name Definition Citation Self-Attention(\u0026amp;) Relating different positions of the same input sequence. Theoretically the self-attention can adopt any score functions above, but just replace the target sequence with the same input sequence. Cheng2016 Global/Soft Attending to the entire input state space. Xu2015 Local/Hard Attending to the part of input state space; i.e. a patch of the input image. Xu2015; Luong2015 (\u0026amp;) Also, referred to as \u0026ldquo;intra-attention\u0026rdquo; in Cheng et al., 2016 and some other papers.\nSelf-Attention Self-attention, also known as intra-attention, is an attention mechanism relating different positions of a single sequence in order to compute a representation of the same sequence. It has been shown to be very useful in machine reading, abstractive summarization, or image description generation.\nThe long short-term memory network paper used self-attention to do machine reading. In the example below, the self-attention mechanism enables us to learn the correlation between the current words and the previous part of the sentence.\nFig. 6. The current word is in red and the size of the blue shade indicates the activation level. (Image source: Cheng et al., 2016) Soft vs Hard Attention In the show, attend and tell paper, attention mechanism is applied to images to generate captions. The image is first encoded by a CNN to extract features. Then a LSTM decoder consumes the convolution features to produce descriptive words one by one, where the weights are learned through attention. The visualization of the attention weights clearly demonstrates which regions of the image the model is paying attention to so as to output a certain word.\nFig. 7. \"A woman is throwing a frisbee in a park.\" (Image source: Fig. 6(b) in Xu et al. 2015) This paper first proposed the distinction between \u0026ldquo;soft\u0026rdquo; vs \u0026ldquo;hard\u0026rdquo; attention, based on whether the attention has access to the entire image or only a patch:\n Soft Attention: the alignment weights are learned and placed \u0026ldquo;softly\u0026rdquo; over all patches in the source image; essentially the same type of attention as in Bahdanau et al., 2015. Pro: the model is smooth and differentiable. Con: expensive when the source input is large. Hard Attention: only selects one patch of the image to attend to at a time. Pro: less calculation at the inference time. Con: the model is non-differentiable and requires more complicated techniques such as variance reduction or reinforcement learning to train. (Luong, et al., 2015) Global vs Local Attention Luong, et al., 2015 proposed the \u0026ldquo;global\u0026rdquo; and \u0026ldquo;local\u0026rdquo; attention. The global attention is similar to the soft attention, while the local one is an interesting blend between hard and soft, an improvement over the hard attention to make it differentiable: the model first predicts a single aligned position for the current target word and a window centered around the source position is then used to compute a context vector.\nFig. 8. Global vs local attention (Image source: Fig 2 \u0026 3 in Luong, et al., 2015) Neural Turing Machines Alan Turing in 1936 proposed a minimalistic model of computation. It is composed of a infinitely long tape and a head to interact with the tape. The tape has countless cells on it, each filled with a symbol: 0, 1 or blank (\u0026quot; \u0026ldquo;). The operation head can read symbols, edit symbols and move left/right on the tape. Theoretically a Turing machine can simulate any computer algorithm, irrespective of how complex or expensive the procedure might be. The infinite memory gives a Turing machine an edge to be mathematically limitless. However, infinite memory is not feasible in real modern computers and then we only consider Turing machine as a mathematical model of computation.\nFig. 9. How a Turing machine looks like: a tape + a head that handles the tape. (Image source: http://aturingmachine.com/) Neural Turing Machine (NTM, Graves, Wayne \u0026amp; Danihelka, 2014) is a model architecture for coupling a neural network with external memory storage. The memory mimics the Turing machine tape and the neural network controls the operation heads to read from or write to the tape. However, the memory in NTM is finite, and thus it probably looks more like a “Neural von Neumann Machine”.\nNTM contains two major components, a controller neural network and a memory bank. Controller: is in charge of executing operations on the memory. It can be any type of neural network, feed-forward or recurrent. Memory: stores processed information. It is a matrix of size $N \\times M$, containing N vector rows and each has $M$ dimensions.\nIn one update iteration, the controller processes the input and interacts with the memory bank accordingly to generate output. The interaction is handled by a set of parallel read and write heads. Both read and write operations are “blurry” by softly attending to all the memory addresses.\nFig 10. Neural Turing Machine Architecture. Reading and Writing When reading from the memory at time t, an attention vector of size $N$, $\\mathbf{w}_t$ controls how much attention to assign to different memory locations (matrix rows). The read vector $\\mathbf{r}_t$ is a sum weighted by attention intensity:\n $$ \\mathbf{r}_t = \\sum_{i=1}^N w_t(i)\\mathbf{M}_t(i)\\text{, where }\\sum_{i=1}^N w_t(i)=1, \\forall i: 0 \\leq w_t(i) \\leq 1 $$ where $w_t(i)$ is the $i$-th element in $\\mathbf{w}_t$ and $\\mathbf{M}_t(i)$ is the $i$-th row vector in the memory.\nWhen writing into the memory at time t, as inspired by the input and forget gates in LSTM, a write head first wipes off some old content according to an erase vector $\\mathbf{e}_t$ and then adds new information by an add vector $\\mathbf{a}_t$.\n $$ \\begin{aligned} \\tilde{\\mathbf{M}}_t(i) \u0026= \\mathbf{M}_{t-1}(i) [\\mathbf{1} - w_t(i)\\mathbf{e}_t] \u0026\\scriptstyle{\\text{; erase}}\\\\ \\mathbf{M}_t(i) \u0026= \\tilde{\\mathbf{M}}_t(i) + w_t(i) \\mathbf{a}_t \u0026\\scriptstyle{\\text{; add}} \\end{aligned} $$ Attention Mechanisms In Neural Turing Machine, how to generate the attention distribution $\\mathbf{w}_t$ depends on the addressing mechanisms: NTM uses a mixture of content-based and location-based addressings.\nContent-based addressing\nThe content-addressing creates attention vectors based on the similarity between the key vector $\\mathbf{k}_t$ extracted by the controller from the input and memory rows. The content-based attention scores are computed as cosine similarity and then normalized by softmax. In addition, NTM adds a strength multiplier $\\beta_t$ to amplify or attenuate the focus of the distribution.\n $$ w_t^c(i) = \\text{softmax}(\\beta_t \\cdot \\text{cosine}[\\mathbf{k}_t, \\mathbf{M}_t(i)]) = \\frac{\\exp(\\beta_t \\frac{\\mathbf{k}_t \\cdot \\mathbf{M}_t(i)}{\\|\\mathbf{k}_t\\| \\cdot \\|\\mathbf{M}_t(i)\\|})}{\\sum_{j=1}^N \\exp(\\beta_t \\frac{\\mathbf{k}_t \\cdot \\mathbf{M}_t(j)}{\\|\\mathbf{k}_t\\| \\cdot \\|\\mathbf{M}_t(j)\\|})} $$ Interpolation\nThen an interpolation gate scalar $g_t$ is used to blend the newly generated content-based attention vector with the attention weights in the last time step:\n $$ \\mathbf{w}_t^g = g_t \\mathbf{w}_t^c + (1 - g_t) \\mathbf{w}_{t-1} $$ Location-based addressing\nThe location-based addressing sums up the values at different positions in the attention vector, weighted by a weighting distribution over allowable integer shifts. It is equivalent to a 1-d convolution with a kernel $\\mathbf{s}_t(.)$, a function of the position offset. There are multiple ways to define this distribution. See Fig. 11. for inspiration.\nFig. 11. Two ways to represent the shift weighting distribution $\\mathbf{s}\\_t$. Finally the attention distribution is enhanced by a sharpening scalar $\\gamma_t \\geq 1$.\n $$ \\begin{aligned} \\tilde{w}_t(i) \u0026= \\sum_{j=1}^N w_t^g(j) s_t(i-j) \u0026 \\scriptstyle{\\text{; circular convolution}}\\\\ w_t(i) \u0026= \\frac{\\tilde{w}_t(i)^{\\gamma_t}}{\\sum_{j=1}^N \\tilde{w}_t(j)^{\\gamma_t}} \u0026 \\scriptstyle{\\text{; sharpen}} \\end{aligned} $$ The complete process of generating the attention vector $\\mathbf{w}_t$ at time step t is illustrated in Fig. 12. All the parameters produced by the controller are unique for each head. If there are multiple read and write heads in parallel, the controller would output multiple sets.\nFig. 12. Flow diagram of the addressing mechanisms in Neural Turing Machine. (Image source: Graves, Wayne \u0026 Danihelka, 2014) Pointer Network In problems like sorting or travelling salesman, both input and output are sequential data. Unfortunately, they cannot be easily solved by classic seq-2-seq or NMT models, given that the discrete categories of output elements are not determined in advance, but depends on the variable input size. The Pointer Net (Ptr-Net; Vinyals, et al. 2015) is proposed to resolve this type of problems: When the output elements correspond to positions in an input sequence. Rather than using attention to blend hidden units of an encoder into a context vector (See Fig. 8), the Pointer Net applies attention over the input elements to pick one as the output at each decoder step.\nFig. 13. The architecture of a Pointer Network model. (Image source: Vinyals, et al. 2015) The Ptr-Net outputs a sequence of integer indices, $\\boldsymbol{c} = (c_1, \\dots, c_m)$ given a sequence of input vectors $\\boldsymbol{x} = (x_1, \\dots, x_n)$ and $1 \\leq c_i \\leq n$. The model still embraces an encoder-decoder framework. The encoder and decoder hidden states are denoted as $(\\boldsymbol{h}_1, \\dots, \\boldsymbol{h}_n)$ and $(\\boldsymbol{s}_1, \\dots, \\boldsymbol{s}_m)$, respectively. Note that $\\mathbf{s}_i$ is the output gate after cell activation in the decoder. The Ptr-Net applies additive attention between states and then normalizes it by softmax to model the output conditional probability:\n $$ \\begin{aligned} y_i \u0026= p(c_i \\vert c_1, \\dots, c_{i-1}, \\boldsymbol{x}) \\\\ \u0026= \\text{softmax}(\\text{score}(\\boldsymbol{s}_t; \\boldsymbol{h}_i)) = \\text{softmax}(\\mathbf{v}_a^\\top \\tanh(\\mathbf{W}_a[\\boldsymbol{s}_t; \\boldsymbol{h}_i])) \\end{aligned} $$ The attention mechanism is simplified, as Ptr-Net does not blend the encoder states into the output with attention weights. In this way, the output only responds to the positions but not the input content.\nTransformer \u0026ldquo;Attention is All you Need\u0026rdquo; (Vaswani, et al., 2017), without a doubt, is one of the most impactful and interesting paper in 2017. It presented a lot of improvements to the soft attention and make it possible to do seq2seq modeling without recurrent network units. The proposed \u0026ldquo;transformer\u0026rdquo; model is entirely built on the self-attention mechanisms without using sequence-aligned recurrent architecture.\nThe secret recipe is carried in its model architecture.\nKey, Value and Query The major component in the transformer is the unit of multi-head self-attention mechanism. The transformer views the encoded representation of the input as a set of key-value pairs, $(\\mathbf{K}, \\mathbf{V})$, both of dimension $n$ (input sequence length); in the context of NMT, both the keys and values are the encoder hidden states. In the decoder, the previous output is compressed into a query ($\\mathbf{Q}$ of dimension $m$) and the next output is produced by mapping this query and the set of keys and values.\nThe transformer adopts the scaled dot-product attention: the output is a weighted sum of the values, where the weight assigned to each value is determined by the dot-product of the query with all the keys:\n $$ \\text{Attention}(\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}) = \\text{softmax}(\\frac{\\mathbf{Q}\\mathbf{K}^\\top}{\\sqrt{n}})\\mathbf{V} $$ Multi-Head Self-Attention Fig. 14. Multi-head scaled dot-product attention mechanism. (Image source: Fig 2 in Vaswani, et al., 2017) Rather than only computing the attention once, the multi-head mechanism runs through the scaled dot-product attention multiple times in parallel. The independent attention outputs are simply concatenated and linearly transformed into the expected dimensions. I assume the motivation is because ensembling always helps? ;) According to the paper, \u0026ldquo;multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.\u0026quot;\n $$ \\begin{aligned} \\text{MultiHead}(\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}) \u0026= [\\text{head}_1; \\dots; \\text{head}_h]\\mathbf{W}^O \\\\ \\text{where head}_i \u0026= \\text{Attention}(\\mathbf{Q}\\mathbf{W}^Q_i, \\mathbf{K}\\mathbf{W}^K_i, \\mathbf{V}\\mathbf{W}^V_i) \\end{aligned} $$ where $\\mathbf{W}^Q_i$, $\\mathbf{W}^K_i$, $\\mathbf{W}^V_i$, and $\\mathbf{W}^O$ are parameter matrices to be learned.\nEncoder Fig. 15. The transformer’s encoder. (Image source: Vaswani, et al., 2017) The encoder generates an attention-based representation with capability to locate a specific piece of information from a potentially infinitely-large context.\n A stack of N=6 identical layers. Each layer has a multi-head self-attention layer and a simple position-wise fully connected feed-forward network. Each sub-layer adopts a residual connection and a layer normalization. All the sub-layers output data of the same dimension $d_\\text{model} = 512$. Decoder Fig. 16. The transformer’s decoder. (Image source: Vaswani, et al., 2017) The decoder is able to retrieval from the encoded representation.\n A stack of N = 6 identical layers Each layer has two sub-layers of multi-head attention mechanisms and one sub-layer of fully-connected feed-forward network. Similar to the encoder, each sub-layer adopts a residual connection and a layer normalization. The first multi-head attention sub-layer is modified to prevent positions from attending to subsequent positions, as we don’t want to look into the future of the target sequence when predicting the current position. Full Architecture Finally here is the complete view of the transformer\u0026rsquo;s architecture:\n Both the source and target sequences first go through embedding layers to produce data of the same dimension $d_\\text{model} =512$. To preserve the position information, a sinusoid-wave-based positional encoding is applied and summed with the embedding output. A softmax and linear layer are added to the final decoder output. Fig. 17. The full model architecture of the transformer. (Image source: Fig 1 \u0026 2 in Vaswani, et al., 2017.) Try to implement the transformer model is an interesting experience, here is mine: lilianweng/transformer-tensorflow. Read the comments in the code if you are interested.\nSNAIL The transformer has no recurrent or convolutional structure, even with the positional encoding added to the embedding vector, the sequential order is only weakly incorporated. For problems sensitive to the positional dependency like reinforcement learning, this can be a big problem.\nThe Simple Neural Attention Meta-Learner (SNAIL) (Mishra et al., 2017) was developed partially to resolve the problem with positioning in the transformer model by combining the self-attention mechanism in transformer with temporal convolutions. It has been demonstrated to be good at both supervised learning and reinforcement learning tasks.\nFig. 18. SNAIL model architecture (Image source: Mishra et al., 2017) SNAIL was born in the field of meta-learning, which is another big topic worthy of a post by itself. But in simple words, the meta-learning model is expected to be generalizable to novel, unseen tasks in the similar distribution. Read this nice introduction if interested.\nSelf-Attention GAN Self-Attention GAN (SAGAN; Zhang et al., 2018) adds self-attention layers into GAN to enable both the generator and the discriminator to better model relationships between spatial regions.\nThe classic DCGAN (Deep Convolutional GAN) represents both discriminator and generator as multi-layer convolutional networks. However, the representation capacity of the network is restrained by the filter size, as the feature of one pixel is limited to a small local region. In order to connect regions far apart, the features have to be dilute through layers of convolutional operations and the dependencies are not guaranteed to be maintained.\nAs the (soft) self-attention in the vision context is designed to explicitly learn the relationship between one pixel and all other positions, even regions far apart, it can easily capture global dependencies. Hence GAN equipped with self-attention is expected to handle details better, hooray!\nFig. 19. Convolution operation and self-attention have access to regions of very different sizes. The SAGAN adopts the non-local neural network to apply the attention computation. The convolutional image feature maps $\\mathbf{x}$ is branched out into three copies, corresponding to the concepts of key, value, and query in the transformer:\n Key: $f(\\mathbf{x}) = \\mathbf{W}_f \\mathbf{x}$ Query: $g(\\mathbf{x}) = \\mathbf{W}_g \\mathbf{x}$ Value: $h(\\mathbf{x}) = \\mathbf{W}_h \\mathbf{x}$ Then we apply the dot-product attention to output the self-attention feature maps:\n $$ \\begin{aligned} \\alpha_{i,j} \u0026= \\text{softmax}(f(\\mathbf{x}_i)^\\top g(\\mathbf{x}_j)) \\\\ \\mathbf{o}_j \u0026= \\mathbf{W}_v \\Big( \\sum_{i=1}^N \\alpha_{i,j} h(\\mathbf{x}_i) \\Big) \\end{aligned} $$ Fig. 20. The self-attention mechanism in SAGAN. (Image source: Fig. 2 in Zhang et al., 2018) Note that $\\alpha_{i,j}$ is one entry in the attention map, indicating how much attention the model should pay to the $i$-th position when synthesizing the $j$-th location. $\\mathbf{W}_f$, $\\mathbf{W}_g$, and $\\mathbf{W}_h$ are all 1x1 convolution filters. If you feel that 1x1 conv sounds like a weird concept (i.e., isn\u0026rsquo;t it just to multiply the whole feature map with one number?), watch this short tutorial by Andrew Ng. The output $\\mathbf{o}_j$ is a column vector of the final output $\\mathbf{o}= (\\mathbf{o}_1, \\mathbf{o}_2, \\dots, \\mathbf{o}_j, \\dots, \\mathbf{o}_N)$.\nFurthermore, the output of the attention layer is multiplied by a scale parameter and added back to the original input feature map:\n $$ \\mathbf{y} = \\mathbf{x}_i + \\gamma \\mathbf{o}_i $$ While the scaling parameter $\\gamma$ is increased gradually from 0 during the training, the network is configured to first rely on the cues in the local regions and then gradually learn to assign more weight to the regions that are further away.\nFig. 21. 128×128 example images generated by SAGAN for different classes. (Image source: Partial Fig. 6 in Zhang et al., 2018) Cited as:\n@article{weng2018attention, title = \u0026quot;Attention? Attention!\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2018\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2018-06-24-attention/\u0026quot; } References [1] \u0026ldquo;Attention and Memory in Deep Learning and NLP.\u0026quot; - Jan 3, 2016 by Denny Britz\n[2] \u0026ldquo;Neural Machine Translation (seq2seq) Tutorial\u0026rdquo;\n[3] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. \u0026ldquo;Neural machine translation by jointly learning to align and translate.\u0026quot; ICLR 2015.\n[4] Kelvin Xu, Jimmy Ba, Ryan Kiros, Kyunghyun Cho, Aaron Courville, Ruslan Salakhudinov, Rich Zemel, and Yoshua Bengio. \u0026ldquo;Show, attend and tell: Neural image caption generation with visual attention.\u0026quot; ICML, 2015.\n[5] Ilya Sutskever, Oriol Vinyals, and Quoc V. Le. \u0026ldquo;Sequence to sequence learning with neural networks.\u0026quot; NIPS 2014.\n[6] Thang Luong, Hieu Pham, Christopher D. Manning. \u0026ldquo;Effective Approaches to Attention-based Neural Machine Translation.\u0026quot; EMNLP 2015.\n[7] Denny Britz, Anna Goldie, Thang Luong, and Quoc Le. \u0026ldquo;Massive exploration of neural machine translation architectures.\u0026quot; ACL 2017.\n[8] Ashish Vaswani, et al. \u0026ldquo;Attention is all you need.\u0026quot; NIPS 2017.\n[9] Jianpeng Cheng, Li Dong, and Mirella Lapata. \u0026ldquo;Long short-term memory-networks for machine reading.\u0026quot; EMNLP 2016.\n[10] Xiaolong Wang, et al. \u0026ldquo;Non-local Neural Networks.\u0026quot; CVPR 2018\n[11] Han Zhang, Ian Goodfellow, Dimitris Metaxas, and Augustus Odena. \u0026ldquo;Self-Attention Generative Adversarial Networks.\u0026quot; arXiv preprint arXiv:1805.08318 (2018).\n[12] Nikhil Mishra, Mostafa Rohaninejad, Xi Chen, and Pieter Abbeel. \u0026ldquo;A simple neural attentive meta-learner.\u0026quot; ICLR 2018.\n[13] \u0026ldquo;WaveNet: A Generative Model for Raw Audio\u0026rdquo; - Sep 8, 2016 by DeepMind.\n[14] Oriol Vinyals, Meire Fortunato, and Navdeep Jaitly. \u0026ldquo;Pointer networks.\u0026quot; NIPS 2015.\n[15] Alex Graves, Greg Wayne, and Ivo Danihelka. \u0026ldquo;Neural turing machines.\u0026quot; arXiv preprint arXiv:1410.5401 (2014).\n","permalink":"https://lilianweng.github.io/posts/2018-06-24-attention/","summary":"[Updated on 2018-10-28: Add Pointer Network and the link to my implementation of Transformer.] [Updated on 2018-11-06: Add a link to the implementation of Transformer model.] [Updated on 2018-11-18: Add Neural Turing Machines.] [Updated on 2019-07-18: Correct the mistake on using the term \u0026ldquo;self-attention\u0026rdquo; when introducing the show-attention-tell paper; moved it to Self-Attention section.] [Updated on 2020-04-07: A follow-up post on improved Transformer models is here.]\nAttention is, to some extent, motivated by how we pay visual attention to different regions of an image or correlate words in one sentence.","title":"Attention? Attention!"},{"content":"The full implementation is available in lilianweng/deep-reinforcement-learning-gym\nIn the previous two posts, I have introduced the algorithms of many deep reinforcement learning models. Now it is the time to get our hands dirty and practice how to implement the models in the wild. The implementation is gonna be built in Tensorflow and OpenAI gym environment. The full version of the code in this tutorial is available in [lilian/deep-reinforcement-learning-gym].\nEnvironment Setup Make sure you have Homebrew installed: /usr/bin/ruby -e \u0026#34;$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)\u0026#34; I would suggest starting a virtualenv for your development. It makes life so much easier when you have multiple projects with conflicting requirements; i.e. one works in Python 2.7 while the other is only compatible with Python 3.5+. # Install python virtualenv brew install pyenv-virtualenv # Create a virtual environment of any name you like with Python 3.6.4 support pyenv virtualenv 3.6.4 workspace # Activate the virtualenv named \u0026#34;workspace\u0026#34; pyenv activate workspace [*] For every new installation below, please make sure you are in the virtualenv.\nInstall OpenAI gym according to the instruction. For a minimal installation, run: git clone https://github.com/openai/gym.git cd gym pip install -e . If you are interested in playing with Atari games or other advanced packages, please continue to get a couple of system packages installed.\nbrew install cmake boost boost-python sdl2 swig wget For Atari, go to the gym directory and pip install it. This post is pretty helpful if you have troubles with ALE (arcade learning environment) installation.\npip install -e \u0026#39;.[atari]\u0026#39; Finally clone the \u0026ldquo;playground\u0026rdquo; code and install the requirements. git clone git@github.com:lilianweng/deep-reinforcement-learning-gym.git cd deep-reinforcement-learning-gym pip install -e . # install the \u0026#34;playground\u0026#34; project. pip install -r requirements.txt # install required packages. Gym Environment The OpenAI Gym toolkit provides a set of physical simulation environments, games, and robot simulators that we can play with and design reinforcement learning agents for. An environment object can be initialized by gym.make(\u0026quot;{environment name}\u0026quot;:\nimport gym env = gym.make(\u0026#34;MsPacman-v0\u0026#34;) The formats of action and observation of an environment are defined by env.action_space and env.observation_space, respectively.\nTypes of gym spaces:\n gym.spaces.Discrete(n): discrete values from 0 to n-1. gym.spaces.Box: a multi-dimensional vector of numeric values, the upper and lower bounds of each dimension are defined by Box.low and Box.high. We interact with the env through two major api calls:\nob = env.reset()\n Resets the env to the original setting. Returns the initial observation. ob_next, reward, done, info = env.step(action)\n Applies one action in the env which should be compatible with env.action_space. Gets back the new observation ob_next (env.observation_space), a reward (float), a done flag (bool), and other meta information (dict). If done=True, the episode is complete and we should reset the env to restart. Read more here. Naive Q-Learning Q-learning (Watkins \u0026amp; Dayan, 1992) learns the action value (\u0026ldquo;Q-value\u0026rdquo;) and update it according to the Bellman equation. The key point is while estimating what is the next action, it does not follow the current policy but rather adopt the best Q value (the part in red) independently.\n $$ Q(s, a) \\leftarrow (1 - \\alpha) Q(s, a) + \\alpha (r + \\gamma \\color{red}{\\max_{a' \\in \\mathcal{A}} Q(s', a')}) $$ In a naive implementation, the Q value for all (s, a) pairs can be simply tracked in a dict. No complicated machine learning model is involved yet.\nfrom collections import defaultdict Q = defaultdict(float) gamma = 0.99 # Discounting factor alpha = 0.5 # soft update param env = gym.make(\u0026#34;CartPole-v0\u0026#34;) actions = range(env.action_space) def update_Q(s, r, a, s_next, done): max_q_next = max([Q[s_next, a] for a in actions]) # Do not include the next state\u0026#39;s value if currently at the terminal state. Q[s, a] += alpha * (r + gamma * max_q_next * (1.0 - done) - Q[s, a]) Most gym environments have a multi-dimensional continuous observation space (gym.spaces.Box). To make sure our Q dictionary will not explode by trying to memorize an infinite number of keys, we apply a wrapper to discretize the observation. The concept of wrappers is very powerful, with which we are capable to customize observation, action, step function, etc. of an env. No matter how many wrappers are applied, env.unwrapped always gives back the internal original environment object.\nimport gym class DiscretizedObservationWrapper(gym.ObservationWrapper): \u0026#34;\u0026#34;\u0026#34;This wrapper converts a Box observation into a single integer. \u0026#34;\u0026#34;\u0026#34; def __init__(self, env, n_bins=10, low=None, high=None): super().__init__(env) assert isinstance(env.observation_space, Box) low = self.observation_space.low if low is None else low high = self.observation_space.high if high is None else high self.n_bins = n_bins self.val_bins = [np.linspace(l, h, n_bins + 1) for l, h in zip(low.flatten(), high.flatten())] self.observation_space = Discrete(n_bins ** low.flatten().shape[0]) def _convert_to_one_number(self, digits): return sum([d * ((self.n_bins + 1) ** i) for i, d in enumerate(digits)]) def observation(self, observation): digits = [np.digitize([x], bins)[0] for x, bins in zip(observation.flatten(), self.val_bins)] return self._convert_to_one_number(digits) env = DiscretizedObservationWrapper( env, n_bins=8, low=[-2.4, -2.0, -0.42, -3.5], high=[2.4, 2.0, 0.42, 3.5] ) Let\u0026rsquo;s plug in the interaction with a gym env and update the Q function every time a new transition is generated. When picking the action, we use ε-greedy to force exploration.\nimport gym import numpy as np n_steps = 100000 epsilon = 0.1 # 10% chances to apply a random action def act(ob): if np.random.random() \u0026lt; epsilon: # action_space.sample() is a convenient function to get a random action # that is compatible with this given action space. return env.action_space.sample() # Pick the action with highest q value. qvals = {a: q[state, a] for a in actions} max_q = max(qvals.values()) # In case multiple actions have the same maximum q value. actions_with_max_q = [a for a, q in qvals.items() if q == max_q] return np.random.choice(actions_with_max_q) ob = env.reset() rewards = [] reward = 0.0 for step in range(n_steps): a = act(ob) ob_next, r, done, _ = env.step(a) update_Q(ob, r, a, ob_next, done) reward += r if done: rewards.append(reward) reward = 0.0 ob = env.reset() else: ob = ob_next Often we start with a high epsilon and gradually decrease it during the training, known as \u0026ldquo;epsilon annealing\u0026rdquo;. The full code of QLearningPolicy is available here.\nDeep Q-Network Deep Q-network is a seminal piece of work to make the training of Q-learning more stable and more data-efficient, when the Q value is approximated with a nonlinear function. Two key ingredients are experience replay and a separately updated target network.\nThe main loss function looks like the following,\n $$ \\begin{aligned} \u0026 Y(s, a, r, s') = r + \\gamma \\max_{a'} Q_{\\theta^{-}}(s', a') \\\\ \u0026 \\mathcal{L}(\\theta) = \\mathbb{E}_{(s, a, r, s') \\sim U(D)} \\Big[ \\big( Y(s, a, r, s') - Q_\\theta(s, a) \\big)^2 \\Big] \\end{aligned} $$ The Q network can be a multi-layer dense neural network, a convolutional network, or a recurrent network, depending on the problem. In the full implementation of the DQN policy, it is determined by the model_type parameter, one of (\u0026ldquo;dense\u0026rdquo;, \u0026ldquo;conv\u0026rdquo;, \u0026ldquo;lstm\u0026rdquo;).\nIn the following example, I\u0026rsquo;m using a 2-layer densely connected neural network to learn Q values for the cart pole balancing problem.\nimport gym env = gym.make(\u0026#39;CartPole-v1\u0026#39;) # The observation space is `Box(4,)`, a 4-element vector. observation_size = env.observation_space.shape[0] We have a helper function for creating the networks below:\nimport tensorflow as tf def dense_nn(inputs, layers_sizes, scope_name): \u0026#34;\u0026#34;\u0026#34;Creates a densely connected multi-layer neural network. inputs: the input tensor layers_sizes (list\u0026lt;int\u0026gt;): defines the number of units in each layer. The output layer has the size layers_sizes[-1]. \u0026#34;\u0026#34;\u0026#34; with tf.variable_scope(scope_name): for i, size in enumerate(layers_sizes): inputs = tf.layers.dense( inputs, size, # Add relu activation only for internal layers. activation=tf.nn.relu if i \u0026lt; len(layers_sizes) - 1 else None, kernel_initializer=tf.contrib.layers.xavier_initializer(), name=scope_name + \u0026#39;_l\u0026#39; + str(i) ) return inputs The Q-network and the target network are updated with a batch of transitions (state, action, reward, state_next, done_flag). The input tensors are:\nbatch_size = 32 # A tunable hyperparameter. states = tf.placeholder(tf.float32, shape=(batch_size, observation_size), name=\u0026#39;state\u0026#39;) states_next = tf.placeholder(tf.float32, shape=(batch_size, observation_size), name=\u0026#39;state_next\u0026#39;) actions = tf.placeholder(tf.int32, shape=(batch_size,), name=\u0026#39;action\u0026#39;) rewards = tf.placeholder(tf.float32, shape=(batch_size,), name=\u0026#39;reward\u0026#39;) done_flags = tf.placeholder(tf.float32, shape=(batch_size,), name=\u0026#39;done\u0026#39;) We have two networks of the same structure. Both have the same network architectures with the state observation as the inputs and Q values over all the actions as the outputs.\nq = dense(states, [32, 32, 2], name=\u0026#39;Q_primary\u0026#39;) q_target = dense(states_next, [32, 32, 2], name=\u0026#39;Q_target\u0026#39;) The target network \u0026ldquo;Q_target\u0026rdquo; takes the states_next tensor as the input, because we use its prediction to select the optimal next state in the Bellman equation.\n# The prediction by the primary Q network for the actual actions. action_one_hot = tf.one_hot(actions, act_size, 1.0, 0.0, name=\u0026#39;action_one_hot\u0026#39;) pred = tf.reduce_sum(q * action_one_hot, reduction_indices=-1, name=\u0026#39;q_acted\u0026#39;) # The optimization target defined by the Bellman equation and the target network. max_q_next_by_target = tf.reduce_max(q_target, axis=-1) y = rewards + (1. - done_flags) * gamma * max_q_next_by_target # The loss measures the mean squared error between prediction and target. loss = tf.reduce_mean(tf.square(pred - tf.stop_gradient(y)), name=\u0026#34;loss_mse_train\u0026#34;) optimizer = tf.train.AdamOptimizer(0.001).minimize(loss, name=\u0026#34;adam_optim\u0026#34;) Note that tf.stop_gradient() on the target y, because the target network should stay fixed during the loss-minimizing gradient update.\nThe target network is updated by copying the primary Q network parameters over every C number of steps (\u0026ldquo;hard update\u0026rdquo;) or polyak averaging towards the primary network (\u0026ldquo;soft update\u0026rdquo;)\n# Get all the variables in the Q primary network. q_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=\u0026#34;Q_primary\u0026#34;) # Get all the variables in the Q target network. q_target_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=\u0026#34;Q_target\u0026#34;) assert len(q_vars) == len(q_target_vars) def update_target_q_net_hard(): # Hard update sess.run([v_t.assign(v) for v_t, v in zip(q_target_vars, q_vars)]) def update_target_q_net_soft(tau=0.05): # Soft update: polyak averaging. sess.run([v_t.assign(v_t * (1. - tau) + v * tau) for v_t, v in zip(q_target_vars, q_vars)]) Double Q-Learning If we look into the standard form of the Q value target, $Y(s, a) = r + \\gamma \\max_{a' \\in \\mathcal{A}} Q_\\theta (s', a')$, it is easy to notice that we use $Q_\\theta$ to select the best next action at state s' and then apply the action value predicted by the same $Q_\\theta$. This two-step reinforcing procedure could potentially lead to overestimation of an (already) overestimated value, further leading to training instability. The solution proposed by double Q-learning (Hasselt, 2010) is to decouple the action selection and action value estimation by using two Q networks, $Q_1$ and $Q_2$: when $Q_1$ is being updated, $Q_2$ decides the best next action, and vice versa.\n $$ Y_1(s, a, r, s') = r + \\gamma Q_1 (s', \\arg\\max_{a' \\in \\mathcal{A}}Q_2(s', a'))\\\\ Y_2(s, a, r, s') = r + \\gamma Q_2 (s', \\arg\\max_{a' \\in \\mathcal{A}}Q_1(s', a')) $$ To incorporate double Q-learning into DQN, the minimum modification (Hasselt, Guez, \u0026amp; Silver, 2016) is to use the primary Q network to select the action while the action value is estimated by the target network:\n $$ Y(s, a, r, s') = r + \\gamma Q_{\\theta^{-}}(s', \\arg\\max_{a' \\in \\mathcal{A}} Q_\\theta(s', a')) $$ In the code, we add a new tensor for getting the action selected by the primary Q network as the input and a tensor operation for selecting this action.\nactions_next = tf.placeholder(tf.int32, shape=(None,), name=\u0026#39;action_next\u0026#39;) actions_selected_by_q = tf.argmax(q, axis=-1, name=\u0026#39;action_selected\u0026#39;) The prediction target y in the loss function becomes:\nactions_next_flatten = actions_next + tf.range(0, batch_size) * q_target.shape[1] max_q_next_target = tf.gather(tf.reshape(q_target, [-1]), actions_next_flatten) y = rewards + (1. - done_flags) * gamma * max_q_next_by_target Here I used tf.gather() to select the action values of interests.\n(Image source: tf.gather() docs) During the episode rollout, we compute the actions_next by feeding the next states' data into the actions_selected_by_q operation.\n# batch_data is a dict with keys, ‘s\u0026#39;, ‘a\u0026#39;, ‘r\u0026#39;, ‘s_next\u0026#39; and ‘done\u0026#39;, containing a batch of transitions. actions_next = sess.run(actions_selected_by_q, {states: batch_data[\u0026#39;s_next\u0026#39;]}) Dueling Q-Network The dueling Q-network (Wang et al., 2016) is equipped with an enhanced network architecture: the output layer branches out into two heads, one for predicting state value, V, and the other for advantage, A. The Q-value is then reconstructed, $Q(s, a) = V(s) + A(s, a)$.\n $$ \\begin{aligned} A(s, a) \u0026= Q(s, a) - V(s)\\\\ V(s) \u0026= \\sum_a Q(s, a) \\pi(a \\vert s) = \\sum_a (V(s) + A(s, a)) \\pi(a \\vert s) = V(s) + \\sum_a A(s, a)\\pi(a \\vert s)\\\\ \\text{Thus, }\u0026 \\sum_a A(s, a)\\pi(a \\vert s) = 0 \\end{aligned} $$ To make sure the estimated advantage values sum up to zero, $\\sum_a A(s, a)\\pi(a \\vert s) = 0$, we deduct the mean value from the prediction.\n $$ Q(s, a) = V(s) + (A(s, a) - \\frac{1}{|\\mathcal{A}|} \\sum_a A(s, a)) $$ The code change is straightforward:\nq_hidden = dense_nn(states, [32], name=\u0026#39;Q_primary_hidden\u0026#39;) adv = dense_nn(q_hidden, [32, env.action_space.n], name=\u0026#39;Q_primary_adv\u0026#39;) v = dense_nn(q_hidden, [32, 1], name=\u0026#39;Q_primary_v\u0026#39;) # Average dueling q = v + (adv - tf.reduce_mean(adv, reduction_indices=1, keepdims=True)) (Image source: Wang et al., 2016) Check the code for the complete flow.\nMonte-Carlo Policy Gradient I reviewed a number of popular policy gradient methods in my last post. Monte-Carlo policy gradient, also known as REINFORCE, is a classic on-policy method that learns the policy model explicitly. It uses the return estimated from a full on-policy trajectory and updates the policy parameters with policy gradient.\nThe returns are computed during rollouts and then fed into the Tensorflow graph as inputs.\n# Inputs states = tf.placeholder(tf.float32, shape=(None, obs_size), name=\u0026#39;state\u0026#39;) actions = tf.placeholder(tf.int32, shape=(None,), name=\u0026#39;action\u0026#39;) returns = tf.placeholder(tf.float32, shape=(None,), name=\u0026#39;return\u0026#39;) The policy network is contructed. We update the policy parameters by minimizing the loss function, $\\mathcal{L} = - (G_t - V(s)) \\log \\pi(a \\vert s)$. tf.nn.sparse_softmax_cross_entropy_with_logits() asks for the raw logits as inputs, rather then the probabilities after softmax, and that\u0026rsquo;s why we do not have a softmax layer on top of the policy network.\n# Policy network pi = dense_nn(states, [32, 32, env.action_space.n], name=\u0026#39;pi_network\u0026#39;) sampled_actions = tf.squeeze(tf.multinomial(pi, 1)) # For sampling actions according to probabilities. with tf.variable_scope(\u0026#39;pi_optimize\u0026#39;): loss_pi = tf.reduce_mean( returns * tf.nn.sparse_softmax_cross_entropy_with_logits( logits=pi, labels=actions), name=\u0026#39;loss_pi\u0026#39;) optim_pi = tf.train.AdamOptimizer(0.001).minimize(loss_pi, name=\u0026#39;adam_optim_pi\u0026#39;) During the episode rollout, the return is calculated as follows:\n# env = gym.make(...) # gamma = 0.99 # sess = tf.Session(...) def act(ob): return sess.run(sampled_actions, {states: [ob]}) for _ in range(n_episodes): ob = env.reset() done = False obs = [] actions = [] rewards = [] returns = [] while not done: a = act(ob) new_ob, r, done, info = env.step(a) obs.append(ob) actions.append(a) rewards.append(r) ob = new_ob # Estimate returns backwards. return_so_far = 0.0 for r in rewards[::-1]: return_so_far = gamma * return_so_far + r returns.append(return_so_far) returns = returns[::-1] # Update the policy network with the data from one episode. sess.run([optim_pi], feed_dict={ states: np.array(obs), actions: np.array(actions), returns: np.array(returns), }) The full implementation of REINFORCE is here.\nActor-Critic The actor-critic algorithm learns two models at the same time, the actor for learning the best policy and the critic for estimating the state value.\n Initialize the actor network, $\\pi(a \\vert s)$ and the critic, $V(s)$ Collect a new transition (s, a, r, s'): Sample the action $a \\sim \\pi(a \\vert s)$ for the current state s, and get the reward r and the next state s'. Compute the TD target during episode rollout, $G_t = r + \\gamma V(s')$ and TD error, $\\delta_t = r + \\gamma V(s') - V(s)$. Update the critic network by minimizing the critic loss: $L_c = (V(s) - G_t)$. Update the actor network by minimizing the actor loss: $L_a = - \\delta_t \\log \\pi(a \\vert s)$. Set s' = s and repeat step 2.-5. Overall the implementation looks pretty similar to REINFORCE with an extra critic network. The full implementation is here.\n# Inputs states = tf.placeholder(tf.float32, shape=(None, observation_size), name=\u0026#39;state\u0026#39;) actions = tf.placeholder(tf.int32, shape=(None,), name=\u0026#39;action\u0026#39;) td_targets = tf.placeholder(tf.float32, shape=(None,), name=\u0026#39;td_target\u0026#39;) # Actor: action probabilities actor = dense_nn(states, [32, 32, env.action_space.n], name=\u0026#39;actor\u0026#39;) # Critic: action value (Q-value) critic = dense_nn(states, [32, 32, 1], name=\u0026#39;critic\u0026#39;) action_ohe = tf.one_hot(actions, act_size, 1.0, 0.0, name=\u0026#39;action_one_hot\u0026#39;) pred_value = tf.reduce_sum(critic * action_ohe, reduction_indices=-1, name=\u0026#39;q_acted\u0026#39;) td_errors = td_targets - tf.reshape(pred_value, [-1]) with tf.variable_scope(\u0026#39;critic_train\u0026#39;): loss_c = tf.reduce_mean(tf.square(td_errors)) optim_c = tf.train.AdamOptimizer(0.01).minimize(loss_c) with tf.variable_scope(\u0026#39;actor_train\u0026#39;): loss_a = tf.reduce_mean( tf.stop_gradient(td_errors) * tf.nn.sparse_softmax_cross_entropy_with_logits( logits=actor, labels=actions), name=\u0026#39;loss_actor\u0026#39;) optim_a = tf.train.AdamOptimizer(0.01).minimize(loss_a) train_ops = [optim_c, optim_a] The tensorboard graph is always helpful: References [1] Tensorflow API Docs\n[2] Christopher JCH Watkins, and Peter Dayan. \u0026ldquo;Q-learning.\u0026quot; Machine learning 8.3-4 (1992): 279-292.\n[3] Hado Van Hasselt, Arthur Guez, and David Silver. \u0026ldquo;Deep Reinforcement Learning with Double Q-Learning.\u0026quot; AAAI. Vol. 16. 2016.\n[4] Hado van Hasselt. \u0026ldquo;Double Q-learning.\u0026quot; NIPS, 23:2613–2621, 2010.\n[5] Ziyu Wang, et al. Dueling network architectures for deep reinforcement learning. ICML. 2016.\n","permalink":"https://lilianweng.github.io/posts/2018-05-05-drl-implementation/","summary":"The full implementation is available in lilianweng/deep-reinforcement-learning-gym\nIn the previous two posts, I have introduced the algorithms of many deep reinforcement learning models. Now it is the time to get our hands dirty and practice how to implement the models in the wild. The implementation is gonna be built in Tensorflow and OpenAI gym environment. The full version of the code in this tutorial is available in [lilian/deep-reinforcement-learning-gym].\nEnvironment Setup Make sure you have Homebrew installed: /usr/bin/ruby -e \u0026#34;$(curl -fsSL https://raw.","title":"Implementing Deep Reinforcement Learning Models with Tensorflow + OpenAI Gym"},{"content":"[Updated on 2018-06-30: add two new policy gradient methods, SAC and D4PG.] [Updated on 2018-09-30: add a new policy gradient method, TD3.] [Updated on 2019-02-09: add SAC with automatically adjusted temperature]. [Updated on 2019-06-26: Thanks to Chanseok, we have a version of this post in Korean]. [Updated on 2019-09-12: add a new policy gradient method SVPG.] [Updated on 2019-12-22: add a new policy gradient method IMPALA.] [Updated on 2020-10-15: add a new policy gradient method PPG \u0026amp; some new discussion in PPO.] [Updated on 2021-09-19: Thanks to Wenhao \u0026amp; 爱吃猫的鱼, we have this post in Chinese1 \u0026amp; Chinese2].\nWhat is Policy Gradient Policy gradient is an approach to solve reinforcement learning problems. If you haven\u0026rsquo;t looked into the field of reinforcement learning, please first read the section \u0026ldquo;A (Long) Peek into Reinforcement Learning \u0026raquo; Key Concepts\u0026rdquo; for the problem definition and key concepts.\nNotations Here is a list of notations to help you read through equations in the post easily.\n Symbol Meaning $s \\in \\mathcal{S}$ States. $a \\in \\mathcal{A}$ Actions. $r \\in \\mathcal{R}$ Rewards. $S_t, A_t, R_t$ State, action, and reward at time step $t$ of one trajectory. I may occasionally use $s_t, a_t, r_t$ as well. $\\gamma$ Discount factor; penalty to uncertainty of future rewards; $0\u0026lt;\\gamma \\leq 1$. $G_t$ Return; or discounted future reward; $G_t = \\sum_{k=0}^{\\infty} \\gamma^k R_{t+k+1}$. $P(s', r \\vert s, a)$ Transition probability of getting to the next state $s'$ from the current state $s$ with action $a$ and reward $r$. $\\pi(a \\vert s)$ Stochastic policy (agent behavior strategy); $\\pi_\\theta(.)$ is a policy parameterized by $\\theta$. $\\mu(s)$ Deterministic policy; we can also label this as $\\pi(s)$, but using a different letter gives better distinction so that we can easily tell when the policy is stochastic or deterministic without further explanation. Either $\\pi$ or $\\mu$ is what a reinforcement learning algorithm aims to learn. $V(s)$ State-value function measures the expected return of state $s$; $V_w(.)$ is a value function parameterized by $w$. $V^\\pi(s)$ The value of state $s$ when we follow a policy $\\pi$; $V^\\pi (s) = \\mathbb{E}_{a\\sim \\pi} [G_t \\vert S_t = s]$. $Q(s, a)$ Action-value function is similar to $V(s)$, but it assesses the expected return of a pair of state and action $(s, a)$; $Q_w(.)$ is a action value function parameterized by $w$. $Q^\\pi(s, a)$ Similar to $V^\\pi(.)$, the value of (state, action) pair when we follow a policy $\\pi$; $Q^\\pi(s, a) = \\mathbb{E}_{a\\sim \\pi} [G_t \\vert S_t = s, A_t = a]$. $A(s, a)$ Advantage function, $A(s, a) = Q(s, a) - V(s)$; it can be considered as another version of Q-value with lower variance by taking the state-value off as the baseline. Policy Gradient The goal of reinforcement learning is to find an optimal behavior strategy for the agent to obtain optimal rewards. The policy gradient methods target at modeling and optimizing the policy directly. The policy is usually modeled with a parameterized function respect to $\\theta$, $\\pi_\\theta(a \\vert s)$. The value of the reward (objective) function depends on this policy and then various algorithms can be applied to optimize $\\theta$ for the best reward.\nThe reward function is defined as:\n $$ J(\\theta) = \\sum_{s \\in \\mathcal{S}} d^\\pi(s) V^\\pi(s) = \\sum_{s \\in \\mathcal{S}} d^\\pi(s) \\sum_{a \\in \\mathcal{A}} \\pi_\\theta(a \\vert s) Q^\\pi(s, a) $$ where $d^\\pi(s)$ is the stationary distribution of Markov chain for $\\pi_\\theta$ (on-policy state distribution under $\\pi$). For simplicity, the parameter $\\theta$ would be omitted for the policy $\\pi_\\theta$ when the policy is present in the subscript of other functions; for example, $d^{\\pi}$ and $Q^\\pi$ should be $d^{\\pi_\\theta}$ and $Q^{\\pi_\\theta}$ if written in full.\nImagine that you can travel along the Markov chain\u0026rsquo;s states forever, and eventually, as the time progresses, the probability of you ending up with one state becomes unchanged \u0026mdash; this is the stationary probability for $\\pi_\\theta$. $d^\\pi(s) = \\lim_{t \\to \\infty} P(s_t = s \\vert s_0, \\pi_\\theta)$ is the probability that $s_t=s$ when starting from $s_0$ and following policy $\\pi_\\theta$ for t steps. Actually, the existence of the stationary distribution of Markov chain is one main reason for why PageRank algorithm works. If you want to read more, check this.\nIt is natural to expect policy-based methods are more useful in the continuous space. Because there is an infinite number of actions and (or) states to estimate the values for and hence value-based approaches are way too expensive computationally in the continuous space. For example, in generalized policy iteration, the policy improvement step $\\arg\\max_{a \\in \\mathcal{A}} Q^\\pi(s, a)$ requires a full scan of the action space, suffering from the curse of dimensionality.\nUsing gradient ascent, we can move $\\theta$ toward the direction suggested by the gradient $\\nabla_\\theta J(\\theta)$ to find the best $\\theta$ for $\\pi_\\theta$ that produces the highest return.\nPolicy Gradient Theorem Computing the gradient $\\nabla_\\theta J(\\theta)$ is tricky because it depends on both the action selection (directly determined by $\\pi_\\theta$) and the stationary distribution of states following the target selection behavior (indirectly determined by $\\pi_\\theta$). Given that the environment is generally unknown, it is difficult to estimate the effect on the state distribution by a policy update.\nLuckily, the policy gradient theorem comes to save the world! Woohoo! It provides a nice reformation of the derivative of the objective function to not involve the derivative of the state distribution $d^\\pi(.)$ and simplify the gradient computation $\\nabla_\\theta J(\\theta)$ a lot.\n $$ \\begin{aligned} \\nabla_\\theta J(\\theta) \u0026= \\nabla_\\theta \\sum_{s \\in \\mathcal{S}} d^\\pi(s) \\sum_{a \\in \\mathcal{A}} Q^\\pi(s, a) \\pi_\\theta(a \\vert s) \\\\ \u0026\\propto \\sum_{s \\in \\mathcal{S}} d^\\pi(s) \\sum_{a \\in \\mathcal{A}} Q^\\pi(s, a) \\nabla_\\theta \\pi_\\theta(a \\vert s) \\end{aligned} $$ Proof of Policy Gradient Theorem This session is pretty dense, as it is the time for us to go through the proof (Sutton \u0026amp; Barto, 2017; Sec. 13.1) and figure out why the policy gradient theorem is correct.\nWe first start with the derivative of the state value function:\n $$ \\begin{aligned} \u0026 \\nabla_\\theta V^\\pi(s) \\\\ =\u0026 \\nabla_\\theta \\Big(\\sum_{a \\in \\mathcal{A}} \\pi_\\theta(a \\vert s)Q^\\pi(s, a) \\Big) \u0026 \\\\ =\u0026 \\sum_{a \\in \\mathcal{A}} \\Big( \\nabla_\\theta \\pi_\\theta(a \\vert s)Q^\\pi(s, a) + \\pi_\\theta(a \\vert s) \\color{red}{\\nabla_\\theta Q^\\pi(s, a)} \\Big) \u0026 \\scriptstyle{\\text{; Derivative product rule.}} \\\\ =\u0026 \\sum_{a \\in \\mathcal{A}} \\Big( \\nabla_\\theta \\pi_\\theta(a \\vert s)Q^\\pi(s, a) + \\pi_\\theta(a \\vert s) \\color{red}{\\nabla_\\theta \\sum_{s', r} P(s',r \\vert s,a)(r + V^\\pi(s'))} \\Big) \u0026 \\scriptstyle{\\text{; Extend } Q^\\pi \\text{ with future state value.}} \\\\ =\u0026 \\sum_{a \\in \\mathcal{A}} \\Big( \\nabla_\\theta \\pi_\\theta(a \\vert s)Q^\\pi(s, a) + \\pi_\\theta(a \\vert s) \\color{red}{\\sum_{s', r} P(s',r \\vert s,a) \\nabla_\\theta V^\\pi(s')} \\Big) \u0026 \\scriptstyle{P(s',r \\vert s,a) \\text{ or } r \\text{ is not a func of }\\theta}\\\\ =\u0026 \\sum_{a \\in \\mathcal{A}} \\Big( \\nabla_\\theta \\pi_\\theta(a \\vert s)Q^\\pi(s, a) + \\pi_\\theta(a \\vert s) \\color{red}{\\sum_{s'} P(s' \\vert s,a) \\nabla_\\theta V^\\pi(s')} \\Big) \u0026 \\scriptstyle{\\text{; Because } P(s' \\vert s, a) = \\sum_r P(s', r \\vert s, a)} \\end{aligned} $$ Now we have:\n $$ \\color{red}{\\nabla_\\theta V^\\pi(s)} = \\sum_{a \\in \\mathcal{A}} \\Big( \\nabla_\\theta \\pi_\\theta(a \\vert s)Q^\\pi(s, a) + \\pi_\\theta(a \\vert s) \\sum_{s'} P(s' \\vert s,a) \\color{red}{\\nabla_\\theta V^\\pi(s')} \\Big) $$ This equation has a nice recursive form (see the red parts!) and the future state value function $V^\\pi(s')$ can be repeated unrolled by following the same equation.\nLet\u0026rsquo;s consider the following visitation sequence and label the probability of transitioning from state s to state x with policy $\\pi_\\theta$ after k step as $\\rho^\\pi(s \\to x, k)$.\n $$ s \\xrightarrow[]{a \\sim \\pi_\\theta(.\\vert s)} s' \\xrightarrow[]{a \\sim \\pi_\\theta(.\\vert s')} s'' \\xrightarrow[]{a \\sim \\pi_\\theta(.\\vert s'')} \\dots $$ When k = 0: $\\rho^\\pi(s \\to s, k=0) = 1$. When k = 1, we scan through all possible actions and sum up the transition probabilities to the target state: $\\rho^\\pi(s \\to s', k=1) = \\sum_a \\pi_\\theta(a \\vert s) P(s' \\vert s, a)$. Imagine that the goal is to go from state s to x after k+1 steps while following policy $\\pi_\\theta$. We can first travel from s to a middle point s' (any state can be a middle point, $s' \\in \\mathcal{S}$) after k steps and then go to the final state x during the last step. In this way, we are able to update the visitation probability recursively: $\\rho^\\pi(s \\to x, k+1) = \\sum_{s'} \\rho^\\pi(s \\to s', k) \\rho^\\pi(s' \\to x, 1)$. Then we go back to unroll the recursive representation of $\\nabla_\\theta V^\\pi(s)$! Let $\\phi(s) = \\sum_{a \\in \\mathcal{A}} \\nabla_\\theta \\pi_\\theta(a \\vert s)Q^\\pi(s, a)$ to simplify the maths. If we keep on extending $\\nabla_\\theta V^\\pi(.)$ infinitely, it is easy to find out that we can transition from the starting state s to any state after any number of steps in this unrolling process and by summing up all the visitation probabilities, we get $\\nabla_\\theta V^\\pi(s)$!\n $$ \\begin{aligned} \u0026 \\color{red}{\\nabla_\\theta V^\\pi(s)} \\\\ =\u0026 \\phi(s) + \\sum_a \\pi_\\theta(a \\vert s) \\sum_{s'} P(s' \\vert s,a) \\color{red}{\\nabla_\\theta V^\\pi(s')} \\\\ =\u0026 \\phi(s) + \\sum_{s'} \\sum_a \\pi_\\theta(a \\vert s) P(s' \\vert s,a) \\color{red}{\\nabla_\\theta V^\\pi(s')} \\\\ =\u0026 \\phi(s) + \\sum_{s'} \\rho^\\pi(s \\to s', 1) \\color{red}{\\nabla_\\theta V^\\pi(s')} \\\\ =\u0026 \\phi(s) + \\sum_{s'} \\rho^\\pi(s \\to s', 1) \\color{red}{\\nabla_\\theta V^\\pi(s')} \\\\ =\u0026 \\phi(s) + \\sum_{s'} \\rho^\\pi(s \\to s', 1) \\color{red}{[ \\phi(s') + \\sum_{s''} \\rho^\\pi(s' \\to s'', 1) \\nabla_\\theta V^\\pi(s'')]} \\\\ =\u0026 \\phi(s) + \\sum_{s'} \\rho^\\pi(s \\to s', 1) \\phi(s') + \\sum_{s''} \\rho^\\pi(s \\to s'', 2)\\color{red}{\\nabla_\\theta V^\\pi(s'')} \\scriptstyle{\\text{ ; Consider }s'\\text{ as the middle point for }s \\to s''}\\\\ =\u0026 \\phi(s) + \\sum_{s'} \\rho^\\pi(s \\to s', 1) \\phi(s') + \\sum_{s''} \\rho^\\pi(s \\to s'', 2)\\phi(s'') + \\sum_{s'''} \\rho^\\pi(s \\to s''', 3)\\color{red}{\\nabla_\\theta V^\\pi(s''')} \\\\ =\u0026 \\dots \\scriptstyle{\\text{; Repeatedly unrolling the part of }\\nabla_\\theta V^\\pi(.)} \\\\ =\u0026 \\sum_{x\\in\\mathcal{S}}\\sum_{k=0}^\\infty \\rho^\\pi(s \\to x, k) \\phi(x) \\end{aligned} $$ The nice rewriting above allows us to exclude the derivative of Q-value function, $\\nabla_\\theta Q^\\pi(s, a)$. By plugging it into the objective function $J(\\theta)$, we are getting the following:\n $$ \\begin{aligned} \\nabla_\\theta J(\\theta) \u0026= \\nabla_\\theta V^\\pi(s_0) \u0026 \\scriptstyle{\\text{; Starting from a random state } s_0} \\\\ \u0026= \\sum_{s}\\color{blue}{\\sum_{k=0}^\\infty \\rho^\\pi(s_0 \\to s, k)} \\phi(s) \u0026\\scriptstyle{\\text{; Let }\\color{blue}{\\eta(s) = \\sum_{k=0}^\\infty \\rho^\\pi(s_0 \\to s, k)}} \\\\ \u0026= \\sum_{s}\\eta(s) \\phi(s) \u0026 \\\\ \u0026= \\Big( {\\sum_s \\eta(s)} \\Big)\\sum_{s}\\frac{\\eta(s)}{\\sum_s \\eta(s)} \\phi(s) \u0026 \\scriptstyle{\\text{; Normalize } \\eta(s), s\\in\\mathcal{S} \\text{ to be a probability distribution.}}\\\\ \u0026\\propto \\sum_s \\frac{\\eta(s)}{\\sum_s \\eta(s)} \\phi(s) \u0026 \\scriptstyle{\\sum_s \\eta(s)\\text{ is a constant}} \\\\ \u0026= \\sum_s d^\\pi(s) \\sum_a \\nabla_\\theta \\pi_\\theta(a \\vert s)Q^\\pi(s, a) \u0026 \\scriptstyle{d^\\pi(s) = \\frac{\\eta(s)}{\\sum_s \\eta(s)}\\text{ is stationary distribution.}} \\end{aligned} $$ In the episodic case, the constant of proportionality ($\\sum_s \\eta(s)$) is the average length of an episode; in the continuing case, it is 1 (Sutton \u0026amp; Barto, 2017; Sec. 13.2). The gradient can be further written as:\n $$ \\begin{aligned} \\nabla_\\theta J(\\theta) \u0026\\propto \\sum_{s \\in \\mathcal{S}} d^\\pi(s) \\sum_{a \\in \\mathcal{A}} Q^\\pi(s, a) \\nabla_\\theta \\pi_\\theta(a \\vert s) \u0026\\\\ \u0026= \\sum_{s \\in \\mathcal{S}} d^\\pi(s) \\sum_{a \\in \\mathcal{A}} \\pi_\\theta(a \\vert s) Q^\\pi(s, a) \\frac{\\nabla_\\theta \\pi_\\theta(a \\vert s)}{\\pi_\\theta(a \\vert s)} \u0026\\\\ \u0026= \\mathbb{E}_\\pi [Q^\\pi(s, a) \\nabla_\\theta \\ln \\pi_\\theta(a \\vert s)] \u0026 \\scriptstyle{\\text{; Because } (\\ln x)' = 1/x} \\end{aligned} $$ Where $\\mathbb{E}_\\pi$ refers to $\\mathbb{E}_{s \\sim d_\\pi, a \\sim \\pi_\\theta}$ when both state and action distributions follow the policy $\\pi_\\theta$ (on policy).\nThe policy gradient theorem lays the theoretical foundation for various policy gradient algorithms. This vanilla policy gradient update has no bias but high variance. Many following algorithms were proposed to reduce the variance while keeping the bias unchanged.\n $$ \\nabla_\\theta J(\\theta) = \\mathbb{E}_\\pi [Q^\\pi(s, a) \\nabla_\\theta \\ln \\pi_\\theta(a \\vert s)] $$ Here is a nice summary of a general form of policy gradient methods borrowed from the GAE (general advantage estimation) paper (Schulman et al., 2016) and this post thoroughly discussed several components in GAE , highly recommended.\nFig. 1. A general form of policy gradient methods. (Image source: Schulman et al., 2016) Policy Gradient Algorithms Tons of policy gradient algorithms have been proposed during recent years and there is no way for me to exhaust them. I\u0026rsquo;m introducing some of them that I happened to know and read about.\nREINFORCE REINFORCE (Monte-Carlo policy gradient) relies on an estimated return by Monte-Carlo methods using episode samples to update the policy parameter $\\theta$. REINFORCE works because the expectation of the sample gradient is equal to the actual gradient:\n $$ \\begin{aligned} \\nabla_\\theta J(\\theta) \u0026= \\mathbb{E}_\\pi [Q^\\pi(s, a) \\nabla_\\theta \\ln \\pi_\\theta(a \\vert s)] \u0026 \\\\ \u0026= \\mathbb{E}_\\pi [G_t \\nabla_\\theta \\ln \\pi_\\theta(A_t \\vert S_t)] \u0026 \\scriptstyle{\\text{; Because } Q^\\pi(S_t, A_t) = \\mathbb{E}_\\pi[G_t \\vert S_t, A_t]} \\end{aligned} $$ Therefore we are able to measure $G_t$ from real sample trajectories and use that to update our policy gradient. It relies on a full trajectory and that\u0026rsquo;s why it is a Monte-Carlo method.\nThe process is pretty straightforward:\n Initialize the policy parameter $\\theta$ at random. Generate one trajectory on policy $\\pi_\\theta$: $S_1, A_1, R_2, S_2, A_2, \\dots, S_T$. For t=1, 2, \u0026hellip; , T: Estimate the the return $G_t$; Update policy parameters: $\\theta \\leftarrow \\theta + \\alpha \\gamma^t G_t \\nabla_\\theta \\ln \\pi_\\theta(A_t \\vert S_t)$ A widely used variation of REINFORCE is to subtract a baseline value from the return $G_t$ to reduce the variance of gradient estimation while keeping the bias unchanged (Remember we always want to do this when possible). For example, a common baseline is to subtract state-value from action-value, and if applied, we would use advantage $A(s, a) = Q(s, a) - V(s)$ in the gradient ascent update. This post nicely explained why a baseline works for reducing the variance, in addition to a set of fundamentals of policy gradient.\nActor-Critic Two main components in policy gradient are the policy model and the value function. It makes a lot of sense to learn the value function in addition to the policy, since knowing the value function can assist the policy update, such as by reducing gradient variance in vanilla policy gradients, and that is exactly what the Actor-Critic method does.\nActor-critic methods consist of two models, which may optionally share parameters:\n Critic updates the value function parameters w and depending on the algorithm it could be action-value $Q_w(a \\vert s)$ or state-value $V_w(s)$. Actor updates the policy parameters $\\theta$ for $\\pi_\\theta(a \\vert s)$, in the direction suggested by the critic. Let\u0026rsquo;s see how it works in a simple action-value actor-critic algorithm.\n Initialize $s, \\theta, w$ at random; sample $a \\sim \\pi_\\theta(a \\vert s)$. For $t = 1 \\dots T$: Sample reward $r_t \\sim R(s, a)$ and next state $s' \\sim P(s' \\vert s, a)$; Then sample the next action $a' \\sim \\pi_\\theta(a' \\vert s')$; Update the policy parameters: $\\theta \\leftarrow \\theta + \\alpha_\\theta Q_w(s, a) \\nabla_\\theta \\ln \\pi_\\theta(a \\vert s)$; Compute the correction (TD error) for action-value at time t: $\\delta_t = r_t + \\gamma Q_w(s', a') - Q_w(s, a)$ and use it to update the parameters of action-value function: $w \\leftarrow w + \\alpha_w \\delta_t \\nabla_w Q_w(s, a)$ Update $a \\leftarrow a'$ and $s \\leftarrow s'$. Two learning rates, $\\alpha_\\theta$ and $\\alpha_w$, are predefined for policy and value function parameter updates respectively.\nOff-Policy Policy Gradient Both REINFORCE and the vanilla version of actor-critic method are on-policy: training samples are collected according to the target policy \u0026mdash; the very same policy that we try to optimize for. Off policy methods, however, result in several additional advantages:\n The off-policy approach does not require full trajectories and can reuse any past episodes (“experience replay”) for much better sample efficiency. The sample collection follows a behavior policy different from the target policy, bringing better exploration. Now let\u0026rsquo;s see how off-policy policy gradient is computed. The behavior policy for collecting samples is a known policy (predefined just like a hyperparameter), labelled as $\\beta(a \\vert s)$. The objective function sums up the reward over the state distribution defined by this behavior policy:\n $$ J(\\theta) = \\sum_{s \\in \\mathcal{S}} d^\\beta(s) \\sum_{a \\in \\mathcal{A}} Q^\\pi(s, a) \\pi_\\theta(a \\vert s) = \\mathbb{E}_{s \\sim d^\\beta} \\big[ \\sum_{a \\in \\mathcal{A}} Q^\\pi(s, a) \\pi_\\theta(a \\vert s) \\big] $$ where $d^\\beta(s)$ is the stationary distribution of the behavior policy $\\beta$; recall that $d^\\beta(s) = \\lim_{t \\to \\infty} P(S_t = s \\vert S_0, \\beta)$; and $Q^\\pi$ is the action-value function estimated with regard to the target policy $\\pi$ (not the behavior policy!).\nGiven that the training observations are sampled by $a \\sim \\beta(a \\vert s)$, we can rewrite the gradient as:\n $$ \\begin{aligned} \\nabla_\\theta J(\\theta) \u0026= \\nabla_\\theta \\mathbb{E}_{s \\sim d^\\beta} \\Big[ \\sum_{a \\in \\mathcal{A}} Q^\\pi(s, a) \\pi_\\theta(a \\vert s) \\Big] \u0026 \\\\ \u0026= \\mathbb{E}_{s \\sim d^\\beta} \\Big[ \\sum_{a \\in \\mathcal{A}} \\big( Q^\\pi(s, a) \\nabla_\\theta \\pi_\\theta(a \\vert s) + \\color{red}{\\pi_\\theta(a \\vert s) \\nabla_\\theta Q^\\pi(s, a)} \\big) \\Big] \u0026 \\scriptstyle{\\text{; Derivative product rule.}}\\\\ \u0026\\stackrel{(i)}{\\approx} \\mathbb{E}_{s \\sim d^\\beta} \\Big[ \\sum_{a \\in \\mathcal{A}} Q^\\pi(s, a) \\nabla_\\theta \\pi_\\theta(a \\vert s) \\Big] \u0026 \\scriptstyle{\\text{; Ignore the red part: } \\color{red}{\\pi_\\theta(a \\vert s) \\nabla_\\theta Q^\\pi(s, a)}}. \\\\ \u0026= \\mathbb{E}_{s \\sim d^\\beta} \\Big[ \\sum_{a \\in \\mathcal{A}} \\beta(a \\vert s) \\frac{\\pi_\\theta(a \\vert s)}{\\beta(a \\vert s)} Q^\\pi(s, a) \\frac{\\nabla_\\theta \\pi_\\theta(a \\vert s)}{\\pi_\\theta(a \\vert s)} \\Big] \u0026 \\\\ \u0026= \\mathbb{E}_\\beta \\Big[\\frac{\\color{blue}{\\pi_\\theta(a \\vert s)}}{\\color{blue}{\\beta(a \\vert s)}} Q^\\pi(s, a) \\nabla_\\theta \\ln \\pi_\\theta(a \\vert s) \\Big] \u0026 \\scriptstyle{\\text{; The blue part is the importance weight.}} \\end{aligned} $$ where $\\frac{\\pi_\\theta(a \\vert s)}{\\beta(a \\vert s)}$ is the importance weight. Because $Q^\\pi$ is a function of the target policy and thus a function of policy parameter $\\theta$, we should take the derivative of $\\nabla_\\theta Q^\\pi(s, a)$ as well according to the product rule. However, it is super hard to compute $\\nabla_\\theta Q^\\pi(s, a)$ in reality. Fortunately if we use an approximated gradient with the gradient of Q ignored, we still guarantee the policy improvement and eventually achieve the true local minimum. This is justified in the proof here (Degris, White \u0026amp; Sutton, 2012).\nIn summary, when applying policy gradient in the off-policy setting, we can simple adjust it with a weighted sum and the weight is the ratio of the target policy to the behavior policy, $\\frac{\\pi_\\theta(a \\vert s)}{\\beta(a \\vert s)}$.\nA3C [paper|code]\nAsynchronous Advantage Actor-Critic (Mnih et al., 2016), short for A3C, is a classic policy gradient method with a special focus on parallel training.\nIn A3C, the critics learn the value function while multiple actors are trained in parallel and get synced with global parameters from time to time. Hence, A3C is designed to work well for parallel training.\nLet\u0026rsquo;s use the state-value function as an example. The loss function for state value is to minimize the mean squared error, $J_v(w) = (G_t - V_w(s))^2$ and gradient descent can be applied to find the optimal w. This state-value function is used as the baseline in the policy gradient update.\nHere is the algorithm outline:\n We have global parameters, $\\theta$ and $w$; similar thread-specific parameters, $\\theta'$ and $w'$.\n Initialize the time step $t = 1$\n While $T \\leq T_\\text{MAX}$:\n Reset gradient: $\\mathrm{d}\\theta = 0$ and $\\mathrm{d}w = 0$. Synchronize thread-specific parameters with global ones: $\\theta' = \\theta$ and $w' = w$. $t_\\text{start}$ = t and sample a starting state $s_t$. While ($s_t$ != TERMINAL) and $t - t_\\text{start} \\leq t_\\text{max}$: Pick the action $A_t \\sim \\pi_{\\theta'}(A_t \\vert S_t)$ and receive a new reward $R_t$ and a new state $s_{t+1}$. Update $t = t + 1$ and $T = T + 1$ Initialize the variable that holds the return estimation $$ R = \\begin{cases} 0 \u0026 \\text{if } s_t \\text{ is TERMINAL} \\\\ V_{w'}(s_t) \u0026 \\text{otherwise} \\end{cases} $$ 6. For $i = t-1, \\dots, t\\_\\text{start}$: 1. $R \\leftarrow \\gamma R + R\\_i$; here R is a MC measure of $G\\_i$. 2. Accumulate gradients w.r.t. $\\theta'$: $d\\theta \\leftarrow d\\theta + \\nabla\\_{\\theta'} \\log \\pi\\_{\\theta'}(a\\_i \\vert s\\_i)(R - V\\_{w'}(s\\_i))$;Accumulate gradients w.r.t. w': $dw \\leftarrow dw + 2 (R - V\\_{w'}(s\\_i)) \\nabla\\_{w'} (R - V\\_{w'}(s\\_i))$. Update asynchronously $\\theta$ using $\\mathrm{d}\\theta$, and $w$ using $\\mathrm{d}w$. A3C enables the parallelism in multiple agent training. The gradient accumulation step (6.2) can be considered as a parallelized reformation of minibatch-based stochastic gradient update: the values of $w$ or $\\theta$ get corrected by a little bit in the direction of each training thread independently.\nA2C [paper|code]\nA2C is a synchronous, deterministic version of A3C; that\u0026rsquo;s why it is named as “A2C” with the first “A” (“asynchronous”) removed. In A3C each agent talks to the global parameters independently, so it is possible sometimes the thread-specific agents would be playing with policies of different versions and therefore the aggregated update would not be optimal. To resolve the inconsistency, a coordinator in A2C waits for all the parallel actors to finish their work before updating the global parameters and then in the next iteration parallel actors starts from the same policy. The synchronized gradient update keeps the training more cohesive and potentially to make convergence faster.\nA2C has been shown to be able to utilize GPUs more efficiently and work better with large batch sizes while achieving same or better performance than A3C.\nFig. 2. The architecture of A3C versus A2C. DPG [paper|code]\nIn methods described above, the policy function $\\pi(. \\vert s)$ is always modeled as a probability distribution over actions $\\mathcal{A}$ given the current state and thus it is stochastic. Deterministic policy gradient (DPG) instead models the policy as a deterministic decision: $a = \\mu(s)$. It may look bizarre \u0026mdash; how can you calculate the gradient of the action probability when it outputs a single action? Let\u0026rsquo;s look into it step by step.\nRefresh on a few notations to facilitate the discussion:\n $\\rho_0(s)$: The initial distribution over states $\\rho^\\mu(s \\to s', k)$: Starting from state s, the visitation probability density at state s' after moving k steps by policy $\\mu$. $\\rho^\\mu(s')$: Discounted state distribution, defined as $\\rho^\\mu(s') = \\int_\\mathcal{S} \\sum_{k=1}^\\infty \\gamma^{k-1} \\rho_0(s) \\rho^\\mu(s \\to s', k) ds$. The objective function to optimize for is listed as follows:\n $$ J(\\theta) = \\int_\\mathcal{S} \\rho^\\mu(s) Q(s, \\mu_\\theta(s)) ds $$ Deterministic policy gradient theorem: Now it is the time to compute the gradient! According to the chain rule, we first take the gradient of Q w.r.t. the action a and then take the gradient of the deterministic policy function $\\mu$ w.r.t. $\\theta$:\n $$ \\begin{aligned} \\nabla_\\theta J(\\theta) \u0026= \\int_\\mathcal{S} \\rho^\\mu(s) \\nabla_a Q^\\mu(s, a) \\nabla_\\theta \\mu_\\theta(s) \\rvert_{a=\\mu_\\theta(s)} ds \\\\ \u0026= \\mathbb{E}_{s \\sim \\rho^\\mu} [\\nabla_a Q^\\mu(s, a) \\nabla_\\theta \\mu_\\theta(s) \\rvert_{a=\\mu_\\theta(s)}] \\end{aligned} $$ We can consider the deterministic policy as a special case of the stochastic one, when the probability distribution contains only one extreme non-zero value over one action. Actually, in the DPG paper, the authors have shown that if the stochastic policy $\\pi_{\\mu_\\theta, \\sigma}$ is re-parameterized by a deterministic policy $\\mu_\\theta$ and a variation variable $\\sigma$, the stochastic policy is eventually equivalent to the deterministic case when $\\sigma=0$. Compared to the deterministic policy, we expect the stochastic policy to require more samples as it integrates the data over the whole state and action space.\nThe deterministic policy gradient theorem can be plugged into common policy gradient frameworks.\nLet\u0026rsquo;s consider an example of on-policy actor-critic algorithm to showcase the procedure. In each iteration of on-policy actor-critic, two actions are taken deterministically $a = \\mu_\\theta(s)$ and the SARSA update on policy parameters relies on the new gradient that we just computed above:\n $$ \\begin{aligned} \\delta_t \u0026= R_t + \\gamma Q_w(s_{t+1}, a_{t+1}) - Q_w(s_t, a_t) \u0026 \\small{\\text{; TD error in SARSA}}\\\\ w_{t+1} \u0026= w_t + \\alpha_w \\delta_t \\nabla_w Q_w(s_t, a_t) \u0026 \\\\ \\theta_{t+1} \u0026= \\theta_t + \\alpha_\\theta \\color{red}{\\nabla_a Q_w(s_t, a_t) \\nabla_\\theta \\mu_\\theta(s) \\rvert_{a=\\mu_\\theta(s)}} \u0026 \\small{\\text{; Deterministic policy gradient theorem}} \\end{aligned} $$ However, unless there is sufficient noise in the environment, it is very hard to guarantee enough exploration due to the determinacy of the policy. We can either add noise into the policy (ironically this makes it nondeterministic!) or learn it off-policy-ly by following a different stochastic behavior policy to collect samples.\nSay, in the off-policy approach, the training trajectories are generated by a stochastic policy $\\beta(a \\vert s)$ and thus the state distribution follows the corresponding discounted state density $\\rho^\\beta$:\n $$ \\begin{aligned} J_\\beta(\\theta) \u0026= \\int_\\mathcal{S} \\rho^\\beta Q^\\mu(s, \\mu_\\theta(s)) ds \\\\ \\nabla_\\theta J_\\beta(\\theta) \u0026= \\mathbb{E}_{s \\sim \\rho^\\beta} [\\nabla_a Q^\\mu(s, a) \\nabla_\\theta \\mu_\\theta(s) \\rvert_{a=\\mu_\\theta(s)} ] \\end{aligned} $$ Note that because the policy is deterministic, we only need $Q^\\mu(s, \\mu_\\theta(s))$ rather than $\\sum_a \\pi(a \\vert s) Q^\\pi(s, a)$ as the estimated reward of a given state s. In the off-policy approach with a stochastic policy, importance sampling is often used to correct the mismatch between behavior and target policies, as what we have described above. However, because the deterministic policy gradient removes the integral over actions, we can avoid importance sampling.\nDDPG [paper|code]\nDDPG (Lillicrap, et al., 2015), short for Deep Deterministic Policy Gradient, is a model-free off-policy actor-critic algorithm, combining DPG with DQN. Recall that DQN (Deep Q-Network) stabilizes the learning of Q-function by experience replay and the frozen target network. The original DQN works in discrete space, and DDPG extends it to continuous space with the actor-critic framework while learning a deterministic policy.\nIn order to do better exploration, an exploration policy $\\mu'$ is constructed by adding noise $\\mathcal{N}$:\n $$ \\mu'(s) = \\mu_\\theta(s) + \\mathcal{N} $$ In addition, DDPG does soft updates (\u0026ldquo;conservative policy iteration\u0026rdquo;) on the parameters of both actor and critic, with $\\tau \\ll 1$: $\\theta' \\leftarrow \\tau \\theta + (1 - \\tau) \\theta'$. In this way, the target network values are constrained to change slowly, different from the design in DQN that the target network stays frozen for some period of time.\nOne detail in the paper that is particularly useful in robotics is on how to normalize the different physical units of low dimensional features. For example, a model is designed to learn a policy with the robot\u0026rsquo;s positions and velocities as input; these physical statistics are different by nature and even statistics of the same type may vary a lot across multiple robots. Batch normalization is applied to fix it by normalizing every dimension across samples in one minibatch.\nFig 3. DDPG Algorithm. (Image source: Lillicrap, et al., 2015) D4PG [paper|code (Search “github d4pg” and you will see a few.)]\nDistributed Distributional DDPG (D4PG) applies a set of improvements on DDPG to make it run in the distributional fashion.\n(1) Distributional Critic: The critic estimates the expected Q value as a random variable ~ a distribution $Z_w$ parameterized by $w$ and therefore $Q_w(s, a) = \\mathbb{E} Z_w(x, a)$. The loss for learning the distribution parameter is to minimize some measure of the distance between two distributions \u0026mdash; distributional TD error: $L(w) = \\mathbb{E}[d(\\mathcal{T}_{\\mu_\\theta}, Z_{w'}(s, a), Z_w(s, a)]$, where $\\mathcal{T}_{\\mu_\\theta}$ is the Bellman operator.\nThe deterministic policy gradient update becomes:\n $$ \\begin{aligned} \\nabla_\\theta J(\\theta) \u0026\\approx \\mathbb{E}_{\\rho^\\mu} [\\nabla_a Q_w(s, a) \\nabla_\\theta \\mu_\\theta(s) \\rvert_{a=\\mu_\\theta(s)}] \u0026 \\scriptstyle{\\text{; gradient update in DPG}} \\\\ \u0026= \\mathbb{E}_{\\rho^\\mu} [\\mathbb{E}[\\nabla_a Z_w(s, a)] \\nabla_\\theta \\mu_\\theta(s) \\rvert_{a=\\mu_\\theta(s)}] \u0026 \\scriptstyle{\\text{; expectation of the Q-value distribution.}} \\end{aligned} $$ (2) $N$-step returns: When calculating the TD error, D4PG computes $N$-step TD target rather than one-step to incorporate rewards in more future steps. Thus the new TD target is:\n $$ r(s_0, a_0) + \\mathbb{E}[\\sum_{n=1}^{N-1} r(s_n, a_n) + \\gamma^N Q(s_N, \\mu_\\theta(s_N)) \\vert s_0, a_0 ] $$ (3) Multiple Distributed Parallel Actors: D4PG utilizes $K$ independent actors, gathering experience in parallel and feeding data into the same replay buffer.\n(4) Prioritized Experience Replay (PER): The last piece of modification is to do sampling from the replay buffer of size $R$ with an non-uniform probability $p_i$. In this way, a sample $i$ has the probability $(Rp_i)^{-1}$ to be selected and thus the importance weight is $(Rp_i)^{-1}$.\nFig. 4. D4PG algorithm (Image source: Barth-Maron, et al. 2018); Note that in the original paper, the variable letters are chosen slightly differently from what in the post; i.e. I use $\\mu(.)$ for representing a deterministic policy instead of $\\pi(.)$. MADDPG [paper|code]\nMulti-agent DDPG (MADDPG) (Lowe et al., 2017) extends DDPG to an environment where multiple agents are coordinating to complete tasks with only local information. In the viewpoint of one agent, the environment is non-stationary as policies of other agents are quickly upgraded and remain unknown. MADDPG is an actor-critic model redesigned particularly for handling such a changing environment and interactions between agents.\nThe problem can be formalized in the multi-agent version of MDP, also known as Markov games. MADDPG is proposed for partially observable Markov games. Say, there are N agents in total with a set of states $\\mathcal{S}$. Each agent owns a set of possible action, $\\mathcal{A}_1, \\dots, \\mathcal{A}_N$, and a set of observation, $\\mathcal{O}_1, \\dots, \\mathcal{O}_N$. The state transition function involves all states, action and observation spaces $\\mathcal{T}: \\mathcal{S} \\times \\mathcal{A}_1 \\times \\dots \\mathcal{A}_N \\mapsto \\mathcal{S}$. Each agent\u0026rsquo;s stochastic policy only involves its own state and action: $\\pi_{\\theta_i}: \\mathcal{O}_i \\times \\mathcal{A}_i \\mapsto [0, 1]$, a probability distribution over actions given its own observation, or a deterministic policy: $\\mu_{\\theta_i}: \\mathcal{O}_i \\mapsto \\mathcal{A}_i$.\nLet $\\vec{o} = {o_1, \\dots, o_N}$, $\\vec{\\mu} = {\\mu_1, \\dots, \\mu_N}$ and the policies are parameterized by $\\vec{\\theta} = {\\theta_1, \\dots, \\theta_N}$.\nThe critic in MADDPG learns a centralized action-value function $Q^\\vec{\\mu}_i(\\vec{o}, a_1, \\dots, a_N)$ for the i-th agent, where $a_1 \\in \\mathcal{A}_1, \\dots, a_N \\in \\mathcal{A}_N$ are actions of all agents. Each $Q^\\vec{\\mu}_i$ is learned separately for $i=1, \\dots, N$ and therefore multiple agents can have arbitrary reward structures, including conflicting rewards in a competitive setting. Meanwhile, multiple actors, one for each agent, are exploring and upgrading the policy parameters $\\theta_i$ on their own.\nActor update:\n $$ \\nabla_{\\theta_i} J(\\theta_i) = \\mathbb{E}_{\\vec{o}, a \\sim \\mathcal{D}} [\\nabla_{a_i} Q^{\\vec{\\mu}}_i (\\vec{o}, a_1, \\dots, a_N) \\nabla_{\\theta_i} \\mu_{\\theta_i}(o_i) \\rvert_{a_i=\\mu_{\\theta_i}(o_i)} ] $$ Where $\\mathcal{D}$ is the memory buffer for experience replay, containing multiple episode samples $(\\vec{o}, a_1, \\dots, a_N, r_1, \\dots, r_N, \\vec{o}')$ \u0026mdash; given current observation $\\vec{o}$, agents take action $a_1, \\dots, a_N$ and get rewards $r_1, \\dots, r_N$, leading to the new observation $\\vec{o}'$.\nCritic update:\n $$ \\begin{aligned} \\mathcal{L}(\\theta_i) \u0026= \\mathbb{E}_{\\vec{o}, a_1, \\dots, a_N, r_1, \\dots, r_N, \\vec{o}'}[ (Q^{\\vec{\\mu}}_i(\\vec{o}, a_1, \\dots, a_N) - y)^2 ] \u0026 \\\\ \\text{where } y \u0026= r_i + \\gamma Q^{\\vec{\\mu}'}_i (\\vec{o}', a'_1, \\dots, a'_N) \\rvert_{a'_j = \\mu'_{\\theta_j}} \u0026 \\scriptstyle{\\text{; TD target!}} \\end{aligned} $$ where $\\vec{\\mu}'$ are the target policies with delayed softly-updated parameters.\nIf the policies $\\vec{\\mu}$ are unknown during the critic update, we can ask each agent to learn and evolve its own approximation of others' policies. Using the approximated policies, MADDPG still can learn efficiently although the inferred policies might not be accurate.\nTo mitigate the high variance triggered by the interaction between competing or collaborating agents in the environment, MADDPG proposed one more element - policy ensembles:\n Train K policies for one agent; Pick a random policy for episode rollouts; Take an ensemble of these K policies to do gradient update. In summary, MADDPG added three additional ingredients on top of DDPG to make it adapt to the multi-agent environment:\n Centralized critic + decentralized actors; Actors are able to use estimated policies of other agents for learning; Policy ensembling is good for reducing variance. Fig. 5. The architecture design of MADDPG. (Image source: Lowe et al., 2017) TRPO [paper|code]\nTo improve training stability, we should avoid parameter updates that change the policy too much at one step. Trust region policy optimization (TRPO) (Schulman, et al., 2015) carries out this idea by enforcing a KL divergence constraint on the size of policy update at each iteration.\nConsider the case when we are doing off-policy RL, the policy $\\beta$ used for collecting trajectories on rollout workers is different from the policy $\\pi$ to optimize for. The objective function in an off-policy model measures the total advantage over the state visitation distribution and actions, while the mismatch between the training data distribution and the true policy state distribution is compensated by importance sampling estimator:\n $$ \\begin{aligned} J(\\theta) \u0026= \\sum_{s \\in \\mathcal{S}} \\rho^{\\pi_{\\theta_\\text{old}}} \\sum_{a \\in \\mathcal{A}} \\big( \\pi_\\theta(a \\vert s) \\hat{A}_{\\theta_\\text{old}}(s, a) \\big) \u0026 \\\\ \u0026= \\sum_{s \\in \\mathcal{S}} \\rho^{\\pi_{\\theta_\\text{old}}} \\sum_{a \\in \\mathcal{A}} \\big( \\beta(a \\vert s) \\frac{\\pi_\\theta(a \\vert s)}{\\beta(a \\vert s)} \\hat{A}_{\\theta_\\text{old}}(s, a) \\big) \u0026 \\scriptstyle{\\text{; Importance sampling}} \\\\ \u0026= \\mathbb{E}_{s \\sim \\rho^{\\pi_{\\theta_\\text{old}}}, a \\sim \\beta} \\big[ \\frac{\\pi_\\theta(a \\vert s)}{\\beta(a \\vert s)} \\hat{A}_{\\theta_\\text{old}}(s, a) \\big] \u0026 \\end{aligned} $$ where $\\theta_\\text{old}$ is the policy parameters before the update and thus known to us; $\\rho^{\\pi_{\\theta_\\text{old}}}$ is defined in the same way as above; $\\beta(a \\vert s)$ is the behavior policy for collecting trajectories. Noted that we use an estimated advantage $\\hat{A}(.)$ rather than the true advantage function $A(.)$ because the true rewards are usually unknown.\nWhen training on policy, theoretically the policy for collecting data is same as the policy that we want to optimize. However, when rollout workers and optimizers are running in parallel asynchronously, the behavior policy can get stale. TRPO considers this subtle difference: It labels the behavior policy as $\\pi_{\\theta_\\text{old}}(a \\vert s)$ and thus the objective function becomes:\n $$ J(\\theta) = \\mathbb{E}_{s \\sim \\rho^{\\pi_{\\theta_\\text{old}}}, a \\sim \\pi_{\\theta_\\text{old}}} \\big[ \\frac{\\pi_\\theta(a \\vert s)}{\\pi_{\\theta_\\text{old}}(a \\vert s)} \\hat{A}_{\\theta_\\text{old}}(s, a) \\big] $$ TRPO aims to maximize the objective function $J(\\theta)$ subject to, trust region constraint which enforces the distance between old and new policies measured by KL-divergence to be small enough, within a parameter δ:\n $$ \\mathbb{E}_{s \\sim \\rho^{\\pi_{\\theta_\\text{old}}}} [D_\\text{KL}(\\pi_{\\theta_\\text{old}}(.\\vert s) \\| \\pi_\\theta(.\\vert s)] \\leq \\delta $$ In this way, the old and new policies would not diverge too much when this hard constraint is met. While still, TRPO can guarantee a monotonic improvement over policy iteration (Neat, right?). Please read the proof in the paper if interested :)\nPPO [paper|code]\nGiven that TRPO is relatively complicated and we still want to implement a similar constraint, proximal policy optimization (PPO) simplifies it by using a clipped surrogate objective while retaining similar performance.\nFirst, let\u0026rsquo;s denote the probability ratio between old and new policies as:\n $$ r(\\theta) = \\frac{\\pi_\\theta(a \\vert s)}{\\pi_{\\theta_\\text{old}}(a \\vert s)} $$ Then, the objective function of TRPO (on policy) becomes:\n $$ J^\\text{TRPO} (\\theta) = \\mathbb{E} [ r(\\theta) \\hat{A}_{\\theta_\\text{old}}(s, a) ] $$ Without a limitation on the distance between $\\theta_\\text{old}$ and $\\theta$, to maximize $J^\\text{TRPO} (\\theta)$ would lead to instability with extremely large parameter updates and big policy ratios. PPO imposes the constraint by forcing $r(\\theta)$ to stay within a small interval around 1, precisely $[1-\\epsilon, 1+\\epsilon]$, where $\\epsilon$ is a hyperparameter.\n $$ J^\\text{CLIP} (\\theta) = \\mathbb{E} [ \\min( r(\\theta) \\hat{A}_{\\theta_\\text{old}}(s, a), \\text{clip}(r(\\theta), 1 - \\epsilon, 1 + \\epsilon) \\hat{A}_{\\theta_\\text{old}}(s, a))] $$ The function $\\text{clip}(r(\\theta), 1 - \\epsilon, 1 + \\epsilon)$ clips the ratio to be no more than $1+\\epsilon$ and no less than $1-\\epsilon$. The objective function of PPO takes the minimum one between the original value and the clipped version and therefore we lose the motivation for increasing the policy update to extremes for better rewards.\nWhen applying PPO on the network architecture with shared parameters for both policy (actor) and value (critic) functions, in addition to the clipped reward, the objective function is augmented with an error term on the value estimation (formula in red) and an entropy term (formula in blue) to encourage sufficient exploration.\n $$ J^\\text{CLIP'} (\\theta) = \\mathbb{E} [ J^\\text{CLIP} (\\theta) - \\color{red}{c_1 (V_\\theta(s) - V_\\text{target})^2} + \\color{blue}{c_2 H(s, \\pi_\\theta(.))} ] $$ where Both $c_1$ and $c_2$ are two hyperparameter constants.\nPPO has been tested on a set of benchmark tasks and proved to produce awesome results with much greater simplicity.\nIn a later paper by Hsu et al., 2020, two common design choices in PPO are revisited, precisely (1) clipped probability ratio for policy regularization and (2) parameterize policy action space by continuous Gaussian or discrete softmax distribution. They first identified three failure modes in PPO and proposed replacements for these two designs.\nThe failure modes are:\n On continuous action spaces, standard PPO is unstable when rewards vanish outside bounded support. On discrete action spaces with sparse high rewards, standard PPO often gets stuck at suboptimal actions. The policy is sensitive to initialization when there are locally optimal actions close to initialization. Discretizing the action space or use Beta distribution helps avoid failure mode 1\u0026amp;3 associated with Gaussian policy. Using KL regularization (same motivation as in TRPO) as an alternative surrogate model helps resolve failure mode 1\u0026amp;2.\nPPG [paper|code]\nSharing parameters between policy and value networks have pros and cons. It allows policy and value functions to share the learned features with each other, but it may cause conflicts between competing objectives and demands the same data for training two networks at the same time. Phasic policy gradient (PPG; Cobbe, et al 2020) modifies the traditional on-policy actor-critic policy gradient algorithm. precisely PPO, to have separate training phases for policy and value functions. In two alternating phases:\n The policy phase: updates the policy network by optimizing the PPO objective $L^\\text{CLIP} (\\theta)$; The auxiliary phase: optimizes an auxiliary objective alongside a behavioral cloning loss. In the paper, value function error is the sole auxiliary objective, but it can be quite general and includes any other additional auxiliary losses. $$ \\begin{aligned} L^\\text{joint} \u0026= L^\\text{aux} + \\beta_\\text{clone} \\cdot \\mathbb{E}_t[\\text{KL}[\\pi_{\\theta_\\text{old}}(\\cdot\\mid s_t), \\pi_\\theta(\\cdot\\mid s_t)]] \\\\ L^\\text{aux} \u0026= L^\\text{value} = \\mathbb{E}_t \\big[\\frac{1}{2}\\big( V_w(s_t) - \\hat{V}_t^\\text{targ} \\big)^2\\big] \\end{aligned} $$ where $\\beta_\\text{clone}$ is a hyperparameter for controlling how much we would like to keep the policy not diverge too much from its original behavior while optimizing the auxiliary objectives.\nFig. 6. The algorithm of PPG. (Image source: Cobbe, et al 2020) where\n $N_\\pi$ is the number of policy update iterations in the policy phase. Note that the policy phase performs multiple iterations of updates per single auxiliary phase. $E_\\pi$ and $E_V$ control the sample reuse (i.e. the number of training epochs performed across data in the reply buffer) for the policy and value functions, respectively. Note that this happens within the policy phase and thus $E_V$ affects the learning of true value function not the auxiliary value function. $E_\\text{aux}$ defines the sample reuse in the auxiliary phrase. In PPG, value function optimization can tolerate a much higher level sample reuse; for example, in the experiments of the paper, $E_\\text{aux} = 6$ while $E_\\pi = E_V = 1$. PPG leads to a significant improvement on sample efficiency compared to PPO.\nFig. 7. The mean normalized performance of PPG vs PPO on the Procgen benchmark. (Image source: Cobbe, et al 2020) ACER [paper|code]\nACER, short for actor-critic with experience replay (Wang, et al., 2017), is an off-policy actor-critic model with experience replay, greatly increasing the sample efficiency and decreasing the data correlation. A3C builds up the foundation for ACER, but it is on policy; ACER is A3C\u0026rsquo;s off-policy counterpart. The major obstacle to making A3C off policy is how to control the stability of the off-policy estimator. ACER proposes three designs to overcome it:\n Use Retrace Q-value estimation; Truncate the importance weights with bias correction; Apply efficient TRPO. Retrace Q-value Estimation\nRetrace is an off-policy return-based Q-value estimation algorithm with a nice guarantee for convergence for any target and behavior policy pair $(\\pi, \\beta)$, plus good data efficiency.\nRecall how TD learning works for prediction:\n Compute TD error: $\\delta_t = R_t + \\gamma \\mathbb{E}_{a \\sim \\pi} Q(S_{t+1}, a) - Q(S_t, A_t)$; the term $r_t + \\gamma \\mathbb{E}_{a \\sim \\pi} Q(s_{t+1}, a) $ is known as “TD target”. The expectation $\\mathbb{E}_{a \\sim \\pi}$ is used because for the future step the best estimation we can make is what the return would be if we follow the current policy $\\pi$. Update the value by correcting the error to move toward the goal: $Q(S_t, A_t) \\leftarrow Q(S_t, A_t) + \\alpha \\delta_t$. In other words, the incremental update on Q is proportional to the TD error: $\\Delta Q(S_t, A_t) = \\alpha \\delta_t$. When the rollout is off policy, we need to apply importance sampling on the Q update:\n $$ \\Delta Q^\\text{imp}(S_t, A_t) = \\gamma^t \\prod_{1 \\leq \\tau \\leq t} \\frac{\\pi(A_\\tau \\vert S_\\tau)}{\\beta(A_\\tau \\vert S_\\tau)} \\delta_t $$ The product of importance weights looks pretty scary when we start imagining how it can cause super high variance and even explode. Retrace Q-value estimation method modifies $\\Delta Q$ to have importance weights truncated by no more than a constant $c$:\n $$ \\Delta Q^\\text{ret}(S_t, A_t) = \\gamma^t \\prod_{1 \\leq \\tau \\leq t} \\min(c, \\frac{\\pi(A_\\tau \\vert S_\\tau)}{\\beta(A_\\tau \\vert S_\\tau)}) \\delta_t $$ ACER uses $Q^\\text{ret}$ as the target to train the critic by minimizing the L2 error term: $(Q^\\text{ret}(s, a) - Q(s, a))^2$.\nImportance weights truncation\nTo reduce the high variance of the policy gradient $\\hat{g}$, ACER truncates the importance weights by a constant c, plus a correction term. The label $\\hat{g}_t^\\text{acer}$ is the ACER policy gradient at time t.\n $$ \\begin{aligned} \\hat{g}_t^\\text{acer} = \u0026 \\omega_t \\big( Q^\\text{ret}(S_t, A_t) - V_{\\theta_v}(S_t) \\big) \\nabla_\\theta \\ln \\pi_\\theta(A_t \\vert S_t) \u0026 \\scriptstyle{\\text{; Let }\\omega_t=\\frac{\\pi(A_t \\vert S_t)}{\\beta(A_t \\vert S_t)}} \\\\ = \u0026 \\color{blue}{\\min(c, \\omega_t) \\big( Q^\\text{ret}(S_t, A_t) - V_w(S_t) \\big) \\nabla_\\theta \\ln \\pi_\\theta(A_t \\vert S_t)} \\\\ \u0026 + \\color{red}{\\mathbb{E}_{a \\sim \\pi} \\big[ \\max(0, \\frac{\\omega_t(a) - c}{\\omega_t(a)}) \\big( Q_w(S_t, a) - V_w(S_t) \\big) \\nabla_\\theta \\ln \\pi_\\theta(a \\vert S_t) \\big]} \u0026 \\scriptstyle{\\text{; Let }\\omega_t (a) =\\frac{\\pi(a \\vert S_t)}{\\beta(a \\vert S_t)}} \\end{aligned} $$ where $Q_w(.)$ and $V_w(.)$ are value functions predicted by the critic with parameter w. The first term (blue) contains the clipped important weight. The clipping helps reduce the variance, in addition to subtracting state value function $V_w(.)$ as a baseline. The second term (red) makes a correction to achieve unbiased estimation.\nEfficient TRPO\nFurthermore, ACER adopts the idea of TRPO but with a small adjustment to make it more computationally efficient: rather than measuring the KL divergence between policies before and after one update, ACER maintains a running average of past policies and forces the updated policy to not deviate far from this average.\nThe ACER paper is pretty dense with many equations. Hopefully, with the prior knowledge on TD learning, Q-learning, importance sampling and TRPO, you will find the paper slightly easier to follow :)\nACTKR [paper|code]\nACKTR (actor-critic using Kronecker-factored trust region) (Yuhuai Wu, et al., 2017) proposed to use Kronecker-factored approximation curvature (K-FAC) to do the gradient update for both the critic and actor. K-FAC made an improvement on the computation of natural gradient, which is quite different from our standard gradient. Here is a nice, intuitive explanation of natural gradient. One sentence summary is probably:\n “we first consider all combinations of parameters that result in a new network a constant KL divergence away from the old network. This constant value can be viewed as the step size or learning rate. Out of all these possible combinations, we choose the one that minimizes our loss function.”\n I listed ACTKR here mainly for the completeness of this post, but I would not dive into details, as it involves a lot of theoretical knowledge on natural gradient and optimization methods. If interested, check these papers/posts, before reading the ACKTR paper:\n Amari. Natural Gradient Works Efficiently in Learning. 1998 Kakade. A Natural Policy Gradient. 2002 A intuitive explanation of natural gradient descent Wiki: Kronecker product Martens \u0026amp; Grosse. Optimizing neural networks with kronecker-factored approximate curvature. 2015. Here is a high level summary from the K-FAC paper:\n \u0026ldquo;This approximation is built in two stages. In the first, the rows and columns of the Fisher are divided into groups, each of which corresponds to all the weights in a given layer, and this gives rise to a block-partitioning of the matrix. These blocks are then approximated as Kronecker products between much smaller matrices, which we show is equivalent to making certain approximating assumptions regarding the statistics of the network\u0026rsquo;s gradients.\n In the second stage, this matrix is further approximated as having an inverse which is either block-diagonal or block-tridiagonal. We justify this approximation through a careful examination of the relationships between inverse covariances, tree-structured graphical models, and linear regression. Notably, this justification doesn\u0026rsquo;t apply to the Fisher itself, and our experiments confirm that while the inverse Fisher does indeed possess this structure (approximately), the Fisher itself does not.\u0026rdquo;\n SAC [paper|code]\nSoft Actor-Critic (SAC) (Haarnoja et al. 2018) incorporates the entropy measure of the policy into the reward to encourage exploration: we expect to learn a policy that acts as randomly as possible while it is still able to succeed at the task. It is an off-policy actor-critic model following the maximum entropy reinforcement learning framework. A precedent work is Soft Q-learning.\nThree key components in SAC:\n An actor-critic architecture with separate policy and value function networks; An off-policy formulation that enables reuse of previously collected data for efficiency; Entropy maximization to enable stability and exploration. The policy is trained with the objective to maximize the expected return and the entropy at the same time:\n $$ J(\\theta) = \\sum_{t=1}^T \\mathbb{E}_{(s_t, a_t) \\sim \\rho_{\\pi_\\theta}} [r(s_t, a_t) + \\alpha \\mathcal{H}(\\pi_\\theta(.\\vert s_t))] $$ where $\\mathcal{H}(.)$ is the entropy measure and $\\alpha$ controls how important the entropy term is, known as temperature parameter. The entropy maximization leads to policies that can (1) explore more and (2) capture multiple modes of near-optimal strategies (i.e., if there exist multiple options that seem to be equally good, the policy should assign each with an equal probability to be chosen).\nPrecisely, SAC aims to learn three functions:\n The policy with parameter $\\theta$, $\\pi_\\theta$. Soft Q-value function parameterized by $w$, $Q_w$. Soft state value function parameterized by $\\psi$, $V_\\psi$; theoretically we can infer $V$ by knowing $Q$ and $\\pi$, but in practice, it helps stabilize the training. Soft Q-value and soft state value are defined as:\n $$ \\begin{aligned} Q(s_t, a_t) \u0026= r(s_t, a_t) + \\gamma \\mathbb{E}_{s_{t+1} \\sim \\rho_{\\pi}(s)} [V(s_{t+1})] \u0026 \\text{; according to Bellman equation.}\\\\ \\text{where }V(s_t) \u0026= \\mathbb{E}_{a_t \\sim \\pi} [Q(s_t, a_t) - \\alpha \\log \\pi(a_t \\vert s_t)] \u0026 \\text{; soft state value function.} \\end{aligned} $$ $$ \\text{Thus, } Q(s_t, a_t) = r(s_t, a_t) + \\gamma \\mathbb{E}_{(s_{t+1}, a_{t+1}) \\sim \\rho_{\\pi}} [Q(s_{t+1}, a_{t+1}) - \\alpha \\log \\pi(a_{t+1} \\vert s_{t+1})] $$ $\\rho_\\pi(s)$ and $\\rho_\\pi(s, a)$ denote the state and the state-action marginals of the state distribution induced by the policy $\\pi(a \\vert s)$; see the similar definitions in DPG section.\nThe soft state value function is trained to minimize the mean squared error:\n $$ \\begin{aligned} J_V(\\psi) \u0026= \\mathbb{E}_{s_t \\sim \\mathcal{D}} [\\frac{1}{2} \\big(V_\\psi(s_t) - \\mathbb{E}[Q_w(s_t, a_t) - \\log \\pi_\\theta(a_t \\vert s_t)] \\big)^2] \\\\ \\text{with gradient: }\\nabla_\\psi J_V(\\psi) \u0026= \\nabla_\\psi V_\\psi(s_t)\\big( V_\\psi(s_t) - Q_w(s_t, a_t) + \\log \\pi_\\theta (a_t \\vert s_t) \\big) \\end{aligned} $$ where $\\mathcal{D}$ is the replay buffer.\nThe soft Q function is trained to minimize the soft Bellman residual:\n $$ \\begin{aligned} J_Q(w) \u0026= \\mathbb{E}_{(s_t, a_t) \\sim \\mathcal{D}} [\\frac{1}{2}\\big( Q_w(s_t, a_t) - (r(s_t, a_t) + \\gamma \\mathbb{E}_{s_{t+1} \\sim \\rho_\\pi(s)}[V_{\\bar{\\psi}}(s_{t+1})]) \\big)^2] \\\\ \\text{with gradient: } \\nabla_w J_Q(w) \u0026= \\nabla_w Q_w(s_t, a_t) \\big( Q_w(s_t, a_t) - r(s_t, a_t) - \\gamma V_{\\bar{\\psi}}(s_{t+1})\\big) \\end{aligned} $$ where $\\bar{\\psi}$ is the target value function which is the exponential moving average (or only gets updated periodically in a “hard” way), just like how the parameter of the target Q network is treated in DQN to stabilize the training.\nSAC updates the policy to minimize the KL-divergence:\n $$ \\begin{aligned} \\pi_\\text{new} \u0026= \\arg\\min_{\\pi' \\in \\Pi} D_\\text{KL} \\Big( \\pi'(.\\vert s_t) \\| \\frac{\\exp(Q^{\\pi_\\text{old}}(s_t, .))}{Z^{\\pi_\\text{old}}(s_t)} \\Big) \\\\[6pt] \u0026= \\arg\\min_{\\pi' \\in \\Pi} D_\\text{KL} \\big( \\pi'(.\\vert s_t) \\| \\exp(Q^{\\pi_\\text{old}}(s_t, .) - \\log Z^{\\pi_\\text{old}}(s_t)) \\big) \\\\[6pt] \\text{objective for update: } J_\\pi(\\theta) \u0026= \\nabla_\\theta D_\\text{KL} \\big( \\pi_\\theta(. \\vert s_t) \\| \\exp(Q_w(s_t, .) - \\log Z_w(s_t)) \\big) \\\\[6pt] \u0026= \\mathbb{E}_{a_t\\sim\\pi} \\Big[ - \\log \\big( \\frac{\\exp(Q_w(s_t, a_t) - \\log Z_w(s_t))}{\\pi_\\theta(a_t \\vert s_t)} \\big) \\Big] \\\\[6pt] \u0026= \\mathbb{E}_{a_t\\sim\\pi} [ \\log \\pi_\\theta(a_t \\vert s_t) - Q_w(s_t, a_t) + \\log Z_w(s_t) ] \\end{aligned} $$ where $\\Pi$ is the set of potential policies that we can model our policy as to keep them tractable; for example, $\\Pi$ can be the family of Gaussian mixture distributions, expensive to model but highly expressive and still tractable. $Z^{\\pi_\\text{old}}(s_t)$ is the partition function to normalize the distribution. It is usually intractable but does not contribute to the gradient. How to minimize $J_\\pi(\\theta)$ depends our choice of $\\Pi$.\nThis update guarantees that $Q^{\\pi_\\text{new}}(s_t, a_t) \\geq Q^{\\pi_\\text{old}}(s_t, a_t)$, please check the proof on this lemma in the Appendix B.2 in the original paper.\nOnce we have defined the objective functions and gradients for soft action-state value, soft state value and the policy network, the soft actor-critic algorithm is straightforward:\nFig. 8. The soft actor-critic algorithm. (Image source: original paper) SAC with Automatically Adjusted Temperature [paper|code]\nSAC is brittle with respect to the temperature parameter. Unfortunately it is difficult to adjust temperature, because the entropy can vary unpredictably both across tasks and during training as the policy becomes better. An improvement on SAC formulates a constrained optimization problem: while maximizing the expected return, the policy should satisfy a minimum entropy constraint:\n $$ \\max_{\\pi_0, \\dots, \\pi_T} \\mathbb{E} \\Big[ \\sum_{t=0}^T r(s_t, a_t)\\Big] \\text{s.t. } \\forall t\\text{, } \\mathcal{H}(\\pi_t) \\geq \\mathcal{H}_0 $$ where $\\mathcal{H}_0$ is a predefined minimum policy entropy threshold.\nThe expected return $\\mathbb{E} \\Big[ \\sum_{t=0}^T r(s_t, a_t)\\Big]$ can be decomposed into a sum of rewards at all the time steps. Because the policy $\\pi_t$ at time t has no effect on the policy at the earlier time step, $\\pi_{t-1}$, we can maximize the return at different steps backward in time \u0026mdash; this is essentially DP.\n $$ \\underbrace{\\max_{\\pi_0} \\Big( \\mathbb{E}[r(s_0, a_0)]+ \\underbrace{\\max_{\\pi_1} \\Big(\\mathbb{E}[...] + \\underbrace{\\max_{\\pi_T} \\mathbb{E}[r(s_T, a_T)]}_\\text{1st maximization} \\Big)}_\\text{second but last maximization} \\Big)}_\\text{last maximization} $$ where we consider $\\gamma=1$.\nSo we start the optimization from the last timestep $T$:\n $$ \\text{maximize } \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi}} [ r(s_T, a_T) ] \\text{ s.t. } \\mathcal{H}(\\pi_T) - \\mathcal{H}_0 \\geq 0 $$ First, let us define the following functions:\n $$ \\begin{aligned} h(\\pi_T) \u0026= \\mathcal{H}(\\pi_T) - \\mathcal{H}_0 = \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi}} [-\\log \\pi_T(a_T\\vert s_T)] - \\mathcal{H}_0\\\\ f(\\pi_T) \u0026= \\begin{cases} \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi}} [ r(s_T, a_T) ], \u0026 \\text{if }h(\\pi_T) \\geq 0 \\\\ -\\infty, \u0026 \\text{otherwise} \\end{cases} \\end{aligned} $$ And the optimization becomes:\n $$ \\text{maximize } f(\\pi_T) \\text{ s.t. } h(\\pi_T) \\geq 0 $$ To solve the maximization optimization with inequality constraint, we can construct a Lagrangian expression with a Lagrange multiplier (also known as \u0026ldquo;dual variable\u0026rdquo;), $\\alpha_T$:\n $$ L(\\pi_T, \\alpha_T) = f(\\pi_T) + \\alpha_T h(\\pi_T) $$ Considering the case when we try to minimize $L(\\pi_T, \\alpha_T)$ with respect to $\\alpha_T$ - given a particular value $\\pi_T$,\n If the constraint is satisfied, $h(\\pi_T) \\geq 0$, at best we can set $\\alpha_T=0$ since we have no control over the value of $f(\\pi_T)$. Thus, $L(\\pi_T, 0) = f(\\pi_T)$. If the constraint is invalidated, $h(\\pi_T) \u0026lt; 0$, we can achieve $L(\\pi_T, \\alpha_T) \\to -\\infty$ by taking $\\alpha_T \\to \\infty$. Thus, $L(\\pi_T, \\infty) = -\\infty = f(\\pi_T)$. In either case, we can recover the following equation,\n $$ f(\\pi_T) = \\min_{\\alpha_T \\geq 0} L(\\pi_T, \\alpha_T) $$ At the same time, we want to maximize $f(\\pi_T)$,\n $$ \\max_{\\pi_T} f(\\pi_T) = \\min_{\\alpha_T \\geq 0} \\max_{\\pi_T} L(\\pi_T, \\alpha_T) $$ Therefore, to maximize $f(\\pi_T)$, the dual problem is listed as below. Note that to make sure $\\max_{\\pi_T} f(\\pi_T)$ is properly maximized and would not become $-\\infty$, the constraint has to be satisfied.\n $$ \\begin{aligned} \\max_{\\pi_T} \\mathbb{E}[ r(s_T, a_T) ] \u0026= \\max_{\\pi_T} f(\\pi_T) \\\\ \u0026= \\min_{\\alpha_T \\geq 0} \\max_{\\pi_T} L(\\pi_T, \\alpha_T) \\\\ \u0026= \\min_{\\alpha_T \\geq 0} \\max_{\\pi_T} f(\\pi_T) + \\alpha_T h(\\pi_T) \\\\ \u0026= \\min_{\\alpha_T \\geq 0} \\max_{\\pi_T} \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi}} [ r(s_T, a_T) ] + \\alpha_T ( \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi}} [-\\log \\pi_T(a_T\\vert s_T)] - \\mathcal{H}_0) \\\\ \u0026= \\min_{\\alpha_T \\geq 0} \\max_{\\pi_T} \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi}} [ r(s_T, a_T) - \\alpha_T \\log \\pi_T(a_T\\vert s_T)] - \\alpha_T \\mathcal{H}_0 \\\\ \u0026= \\min_{\\alpha_T \\geq 0} \\max_{\\pi_T} \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi}} [ r(s_T, a_T) + \\alpha_T \\mathcal{H}(\\pi_T) - \\alpha_T \\mathcal{H}_0 ] \\end{aligned} $$ We could compute the optimal $\\pi_T$ and $\\alpha_T$ iteratively. First given the current $\\alpha_T$, get the best policy $\\pi_T^{*}$ that maximizes $L(\\pi_T^{*}, \\alpha_T)$. Then plug in $\\pi_T^{*}$ and compute $\\alpha_T^{*}$ that minimizes $L(\\pi_T^{*}, \\alpha_T)$. Assuming we have one neural network for policy and one network for temperature parameter, the iterative update process is more aligned with how we update network parameters during training.\n $$ \\begin{aligned} \\pi^{*}_T \u0026= \\arg\\max_{\\pi_T} \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi}} [ r(s_T, a_T) + \\alpha_T \\mathcal{H}(\\pi_T) - \\alpha_T \\mathcal{H}_0 ] \\\\ \\color{blue}{\\alpha^{*}_T} \u0026\\color{blue}{=} \\color{blue}{\\arg\\min_{\\alpha_T \\geq 0} \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi^{*}}} [\\alpha_T \\mathcal{H}(\\pi^{*}_T) - \\alpha_T \\mathcal{H}_0 ]} \\end{aligned} $$ $$ \\text{Thus, }\\max_{\\pi_T} \\mathbb{E} [ r(s_T, a_T) ] = \\mathbb{E}_{(s_T, a_T) \\sim \\rho_{\\pi^{*}}} [ r(s_T, a_T) + \\alpha^{*}_T \\mathcal{H}(\\pi^{*}_T) - \\alpha^{*}_T \\mathcal{H}_0 ] $$ Now let\u0026rsquo;s go back to the soft Q value function:\n $$ \\begin{aligned} Q_{T-1}(s_{T-1}, a_{T-1}) \u0026= r(s_{T-1}, a_{T-1}) + \\mathbb{E} [Q(s_T, a_T) - \\alpha_T \\log \\pi(a_T \\vert s_T)] \\\\ \u0026= r(s_{T-1}, a_{T-1}) + \\mathbb{E} [r(s_T, a_T)] + \\alpha_T \\mathcal{H}(\\pi_T) \\\\ Q_{T-1}^{*}(s_{T-1}, a_{T-1}) \u0026= r(s_{T-1}, a_{T-1}) + \\max_{\\pi_T} \\mathbb{E} [r(s_T, a_T)] + \\alpha_T \\mathcal{H}(\\pi^{*}_T) \u0026 \\text{; plug in the optimal }\\pi_T^{*} \\end{aligned} $$ Therefore the expected return is as follows, when we take one step further back to the time step $T-1$:\n $$ \\begin{aligned} \u0026\\max_{\\pi_{T-1}}\\Big(\\mathbb{E}[r(s_{T-1}, a_{T-1})] + \\max_{\\pi_T} \\mathbb{E}[r(s_T, a_T] \\Big) \\\\ \u0026= \\max_{\\pi_{T-1}} \\Big( Q^{*}_{T-1}(s_{T-1}, a_{T-1}) - \\alpha^{*}_T \\mathcal{H}(\\pi^{*}_T) \\Big) \u0026 \\text{; should s.t. } \\mathcal{H}(\\pi_{T-1}) - \\mathcal{H}_0 \\geq 0 \\\\ \u0026= \\min_{\\alpha_{T-1} \\geq 0} \\max_{\\pi_{T-1}} \\Big( Q^{*}_{T-1}(s_{T-1}, a_{T-1}) - \\alpha^{*}_T \\mathcal{H}(\\pi^{*}_T) + \\alpha_{T-1} \\big( \\mathcal{H}(\\pi_{T-1}) - \\mathcal{H}_0 \\big) \\Big) \u0026 \\text{; dual problem w/ Lagrangian.} \\\\ \u0026= \\min_{\\alpha_{T-1} \\geq 0} \\max_{\\pi_{T-1}} \\Big( Q^{*}_{T-1}(s_{T-1}, a_{T-1}) + \\alpha_{T-1} \\mathcal{H}(\\pi_{T-1}) - \\alpha_{T-1}\\mathcal{H}_0 \\Big) - \\alpha^{*}_T \\mathcal{H}(\\pi^{*}_T) \\end{aligned} $$ Similar to the previous step,\n $$ \\begin{aligned} \\pi^{*}_{T-1} \u0026= \\arg\\max_{\\pi_{T-1}} \\mathbb{E}_{(s_{T-1}, a_{T-1}) \\sim \\rho_\\pi} [Q^{*}_{T-1}(s_{T-1}, a_{T-1}) + \\alpha_{T-1} \\mathcal{H}(\\pi_{T-1}) - \\alpha_{T-1} \\mathcal{H}_0 ] \\\\ \\color{green}{\\alpha^{*}_{T-1}} \u0026\\color{green}{=} \\color{green}{\\arg\\min_{\\alpha_{T-1} \\geq 0} \\mathbb{E}_{(s_{T-1}, a_{T-1}) \\sim \\rho_{\\pi^{*}}} [ \\alpha_{T-1} \\mathcal{H}(\\pi^{*}_{T-1}) - \\alpha_{T-1}\\mathcal{H}_0 ]} \\end{aligned} $$ The equation for updating $\\alpha_{T-1}$ in green has the same format as the equation for updating $\\alpha_{T-1}$ in blue above. By repeating this process, we can learn the optimal temperature parameter in every step by minimizing the same objective function:\n $$ J(\\alpha) = \\mathbb{E}_{a_t \\sim \\pi_t} [-\\alpha \\log \\pi_t(a_t \\mid s_t) - \\alpha \\mathcal{H}_0] $$ The final algorithm is same as SAC except for learning $\\alpha$ explicitly with respect to the objective $J(\\alpha)$ (see Fig. 7):\nFig. 9. The soft actor-critic algorithm with automatically adjusted temperature. (Image source: original paper) TD3 [paper|code]\nThe Q-learning algorithm is commonly known to suffer from the overestimation of the value function. This overestimation can propagate through the training iterations and negatively affect the policy. This property directly motivated Double Q-learning and Double DQN: the action selection and Q-value update are decoupled by using two value networks.\nTwin Delayed Deep Deterministic (short for TD3; Fujimoto et al., 2018) applied a couple of tricks on DDPG to prevent the overestimation of the value function:\n(1) Clipped Double Q-learning: In Double Q-Learning, the action selection and Q-value estimation are made by two networks separately. In the DDPG setting, given two deterministic actors $(\\mu_{\\theta_1}, \\mu_{\\theta_2})$ with two corresponding critics $(Q_{w_1}, Q_{w_2})$, the Double Q-learning Bellman targets look like:\n $$ \\begin{aligned} y_1 \u0026= r + \\gamma Q_{w_2}(s', \\mu_{\\theta_1}(s'))\\\\ y_2 \u0026= r + \\gamma Q_{w_1}(s', \\mu_{\\theta_2}(s')) \\end{aligned} $$ However, due to the slow changing policy, these two networks could be too similar to make independent decisions. The Clipped Double Q-learning instead uses the minimum estimation among two so as to favor underestimation bias which is hard to propagate through training:\n $$ \\begin{aligned} y_1 \u0026= r + \\gamma \\min_{i=1,2}Q_{w_i}(s', \\mu_{\\theta_1}(s'))\\\\ y_2 \u0026= r + \\gamma \\min_{i=1,2} Q_{w_i}(s', \\mu_{\\theta_2}(s')) \\end{aligned} $$ (2) Delayed update of Target and Policy Networks: In the actor-critic model, policy and value updates are deeply coupled: Value estimates diverge through overestimation when the policy is poor, and the policy will become poor if the value estimate itself is inaccurate.\nTo reduce the variance, TD3 updates the policy at a lower frequency than the Q-function. The policy network stays the same until the value error is small enough after several updates. The idea is similar to how the periodically-updated target network stay as a stable objective in DQN.\n(3) Target Policy Smoothing: Given a concern with deterministic policies that they can overfit to narrow peaks in the value function, TD3 introduced a smoothing regularization strategy on the value function: adding a small amount of clipped random noises to the selected action and averaging over mini-batches.\n $$ \\begin{aligned} y \u0026= r + \\gamma Q_w (s', \\mu_{\\theta}(s') + \\epsilon) \u0026 \\\\ \\epsilon \u0026\\sim \\text{clip}(\\mathcal{N}(0, \\sigma), -c, +c) \u0026 \\scriptstyle{\\text{ ; clipped random noises.}} \\end{aligned} $$ This approach mimics the idea of SARSA update and enforces that similar actions should have similar values.\nHere is the final algorithm:\nFig. 10. TD3 Algorithm. (Image source: Fujimoto et al., 2018) SVPG [paper|code for SVPG]\nStein Variational Policy Gradient (SVPG; Liu et al, 2017) applies the Stein variational gradient descent (SVGD; Liu and Wang, 2016) algorithm to update the policy parameter $\\theta$.\nIn the setup of maximum entropy policy optimization, $\\theta$ is considered as a random variable $\\theta \\sim q(\\theta)$ and the model is expected to learn this distribution $q(\\theta)$. Assuming we know a prior on how $q$ might look like, $q_0$, and we would like to guide the learning process to not make $\\theta$ too far away from $q_0$ by optimizing the following objective function:\n $$ \\hat{J}(\\theta) = \\mathbb{E}_{\\theta \\sim q} [J(\\theta)] - \\alpha D_\\text{KL}(q\\|q_0) $$ where $\\mathbb{E}_{\\theta \\sim q} [R(\\theta)]$ is the expected reward when $\\theta \\sim q(\\theta)$ and $D_\\text{KL}$ is the KL divergence.\nIf we don\u0026rsquo;t have any prior information, we might set $q_0$ as a uniform distribution and set $q_0(\\theta)$ to a constant. Then the above objective function becomes SAC, where the entropy term encourages exploration:\n $$ \\begin{aligned} \\hat{J}(\\theta) \u0026= \\mathbb{E}_{\\theta \\sim q} [J(\\theta)] - \\alpha D_\\text{KL}(q\\|q_0) \\\\ \u0026= \\mathbb{E}_{\\theta \\sim q} [J(\\theta)] - \\alpha \\mathbb{E}_{\\theta \\sim q} [\\log q(\\theta) - \\log q_0(\\theta)] \\\\ \u0026= \\mathbb{E}_{\\theta \\sim q} [J(\\theta)] + \\alpha H(q(\\theta)) \\end{aligned} $$ Let\u0026rsquo;s take the derivative of $\\hat{J}(\\theta) = \\mathbb{E}_{\\theta \\sim q} [J(\\theta)] - \\alpha D_\\text{KL}(q|q_0)$ w.r.t. $q$:\n $$ \\begin{aligned} \\nabla_q \\hat{J}(\\theta) \u0026= \\nabla_q \\big( \\mathbb{E}_{\\theta \\sim q} [J(\\theta)] - \\alpha D_\\text{KL}(q\\|q_0) \\big) \\\\ \u0026= \\nabla_q \\int_\\theta \\big( q(\\theta) J(\\theta) - \\alpha q(\\theta)\\log q(\\theta) + \\alpha q(\\theta) \\log q_0(\\theta) \\big) \\\\ \u0026= \\int_\\theta \\big( J(\\theta) - \\alpha \\log q(\\theta) -\\alpha + \\alpha \\log q_0(\\theta) \\big) \\\\ \u0026= 0 \\end{aligned} $$ The optimal distribution is:\n $$ \\log q^{*}(\\theta) = \\frac{1}{\\alpha} J(\\theta) + \\log q_0(\\theta) - 1 \\text{ thus } \\underbrace{ q^{*}(\\theta) }_\\textrm{\"posterior\"} \\propto \\underbrace{\\exp ( J(\\theta) / \\alpha )}_\\textrm{\"likelihood\"} \\underbrace{q_0(\\theta)}_\\textrm{prior} $$ The temperature $\\alpha$ decides a tradeoff between exploitation and exploration. When $\\alpha \\rightarrow 0$, $\\theta$ is updated only according to the expected return $J(\\theta)$. When $\\alpha \\rightarrow \\infty$, $\\theta$ always follows the prior belief.\nWhen using the SVGD method to estimate the target posterior distribution $q(\\theta)$, it relies on a set of particle $\\{\\theta_i\\}_{i=1}^n$ (independently trained policy agents) and each is updated:\n $$ \\theta_i \\gets \\theta_i + \\epsilon \\phi^{*}(\\theta_i) \\text{ where } \\phi^{*} = \\max_{\\phi \\in \\mathcal{H}} \\{ - \\nabla_\\epsilon D_\\text{KL} (q'_{[\\theta + \\epsilon \\phi(\\theta)]} \\| q) \\text{ s.t. } \\|\\phi\\|_{\\mathcal{H}} \\leq 1\\} $$ where $\\epsilon$ is a learning rate and $\\phi^{*}$ is the unit ball of a RKHS (reproducing kernel Hilbert space) $\\mathcal{H}$ of $\\theta$-shaped value vectors that maximally decreases the KL divergence between the particles and the target distribution. $q'(.)$ is the distribution of $\\theta + \\epsilon \\phi(\\theta)$.\nComparing different gradient-based update methods:\n Method Update space Plain gradient $\\Delta \\theta$ on the parameter space Natural gradient $\\Delta \\theta$ on the search distribution space SVGD $\\Delta \\theta$ on the kernel function space (edited) One estimation of $\\phi^{*}$ has the following form. A positive definite kernel $k(\\vartheta, \\theta)$, i.e. a Gaussian radial basis function, measures the similarity between particles.\n $$ \\begin{aligned} \\phi^{*}(\\theta_i) \u0026= \\mathbb{E}_{\\vartheta \\sim q'} [\\nabla_\\vartheta \\log q(\\vartheta) k(\\vartheta, \\theta_i) + \\nabla_\\vartheta k(\\vartheta, \\theta_i)]\\\\ \u0026= \\frac{1}{n} \\sum_{j=1}^n [\\color{red}{\\nabla_{\\theta_j} \\log q(\\theta_j) k(\\theta_j, \\theta_i)} + \\color{green}{\\nabla_{\\theta_j} k(\\theta_j, \\theta_i)}] \u0026 \\scriptstyle{\\text{;approximate }q'\\text{ with current particle values}} \\end{aligned} $$ The first term in red encourages $\\theta_i$ learning towards the high probability regions of $q$ that is shared across similar particles. =\u0026gt; to be similar to other particles The second term in green pushes particles away from each other and therefore diversifies the policy. =\u0026gt; to be dissimilar to other particles Usually the temperature $\\alpha$ follows an annealing scheme so that the training process does more exploration at the beginning but more exploitation at a later stage.\nIMPALA [paper|code]\nIn order to scale up RL training to achieve a very high throughput, IMPALA (\u0026ldquo;Importance Weighted Actor-Learner Architecture\u0026rdquo;) framework decouples acting from learning on top of basic actor-critic setup and learns from all experience trajectories with V-trace off-policy correction.\nMultiple actors generate experience in parallel, while the learner optimizes both policy and value function parameters using all the generated experience. Actors update their parameters with the latest policy from the learner periodically. Because acting and learning are decoupled, we can add many more actor machines to generate a lot more trajectories per time unit. As the training policy and the behavior policy are not totally synchronized, there is a gap between them and thus we need off-policy corrections.\nLet the value function $V_\\theta$ parameterized by $\\theta$ and the policy $\\pi_\\phi$ parameterized by $\\phi$. Also we know the trajectories in the replay buffer are collected by a slightly older policy $\\mu$.\nAt the training time $t$, given $(s_t, a_t, s_{t+1}, r_t)$, the value function parameter $\\theta$ is learned through an L2 loss between the current value and a V-trace value target. The $n$-step V-trace target is defined as:\n $$ \\begin{aligned} v_t \u0026= V_\\theta(s_t) + \\sum_{i=t}^{t+n-1} \\gamma^{i-t} \\big(\\prod_{j=t}^{i-1} c_j\\big) \\color{red}{\\delta_i V} \\\\ \u0026= V_\\theta(s_t) + \\sum_{i=t}^{t+n-1} \\gamma^{i-t} \\big(\\prod_{j=t}^{i-1} c_j\\big) \\color{red}{\\rho_i (r_i + \\gamma V_\\theta(s_{i+1}) - V_\\theta(s_i))} \\end{aligned} $$ where the red part $\\delta_i V$ is a temporal difference for $V$. $\\rho_i = \\min\\big(\\bar{\\rho}, \\frac{\\pi(a_i \\vert s_i)}{\\mu(a_i \\vert s_i)}\\big)$ and $c_j = \\min\\big(\\bar{c}, \\frac{\\pi(a_j \\vert s_j)}{\\mu(a_j \\vert s_j)}\\big)$ are truncated importance sampling (IS) weights. The product of $c_t, \\dots, c_{i-1}$ measures how much a temporal difference $\\delta_i V$ observed at time $i$ impacts the update of the value function at a previous time $t$. In the on-policy case, we have $\\rho_i=1$ and $c_j=1$ (assuming $\\bar{c} \\geq 1$) and therefore the V-trace target becomes on-policy $n$-step Bellman target.\n$\\bar{\\rho}$ and $\\bar{c}$ are two truncation constants with $\\bar{\\rho} \\geq \\bar{c}$. $\\bar{\\rho}$ impacts the fixed-point of the value function we converge to and $\\bar{c}$ impacts the speed of convergence. When $\\bar{\\rho} =\\infty$ (untruncated), we converge to the value function of the target policy $V^\\pi$; when $\\bar{\\rho}$ is close to 0, we evaluate the value function of the behavior policy $V^\\mu$; when in-between, we evaluate a policy between $\\pi$ and $\\mu$.\nThe value function parameter is therefore updated in the direction of:\n $$ \\Delta\\theta = (v_t - V_\\theta(s_t))\\nabla_\\theta V_\\theta(s_t) $$ The policy parameter $\\phi$ is updated through policy gradient,\n $$ \\begin{aligned} \\Delta \\phi \u0026= \\rho_t \\nabla_\\phi \\log \\pi_\\phi(a_t \\vert s_t) \\big(r_t + \\gamma v_{t+1} - V_\\theta(s_t)\\big) + \\nabla_\\phi H(\\pi_\\phi)\\\\ \u0026= \\rho_t \\nabla_\\phi \\log \\pi_\\phi(a_t \\vert s_t) \\big(r_t + \\gamma v_{t+1} - V_\\theta(s_t)\\big) - \\nabla_\\phi \\sum_a \\pi_\\phi(a\\vert s_t)\\log \\pi_\\phi(a\\vert s_t) \\end{aligned} $$ where $r_t + \\gamma v_{t+1}$ is the estimated Q value, from which a state-dependent baseline $V_\\theta(s_t)$ is subtracted. $H(\\pi_\\phi)$ is an entropy bonus to encourage exploration.\nIn the experiments, IMPALA is used to train one agent over multiple tasks. Two different model architectures are involved, a shallow model (left) and a deep residual model (right).\nQuick Summary After reading through all the algorithms above, I list a few building blocks or principles that seem to be common among them:\n Try to reduce the variance and keep the bias unchanged to stabilize learning. Off-policy gives us better exploration and helps us use data samples more efficiently. Experience replay (training data sampled from a replay memory buffer); Target network that is either frozen periodically or updated slower than the actively learned policy network; Batch normalization; Entropy-regularized reward; The critic and actor can share lower layer parameters of the network and two output heads for policy and value functions. It is possible to learn with deterministic policy rather than stochastic one. Put constraint on the divergence between policy updates. New optimization methods (such as K-FAC). Entropy maximization of the policy helps encourage exploration. Try not to overestimate the value function. Think twice whether the policy and value network should share parameters. TBA more. Cited as:\n@article{weng2018PG, title = \u0026quot;Policy Gradient Algorithms\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2018\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2018-04-08-policy-gradient/\u0026quot; } References [1] jeremykun.com Markov Chain Monte Carlo Without all the Bullshit\n[2] Richard S. Sutton and Andrew G. Barto. Reinforcement Learning: An Introduction; 2nd Edition. 2017.\n[3] John Schulman, et al. \u0026ldquo;High-dimensional continuous control using generalized advantage estimation.\u0026quot; ICLR 2016.\n[4] Thomas Degris, Martha White, and Richard S. Sutton. \u0026ldquo;Off-policy actor-critic.\u0026quot; ICML 2012.\n[5] timvieira.github.io Importance sampling\n[6] Mnih, Volodymyr, et al. \u0026ldquo;Asynchronous methods for deep reinforcement learning.\u0026quot; ICML. 2016.\n[7] David Silver, et al. \u0026ldquo;Deterministic policy gradient algorithms.\u0026quot; ICML. 2014.\n[8] Timothy P. Lillicrap, et al. \u0026ldquo;Continuous control with deep reinforcement learning.\u0026quot; arXiv preprint arXiv:1509.02971 (2015).\n[9] Ryan Lowe, et al. \u0026ldquo;Multi-agent actor-critic for mixed cooperative-competitive environments.\u0026quot; NIPS. 2017.\n[10] John Schulman, et al. \u0026ldquo;Trust region policy optimization.\u0026quot; ICML. 2015.\n[11] Ziyu Wang, et al. \u0026ldquo;Sample efficient actor-critic with experience replay.\u0026quot; ICLR 2017.\n[12] Rémi Munos, Tom Stepleton, Anna Harutyunyan, and Marc Bellemare. \u0026ldquo;Safe and efficient off-policy reinforcement learning\u0026rdquo; NIPS. 2016.\n[13] Yuhuai Wu, et al. \u0026ldquo;Scalable trust-region method for deep reinforcement learning using Kronecker-factored approximation.\u0026quot; NIPS. 2017.\n[14] kvfrans.com A intuitive explanation of natural gradient descent\n[15] Sham Kakade. \u0026ldquo;A Natural Policy Gradient.\u0026quot;. NIPS. 2002.\n[16] \u0026ldquo;Going Deeper Into Reinforcement Learning: Fundamentals of Policy Gradients.\u0026quot; - Seita\u0026rsquo;s Place, Mar 2017.\n[17] \u0026ldquo;Notes on the Generalized Advantage Estimation Paper.\u0026quot; - Seita\u0026rsquo;s Place, Apr, 2017.\n[18] Gabriel Barth-Maron, et al. \u0026ldquo;Distributed Distributional Deterministic Policy Gradients.\u0026quot; ICLR 2018 poster.\n[19] Tuomas Haarnoja, Aurick Zhou, Pieter Abbeel, and Sergey Levine. \u0026ldquo;Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.\u0026quot; arXiv preprint arXiv:1801.01290 (2018).\n[20] Scott Fujimoto, Herke van Hoof, and Dave Meger. \u0026ldquo;Addressing Function Approximation Error in Actor-Critic Methods.\u0026quot; arXiv preprint arXiv:1802.09477 (2018).\n[21] Tuomas Haarnoja, et al. \u0026ldquo;Soft Actor-Critic Algorithms and Applications.\u0026quot; arXiv preprint arXiv:1812.05905 (2018).\n[22] David Knowles. \u0026ldquo;Lagrangian Duality for Dummies\u0026rdquo; Nov 13, 2010.\n[23] Yang Liu, et al. \u0026ldquo;Stein variational policy gradient.\u0026quot; arXiv preprint arXiv:1704.02399 (2017).\n[24] Qiang Liu and Dilin Wang. \u0026ldquo;Stein variational gradient descent: A general purpose bayesian inference algorithm.\u0026quot; NIPS. 2016.\n[25] Lasse Espeholt, et al. \u0026ldquo;IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures\u0026rdquo; arXiv preprint 1802.01561 (2018).\n[26] Karl Cobbe, et al. \u0026ldquo;Phasic Policy Gradient.\u0026quot; arXiv preprint arXiv:2009.04416 (2020).\n[27] Chloe Ching-Yun Hsu, et al. \u0026ldquo;Revisiting Design Choices in Proximal Policy Optimization.\u0026quot; arXiv preprint arXiv:2009.10897 (2020).\n","permalink":"https://lilianweng.github.io/posts/2018-04-08-policy-gradient/","summary":"[Updated on 2018-06-30: add two new policy gradient methods, SAC and D4PG.] [Updated on 2018-09-30: add a new policy gradient method, TD3.] [Updated on 2019-02-09: add SAC with automatically adjusted temperature]. [Updated on 2019-06-26: Thanks to Chanseok, we have a version of this post in Korean]. [Updated on 2019-09-12: add a new policy gradient method SVPG.] [Updated on 2019-12-22: add a new policy gradient method IMPALA.","title":"Policy Gradient Algorithms"},{"content":"[Updated on 2020-09-03: Updated the algorithm of SARSA and Q-learning so that the difference is more pronounced. [Updated on 2021-09-19: Thanks to 爱吃猫的鱼, we have this post in Chinese].\nA couple of exciting news in Artificial Intelligence (AI) has just happened in recent years. AlphaGo defeated the best professional human player in the game of Go. Very soon the extended algorithm AlphaGo Zero beat AlphaGo by 100-0 without supervised learning on human knowledge. Top professional game players lost to the bot developed by OpenAI on DOTA2 1v1 competition. After knowing these, it is pretty hard not to be curious about the magic behind these algorithms \u0026mdash; Reinforcement Learning (RL). I\u0026rsquo;m writing this post to briefly go over the field. We will first introduce several fundamental concepts and then dive into classic approaches to solving RL problems. Hopefully, this post could be a good starting point for newbies, bridging the future study on the cutting-edge research.\nWhat is Reinforcement Learning? Say, we have an agent in an unknown environment and this agent can obtain some rewards by interacting with the environment. The agent ought to take actions so as to maximize cumulative rewards. In reality, the scenario could be a bot playing a game to achieve high scores, or a robot trying to complete physical tasks with physical items; and not just limited to these.\nFig. 1. An agent interacts with the environment, trying to take smart actions to maximize cumulative rewards. The goal of Reinforcement Learning (RL) is to learn a good strategy for the agent from experimental trials and relative simple feedback received. With the optimal strategy, the agent is capable to actively adapt to the environment to maximize future rewards.\nKey Concepts Now Let\u0026rsquo;s formally define a set of key concepts in RL.\nThe agent is acting in an environment. How the environment reacts to certain actions is defined by a model which we may or may not know. The agent can stay in one of many states ($s \\in \\mathcal{S}$) of the environment, and choose to take one of many actions ($a \\in \\mathcal{A}$) to switch from one state to another. Which state the agent will arrive in is decided by transition probabilities between states ($P$). Once an action is taken, the environment delivers a reward ($r \\in \\mathcal{R}$) as feedback.\nThe model defines the reward function and transition probabilities. We may or may not know how the model works and this differentiate two circumstances:\n Know the model: planning with perfect information; do model-based RL. When we fully know the environment, we can find the optimal solution by Dynamic Programming (DP). Do you still remember \u0026ldquo;longest increasing subsequence\u0026rdquo; or \u0026ldquo;traveling salesmen problem\u0026rdquo; from your Algorithms 101 class? LOL. This is not the focus of this post though. Does not know the model: learning with incomplete information; do model-free RL or try to learn the model explicitly as part of the algorithm. Most of the following content serves the scenarios when the model is unknown. The agent\u0026rsquo;s policy $\\pi(s)$ provides the guideline on what is the optimal action to take in a certain state with the goal to maximize the total rewards. Each state is associated with a value function $V(s)$ predicting the expected amount of future rewards we are able to receive in this state by acting the corresponding policy. In other words, the value function quantifies how good a state is. Both policy and value functions are what we try to learn in reinforcement learning.\nFig. 2. Summary of approaches in RL based on whether we want to model the value, policy, or the environment. (Image source: reproduced from David Silver's RL course lecture 1.) The interaction between the agent and the environment involves a sequence of actions and observed rewards in time, $t=1, 2, \\dots, T$. During the process, the agent accumulates the knowledge about the environment, learns the optimal policy, and makes decisions on which action to take next so as to efficiently learn the best policy. Let\u0026rsquo;s label the state, action, and reward at time step t as $S_t$, $A_t$, and $R_t$, respectively. Thus the interaction sequence is fully described by one episode (also known as \u0026ldquo;trial\u0026rdquo; or \u0026ldquo;trajectory\u0026rdquo;) and the sequence ends at the terminal state $S_T$:\n $$ S_1, A_1, R_2, S_2, A_2, \\dots, S_T $$ Terms you will encounter a lot when diving into different categories of RL algorithms:\n Model-based: Rely on the model of the environment; either the model is known or the algorithm learns it explicitly. Model-free: No dependency on the model during learning. On-policy: Use the deterministic outcomes or samples from the target policy to train the algorithm. Off-policy: Training on a distribution of transitions or episodes produced by a different behavior policy rather than that produced by the target policy. Model: Transition and Reward The model is a descriptor of the environment. With the model, we can learn or infer how the environment would interact with and provide feedback to the agent. The model has two major parts, transition probability function $P$ and reward function $R$.\nLet\u0026rsquo;s say when we are in state s, we decide to take action a to arrive in the next state s' and obtain reward r. This is known as one transition step, represented by a tuple (s, a, s', r).\nThe transition function P records the probability of transitioning from state s to s' after taking action a while obtaining reward r. We use $\\mathbb{P}$ as a symbol of \u0026ldquo;probability\u0026rdquo;.\n $$ P(s', r \\vert s, a) = \\mathbb{P} [S_{t+1} = s', R_{t+1} = r \\vert S_t = s, A_t = a] $$ Thus the state-transition function can be defined as a function of $P(s', r \\vert s, a)$:\n $$ P_{ss'}^a = P(s' \\vert s, a) = \\mathbb{P} [S_{t+1} = s' \\vert S_t = s, A_t = a] = \\sum_{r \\in \\mathcal{R}} P(s', r \\vert s, a) $$ The reward function R predicts the next reward triggered by one action:\n $$ R(s, a) = \\mathbb{E} [R_{t+1} \\vert S_t = s, A_t = a] = \\sum_{r\\in\\mathcal{R}} r \\sum_{s' \\in \\mathcal{S}} P(s', r \\vert s, a) $$ Policy Policy, as the agent\u0026rsquo;s behavior function $\\pi$, tells us which action to take in state s. It is a mapping from state s to action a and can be either deterministic or stochastic:\n Deterministic: $\\pi(s) = a$. Stochastic: $\\pi(a \\vert s) = \\mathbb{P}_\\pi [A=a \\vert S=s]$. Value Function Value function measures the goodness of a state or how rewarding a state or an action is by a prediction of future reward. The future reward, also known as return, is a total sum of discounted rewards going forward. Let\u0026rsquo;s compute the return $G_t$ starting from time t:\n $$ G_t = R_{t+1} + \\gamma R_{t+2} + \\dots = \\sum_{k=0}^{\\infty} \\gamma^k R_{t+k+1} $$ The discounting factor $\\gamma \\in [0, 1]$ penalize the rewards in the future, because:\n The future rewards may have higher uncertainty; i.e. stock market. The future rewards do not provide immediate benefits; i.e. As human beings, we might prefer to have fun today rather than 5 years later ;). Discounting provides mathematical convenience; i.e., we don\u0026rsquo;t need to track future steps forever to compute return. We don\u0026rsquo;t need to worry about the infinite loops in the state transition graph. The state-value of a state s is the expected return if we are in this state at time t, $S_t = s$:\n $$ V_{\\pi}(s) = \\mathbb{E}_{\\pi}[G_t \\vert S_t = s] $$ Similarly, we define the action-value (\u0026ldquo;Q-value\u0026rdquo;; Q as \u0026ldquo;Quality\u0026rdquo; I believe?) of a state-action pair as:\n $$ Q_{\\pi}(s, a) = \\mathbb{E}_{\\pi}[G_t \\vert S_t = s, A_t = a] $$ Additionally, since we follow the target policy $\\pi$, we can make use of the probility distribution over possible actions and the Q-values to recover the state-value:\n $$ V_{\\pi}(s) = \\sum_{a \\in \\mathcal{A}} Q_{\\pi}(s, a) \\pi(a \\vert s) $$ The difference between action-value and state-value is the action advantage function (\u0026ldquo;A-value\u0026rdquo;):\n $$ A_{\\pi}(s, a) = Q_{\\pi}(s, a) - V_{\\pi}(s) $$ Optimal Value and Policy The optimal value function produces the maximum return:\n $$ V_{*}(s) = \\max_{\\pi} V_{\\pi}(s), Q_{*}(s, a) = \\max_{\\pi} Q_{\\pi}(s, a) $$ The optimal policy achieves optimal value functions:\n $$ \\pi_{*} = \\arg\\max_{\\pi} V_{\\pi}(s), \\pi_{*} = \\arg\\max_{\\pi} Q_{\\pi}(s, a) $$ And of course, we have $V_{\\pi_{*}}(s)=V_{*}(s)$ and $Q_{\\pi_{*}}(s, a) = Q_{*}(s, a)$.\nMarkov Decision Processes In more formal terms, almost all the RL problems can be framed as Markov Decision Processes (MDPs). All states in MDP has \u0026ldquo;Markov\u0026rdquo; property, referring to the fact that the future only depends on the current state, not the history:\n $$ \\mathbb{P}[ S_{t+1} \\vert S_t ] = \\mathbb{P} [S_{t+1} \\vert S_1, \\dots, S_t] $$ Or in other words, the future and the past are conditionally independent given the present, as the current state encapsulates all the statistics we need to decide the future.\nFig. 3. The agent-environment interaction in a Markov decision process. (Image source: Sec. 3.1 Sutton \u0026 Barto (2017).) A Markov deicison process consists of five elements $\\mathcal{M} = \\langle \\mathcal{S}, \\mathcal{A}, P, R, \\gamma \\rangle$, where the symbols carry the same meanings as key concepts in the previous section, well aligned with RL problem settings:\n $\\mathcal{S}$ - a set of states; $\\mathcal{A}$ - a set of actions; $P$ - transition probability function; $R$ - reward function; $\\gamma$ - discounting factor for future rewards. In an unknown environment, we do not have perfect knowledge about $P$ and $R$. Fig. 4. A fun example of Markov decision process: a typical work day. (Image source: randomant.net/reinforcement-learning-concepts) Bellman Equations Bellman equations refer to a set of equations that decompose the value function into the immediate reward plus the discounted future values.\n $$ \\begin{aligned} V(s) \u0026= \\mathbb{E}[G_t \\vert S_t = s] \\\\ \u0026= \\mathbb{E} [R_{t+1} + \\gamma R_{t+2} + \\gamma^2 R_{t+3} + \\dots \\vert S_t = s] \\\\ \u0026= \\mathbb{E} [R_{t+1} + \\gamma (R_{t+2} + \\gamma R_{t+3} + \\dots) \\vert S_t = s] \\\\ \u0026= \\mathbb{E} [R_{t+1} + \\gamma G_{t+1} \\vert S_t = s] \\\\ \u0026= \\mathbb{E} [R_{t+1} + \\gamma V(S_{t+1}) \\vert S_t = s] \\end{aligned} $$ Similarly for Q-value,\n $$ \\begin{aligned} Q(s, a) \u0026= \\mathbb{E} [R_{t+1} + \\gamma V(S_{t+1}) \\mid S_t = s, A_t = a] \\\\ \u0026= \\mathbb{E} [R_{t+1} + \\gamma \\mathbb{E}_{a\\sim\\pi} Q(S_{t+1}, a) \\mid S_t = s, A_t = a] \\end{aligned} $$ Bellman Expectation Equations The recursive update process can be further decomposed to be equations built on both state-value and action-value functions. As we go further in future action steps, we extend V and Q alternatively by following the policy $\\pi$.\nFig. 5. Illustration of how Bellman expection equations update state-value and action-value functions. $$ \\begin{aligned} V_{\\pi}(s) \u0026= \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s) Q_{\\pi}(s, a) \\\\ Q_{\\pi}(s, a) \u0026= R(s, a) + \\gamma \\sum_{s' \\in \\mathcal{S}} P_{ss'}^a V_{\\pi} (s') \\\\ V_{\\pi}(s) \u0026= \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s) \\big( R(s, a) + \\gamma \\sum_{s' \\in \\mathcal{S}} P_{ss'}^a V_{\\pi} (s') \\big) \\\\ Q_{\\pi}(s, a) \u0026= R(s, a) + \\gamma \\sum_{s' \\in \\mathcal{S}} P_{ss'}^a \\sum_{a' \\in \\mathcal{A}} \\pi(a' \\vert s') Q_{\\pi} (s', a') \\end{aligned} $$ Bellman Optimality Equations If we are only interested in the optimal values, rather than computing the expectation following a policy, we could jump right into the maximum returns during the alternative updates without using a policy. RECAP: the optimal values $V_*$ and $Q_*$ are the best returns we can obtain, defined here.\n $$ \\begin{aligned} V_*(s) \u0026= \\max_{a \\in \\mathcal{A}} Q_*(s,a)\\\\ Q_*(s, a) \u0026= R(s, a) + \\gamma \\sum_{s' \\in \\mathcal{S}} P_{ss'}^a V_*(s') \\\\ V_*(s) \u0026= \\max_{a \\in \\mathcal{A}} \\big( R(s, a) + \\gamma \\sum_{s' \\in \\mathcal{S}} P_{ss'}^a V_*(s') \\big) \\\\ Q_*(s, a) \u0026= R(s, a) + \\gamma \\sum_{s' \\in \\mathcal{S}} P_{ss'}^a \\max_{a' \\in \\mathcal{A}} Q_*(s', a') \\end{aligned} $$ Unsurprisingly they look very similar to Bellman expectation equations.\nIf we have complete information of the environment, this turns into a planning problem, solvable by DP. Unfortunately, in most scenarios, we do not know $P_{ss'}^a$ or $R(s, a)$, so we cannot solve MDPs by directly applying Bellmen equations, but it lays the theoretical foundation for many RL algorithms.\nCommon Approaches Now it is the time to go through the major approaches and classic algorithms for solving RL problems. In future posts, I plan to dive into each approach further.\nDynamic Programming When the model is fully known, following Bellman equations, we can use Dynamic Programming (DP) to iteratively evaluate value functions and improve policy.\nPolicy Evaluation Policy Evaluation is to compute the state-value $V_\\pi$ for a given policy $\\pi$:\n $$ V_{t+1}(s) = \\mathbb{E}_\\pi [r + \\gamma V_t(s') | S_t = s] = \\sum_a \\pi(a \\vert s) \\sum_{s', r} P(s', r \\vert s, a) (r + \\gamma V_t(s')) $$ Policy Improvement Based on the value functions, Policy Improvement generates a better policy $\\pi' \\geq \\pi$ by acting greedily.\n $$ Q_\\pi(s, a) = \\mathbb{E} [R_{t+1} + \\gamma V_\\pi(S_{t+1}) \\vert S_t=s, A_t=a] = \\sum_{s', r} P(s', r \\vert s, a) (r + \\gamma V_\\pi(s')) $$ Policy Iteration The Generalized Policy Iteration (GPI) algorithm refers to an iterative procedure to improve the policy when combining policy evaluation and improvement.\n $$ \\pi_0 \\xrightarrow[]{\\text{evaluation}} V_{\\pi_0} \\xrightarrow[]{\\text{improve}} \\pi_1 \\xrightarrow[]{\\text{evaluation}} V_{\\pi_1} \\xrightarrow[]{\\text{improve}} \\pi_2 \\xrightarrow[]{\\text{evaluation}} \\dots \\xrightarrow[]{\\text{improve}} \\pi_* \\xrightarrow[]{\\text{evaluation}} V_* $$ In GPI, the value function is approximated repeatedly to be closer to the true value of the current policy and in the meantime, the policy is improved repeatedly to approach optimality. This policy iteration process works and always converges to the optimality, but why this is the case?\nSay, we have a policy $\\pi$ and then generate an improved version $\\pi'$ by greedily taking actions, $\\pi'(s) = \\arg\\max_{a \\in \\mathcal{A}} Q_\\pi(s, a)$. The value of this improved $\\pi'$ is guaranteed to be better because:\n $$ \\begin{aligned} Q_\\pi(s, \\pi'(s)) \u0026= Q_\\pi(s, \\arg\\max_{a \\in \\mathcal{A}} Q_\\pi(s, a)) \\\\ \u0026= \\max_{a \\in \\mathcal{A}} Q_\\pi(s, a) \\geq Q_\\pi(s, \\pi(s)) = V_\\pi(s) \\end{aligned} $$ Monte-Carlo Methods First, let\u0026rsquo;s recall that $V(s) = \\mathbb{E}[ G_t \\vert S_t=s]$. Monte-Carlo (MC) methods uses a simple idea: It learns from episodes of raw experience without modeling the environmental dynamics and computes the observed mean return as an approximation of the expected return. To compute the empirical return $G_t$, MC methods need to learn from complete episodes $S_1, A_1, R_2, \\dots, S_T$ to compute $G_t = \\sum_{k=0}^{T-t-1} \\gamma^k R_{t+k+1}$ and all the episodes must eventually terminate.\nThe empirical mean return for state s is:\n $$ V(s) = \\frac{\\sum_{t=1}^T \\mathbb{1}[S_t = s] G_t}{\\sum_{t=1}^T \\mathbb{1}[S_t = s]} $$ where $\\mathbb{1}[S_t = s]$ is a binary indicator function. We may count the visit of state s every time so that there could exist multiple visits of one state in one episode (\u0026ldquo;every-visit\u0026rdquo;), or only count it the first time we encounter a state in one episode (\u0026ldquo;first-visit\u0026rdquo;). This way of approximation can be easily extended to action-value functions by counting (s, a) pair.\n $$ Q(s, a) = \\frac{\\sum_{t=1}^T \\mathbb{1}[S_t = s, A_t = a] G_t}{\\sum_{t=1}^T \\mathbb{1}[S_t = s, A_t = a]} $$ To learn the optimal policy by MC, we iterate it by following a similar idea to GPI.\n Improve the policy greedily with respect to the current value function: $\\pi(s) = \\arg\\max_{a \\in \\mathcal{A}} Q(s, a)$. Generate a new episode with the new policy $\\pi$ (i.e. using algorithms like ε-greedy helps us balance between exploitation and exploration.) Estimate Q using the new episode: $q_\\pi(s, a) = \\frac{\\sum_{t=1}^T \\big( \\mathbb{1}[S_t = s, A_t = a] \\sum_{k=0}^{T-t-1} \\gamma^k R_{t+k+1} \\big)}{\\sum_{t=1}^T \\mathbb{1}[S_t = s, A_t = a]}$ Temporal-Difference Learning Similar to Monte-Carlo methods, Temporal-Difference (TD) Learning is model-free and learns from episodes of experience. However, TD learning can learn from incomplete episodes and hence we don\u0026rsquo;t need to track the episode up to termination. TD learning is so important that Sutton \u0026amp; Barto (2017) in their RL book describes it as \u0026ldquo;one idea … central and novel to reinforcement learning\u0026rdquo;.\nBootstrapping TD learning methods update targets with regard to existing estimates rather than exclusively relying on actual rewards and complete returns as in MC methods. This approach is known as bootstrapping.\nValue Estimation The key idea in TD learning is to update the value function $V(S_t)$ towards an estimated return $R_{t+1} + \\gamma V(S_{t+1})$ (known as \u0026ldquo;TD target\u0026quot;). To what extent we want to update the value function is controlled by the learning rate hyperparameter α:\n $$ \\begin{aligned} V(S_t) \u0026\\leftarrow (1- \\alpha) V(S_t) + \\alpha G_t \\\\ V(S_t) \u0026\\leftarrow V(S_t) + \\alpha (G_t - V(S_t)) \\\\ V(S_t) \u0026\\leftarrow V(S_t) + \\alpha (R_{t+1} + \\gamma V(S_{t+1}) - V(S_t)) \\end{aligned} $$ Similarly, for action-value estimation:\n $$ Q(S_t, A_t) \\leftarrow Q(S_t, A_t) + \\alpha (R_{t+1} + \\gamma Q(S_{t+1}, A_{t+1}) - Q(S_t, A_t)) $$ Next, let\u0026rsquo;s dig into the fun part on how to learn optimal policy in TD learning (aka \u0026ldquo;TD control\u0026rdquo;). Be prepared, you are gonna see many famous names of classic algorithms in this section.\nSARSA: On-Policy TD control \u0026ldquo;SARSA\u0026rdquo; refers to the procedure of updaing Q-value by following a sequence of $\\dots, S_t, A_t, R_{t+1}, S_{t+1}, A_{t+1}, \\dots$. The idea follows the same route of GPI. Within one episode, it works as follows:\n Initialize $t=0$. Start with $S_0$ and choose action $A_0 = \\arg\\max_{a \\in \\mathcal{A}} Q(S_0, a)$, where $\\epsilon$-greedy is commonly applied. At time $t$, after applying action $A_t$, we observe reward $R_{t+1}$ and get into the next state $S_{t+1}$. Then pick the next action in the same way as in step 2: $A_{t+1} = \\arg\\max_{a \\in \\mathcal{A}} Q(S_{t+1}, a)$. Update the Q-value function: $ Q(S_t, A_t) \\leftarrow Q(S_t, A_t) + \\alpha (R_{t+1} + \\gamma Q(S_{t+1}, A_{t+1}) - Q(S_t, A_t)) $. Set $t = t+1$ and repeat from step 3. In each step of SARSA, we need to choose the next action according to the current policy.\nQ-Learning: Off-policy TD control The development of Q-learning (Watkins \u0026amp; Dayan, 1992) is a big breakout in the early days of Reinforcement Learning. Within one episode, it works as follows:\n Initialize $t=0$. Starts with $S_0$. At time step $t$, we pick the action according to Q values, $A_t = \\arg\\max_{a \\in \\mathcal{A}} Q(S_t, a)$ and $\\epsilon$-greedy is commonly applied. After applying action $A_t$, we observe reward $R_{t+1}$ and get into the next state $S_{t+1}$. Update the Q-value function: $Q(S_t, A_t) \\leftarrow Q(S_t, A_t) + \\alpha (R_{t+1} + \\gamma \\max_{a \\in \\mathcal{A}} Q(S_{t+1}, a) - Q(S_t, A_t))$. $t = t+1$ and repeat from step 3. The key difference from SARSA is that Q-learning does not follow the current policy to pick the second action $A_{t+1}$. It estimates $Q^*$ out of the best Q values, but which action (denoted as $a^*$) leads to this maximal Q does not matter and in the next step Q-learning may not follow $a^*$.\nFig. 6. The backup diagrams for Q-learning and SARSA. (Image source: Replotted based on Figure 6.5 in Sutton \u0026 Barto (2017)) Deep Q-Network Theoretically, we can memorize $Q_*(.)$ for all state-action pairs in Q-learning, like in a gigantic table. However, it quickly becomes computationally infeasible when the state and action space are large. Thus people use functions (i.e. a machine learning model) to approximate Q values and this is called function approximation. For example, if we use a function with parameter $\\theta$ to calculate Q values, we can label Q value function as $Q(s, a; \\theta)$.\nUnfortunately Q-learning may suffer from instability and divergence when combined with an nonlinear Q-value function approximation and bootstrapping (See Problems #2).\nDeep Q-Network (\u0026ldquo;DQN\u0026rdquo;; Mnih et al. 2015) aims to greatly improve and stabilize the training procedure of Q-learning by two innovative mechanisms:\n Experience Replay: All the episode steps $e_t = (S_t, A_t, R_t, S_{t+1})$ are stored in one replay memory $D_t = \\{ e_1, \\dots, e_t \\}$. $D_t$ has experience tuples over many episodes. During Q-learning updates, samples are drawn at random from the replay memory and thus one sample could be used multiple times. Experience replay improves data efficiency, removes correlations in the observation sequences, and smooths over changes in the data distribution. Periodically Updated Target: Q is optimized towards target values that are only periodically updated. The Q network is cloned and kept frozen as the optimization target every C steps (C is a hyperparameter). This modification makes the training more stable as it overcomes the short-term oscillations. The loss function looks like this:\n $$ \\mathcal{L}(\\theta) = \\mathbb{E}_{(s, a, r, s') \\sim U(D)} \\Big[ \\big( r + \\gamma \\max_{a'} Q(s', a'; \\theta^{-}) - Q(s, a; \\theta) \\big)^2 \\Big] $$ where $U(D)$ is a uniform distribution over the replay memory D; $\\theta^{-}$ is the parameters of the frozen target Q-network.\nIn addition, it is also found to be helpful to clip the error term to be between [-1, 1]. (I always get mixed feeling with parameter clipping, as many studies have shown that it works empirically but it makes the math much less pretty. :/)\nFig. 7. Algorithm for DQN with experience replay and occasionally frozen optimization target. The prepossessed sequence is the output of some processes running on the input images of Atari games. Don't worry too much about it; just consider them as input feature vectors. (Image source: Mnih et al. 2015) There are many extensions of DQN to improve the original design, such as DQN with dueling architecture (Wang et al. 2016) which estimates state-value function V(s) and advantage function A(s, a) with shared network parameters.\nCombining TD and MC Learning In the previous section on value estimation in TD learning, we only trace one step further down the action chain when calculating the TD target. One can easily extend it to take multiple steps to estimate the return.\nLet\u0026rsquo;s label the estimated return following n steps as $G_t^{(n)}, n=1, \\dots, \\infty$, then:\n $n$ $G_t$ Notes $n=1$ $G_t^{(1)} = R_{t+1} + \\gamma V(S_{t+1})$ TD learning $n=2$ $G_t^{(2)} = R_{t+1} + \\gamma R_{t+2} + \\gamma^2 V(S_{t+2})$ \u0026hellip; $n=n$ $ G_t^{(n)} = R_{t+1} + \\gamma R_{t+2} + \\dots + \\gamma^{n-1} R_{t+n} + \\gamma^n V(S_{t+n}) $ \u0026hellip; $n=\\infty$ $G_t^{(\\infty)} = R_{t+1} + \\gamma R_{t+2} + \\dots + \\gamma^{T-t-1} R_T + \\gamma^{T-t} V(S_T) $ MC estimation The generalized n-step TD learning still has the same form for updating the value function:\n $$ V(S_t) \\leftarrow V(S_t) + \\alpha (G_t^{(n)} - V(S_t)) $$ We are free to pick any $n$ in TD learning as we like. Now the question becomes what is the best $n$? Which $G_t^{(n)}$ gives us the best return approximation? A common yet smart solution is to apply a weighted sum of all possible n-step TD targets rather than to pick a single best n. The weights decay by a factor λ with n, $\\lambda^{n-1}$; the intuition is similar to why we want to discount future rewards when computing the return: the more future we look into the less confident we would be. To make all the weight (n → ∞) sum up to 1, we multiply every weight by (1-λ), because:\n $$ \\begin{aligned} \\text{let } S \u0026= 1 + \\lambda + \\lambda^2 + \\dots \\\\ S \u0026= 1 + \\lambda(1 + \\lambda + \\lambda^2 + \\dots) \\\\ S \u0026= 1 + \\lambda S \\\\ S \u0026= 1 / (1-\\lambda) \\end{aligned} $$ This weighted sum of many n-step returns is called λ-return $G_t^{\\lambda} = (1-\\lambda) \\sum_{n=1}^{\\infty} \\lambda^{n-1} G_t^{(n)}$. TD learning that adopts λ-return for value updating is labeled as TD(λ). The original version we introduced above is equivalent to TD(0).\nFig. 8. Comparison of the backup diagrams of Monte-Carlo, Temporal-Difference learning, and Dynamic Programming for state value functions. (Image source: David Silver's RL course lecture 4: \"Model-Free Prediction\") Policy Gradient All the methods we have introduced above aim to learn the state/action value function and then to select actions accordingly. Policy Gradient methods instead learn the policy directly with a parameterized function respect to $\\theta$, $\\pi(a \\vert s; \\theta)$. Let\u0026rsquo;s define the reward function (opposite of loss function) as the expected return and train the algorithm with the goal to maximize the reward function. My next post described why the policy gradient theorem works (proof) and introduced a number of policy gradient algorithms.\nIn discrete space:\n $$ \\mathcal{J}(\\theta) = V_{\\pi_\\theta}(S_1) = \\mathbb{E}_{\\pi_\\theta}[V_1] $$ where $S_1$ is the initial starting state.\nOr in continuous space:\n $$ \\mathcal{J}(\\theta) = \\sum_{s \\in \\mathcal{S}} d_{\\pi_\\theta}(s) V_{\\pi_\\theta}(s) = \\sum_{s \\in \\mathcal{S}} \\Big( d_{\\pi_\\theta}(s) \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s, \\theta) Q_\\pi(s, a) \\Big) $$ where $d_{\\pi_\\theta}(s)$ is stationary distribution of Markov chain for $\\pi_\\theta$. If you are unfamiliar with the definition of a \u0026ldquo;stationary distribution,\u0026rdquo; please check this reference.\nUsing gradient ascent we can find the best θ that produces the highest return. It is natural to expect policy-based methods are more useful in continuous space, because there is an infinite number of actions and/or states to estimate the values for in continuous space and hence value-based approaches are computationally much more expensive.\nPolicy Gradient Theorem Computing the gradient numerically can be done by perturbing θ by a small amount ε in the k-th dimension. It works even when $J(\\theta)$ is not differentiable (nice!), but unsurprisingly very slow.\n $$ \\frac{\\partial \\mathcal{J}(\\theta)}{\\partial \\theta_k} \\approx \\frac{\\mathcal{J}(\\theta + \\epsilon u_k) - \\mathcal{J}(\\theta)}{\\epsilon} $$ Or analytically,\n $$ \\mathcal{J}(\\theta) = \\mathbb{E}_{\\pi_\\theta} [r] = \\sum_{s \\in \\mathcal{S}} d_{\\pi_\\theta}(s) \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s; \\theta) R(s, a) $$ Actually we have nice theoretical support for (replacing $d(.)$ with $d_\\pi(.)$):\n $$ \\mathcal{J}(\\theta) = \\sum_{s \\in \\mathcal{S}} d_{\\pi_\\theta}(s) \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s; \\theta) Q_\\pi(s, a) \\propto \\sum_{s \\in \\mathcal{S}} d(s) \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s; \\theta) Q_\\pi(s, a) $$ Check Sec 13.1 in Sutton \u0026amp; Barto (2017) for why this is the case.\nThen,\n $$ \\begin{aligned} \\mathcal{J}(\\theta) \u0026= \\sum_{s \\in \\mathcal{S}} d(s) \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s; \\theta) Q_\\pi(s, a) \\\\ \\nabla \\mathcal{J}(\\theta) \u0026= \\sum_{s \\in \\mathcal{S}} d(s) \\sum_{a \\in \\mathcal{A}} \\nabla \\pi(a \\vert s; \\theta) Q_\\pi(s, a) \\\\ \u0026= \\sum_{s \\in \\mathcal{S}} d(s) \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s; \\theta) \\frac{\\nabla \\pi(a \\vert s; \\theta)}{\\pi(a \\vert s; \\theta)} Q_\\pi(s, a) \\\\ \u0026 = \\sum_{s \\in \\mathcal{S}} d(s) \\sum_{a \\in \\mathcal{A}} \\pi(a \\vert s; \\theta) \\nabla \\ln \\pi(a \\vert s; \\theta) Q_\\pi(s, a) \\\\ \u0026 = \\mathbb{E}_{\\pi_\\theta} [\\nabla \\ln \\pi(a \\vert s; \\theta) Q_\\pi(s, a)] \\end{aligned} $$ This result is named \u0026ldquo;Policy Gradient Theorem\u0026rdquo; which lays the theoretical foundation for various policy gradient algorithms:\n $$ \\nabla \\mathcal{J}(\\theta) = \\mathbb{E}_{\\pi_\\theta} [\\nabla \\ln \\pi(a \\vert s, \\theta) Q_\\pi(s, a)] $$ REINFORCE REINFORCE, also known as Monte-Carlo policy gradient, relies on $Q_\\pi(s, a)$, an estimated return by MC methods using episode samples, to update the policy parameter $\\theta$.\nA commonly used variation of REINFORCE is to subtract a baseline value from the return $G_t$ to reduce the variance of gradient estimation while keeping the bias unchanged. For example, a common baseline is state-value, and if applied, we would use $A(s, a) = Q(s, a) - V(s)$ in the gradient ascent update.\n Initialize θ at random Generate one episode $S_1, A_1, R_2, S_2, A_2, \\dots, S_T$ For t=1, 2, \u0026hellip; , T: Estimate the the return G_t since the time step t. $\\theta \\leftarrow \\theta + \\alpha \\gamma^t G_t \\nabla \\ln \\pi(A_t \\vert S_t, \\theta)$. Actor-Critic If the value function is learned in addition to the policy, we would get Actor-Critic algorithm.\n Critic: updates value function parameters w and depending on the algorithm it could be action-value $Q(a \\vert s; w)$ or state-value $V(s; w)$. Actor: updates policy parameters θ, in the direction suggested by the critic, $\\pi(a \\vert s; \\theta)$. Let\u0026rsquo;s see how it works in an action-value actor-critic algorithm.\n Initialize s, θ, w at random; sample $a \\sim \\pi(a \\vert s; \\theta)$. For t = 1… T: Sample reward $r_t \\sim R(s, a)$ and next state $s' \\sim P(s' \\vert s, a)$. Then sample the next action $a' \\sim \\pi(s', a'; \\theta)$. Update policy parameters: $\\theta \\leftarrow \\theta + \\alpha_\\theta Q(s, a; w) \\nabla_\\theta \\ln \\pi(a \\vert s; \\theta)$. Compute the correction for action-value at time t: $G_{t:t+1} = r_t + \\gamma Q(s', a'; w) - Q(s, a; w)$ and use it to update value function parameters: $w \\leftarrow w + \\alpha_w G_{t:t+1} \\nabla_w Q(s, a; w) $. Update $a \\leftarrow a'$ and $s \\leftarrow s'$. $\\alpha_\\theta$ and $\\alpha_w$ are two learning rates for policy and value function parameter updates, respectively.\nA3C Asynchronous Advantage Actor-Critic (Mnih et al., 2016), short for A3C, is a classic policy gradient method with the special focus on parallel training.\nIn A3C, the critics learn the state-value function, $V(s; w)$, while multiple actors are trained in parallel and get synced with global parameters from time to time. Hence, A3C is good for parallel training by default, i.e. on one machine with multi-core CPU.\nThe loss function for state-value is to minimize the mean squared error, $\\mathcal{J}_v (w) = (G_t - V(s; w))^2$ and we use gradient descent to find the optimal w. This state-value function is used as the baseline in the policy gradient update.\nHere is the algorithm outline:\n We have global parameters, θ and w; similar thread-specific parameters, θ' and w'. Initialize the time step t = 1 While T \u0026lt;= T_MAX: Reset gradient: dθ = 0 and dw = 0. Synchronize thread-specific parameters with global ones: θ' = θ and w' = w. $t_\\text{start}$ = t and get $s_t$. While ($s_t \\neq \\text{TERMINAL}$) and ($t - t_\\text{start} \u0026lt;= t_\\text{max}$): Pick the action $a_t \\sim \\pi(a_t \\vert s_t; \\theta')$ and receive a new reward $r_t$ and a new state $s_{t+1}$. Update t = t + 1 and T = T + 1. Initialize the variable that holds the return estimation $$R = \\begin{cases} 0 \u0026amp; \\text{if } s_t \\text{ is TERMINAL} \\ V(s_t; w') \u0026amp; \\text{otherwise} \\end{cases}$$. For $i = t-1, \\dots, t_\\text{start}$: $R \\leftarrow r_i + \\gamma R$; here R is a MC measure of $G_i$. Accumulate gradients w.r.t. θ': $d\\theta \\leftarrow d\\theta + \\nabla_{\\theta'} \\log \\pi(a_i \\vert s_i; \\theta')(R - V(s_i; w'))$; Accumulate gradients w.r.t. w': $dw \\leftarrow dw + \\nabla_{w'} (R - V(s_i; w'))^2$. Update synchronously θ using dθ, and w using dw. A3C enables the parallelism in multiple agent training. The gradient accumulation step (6.2) can be considered as a reformation of minibatch-based stochastic gradient update: the values of w or θ get corrected by a little bit in the direction of each training thread independently.\nEvolution Strategies Evolution Strategies (ES) is a type of model-agnostic optimization approach. It learns the optimal solution by imitating Darwin\u0026rsquo;s theory of the evolution of species by natural selection. Two prerequisites for applying ES: (1) our solutions can freely interact with the environment and see whether they can solve the problem; (2) we are able to compute a fitness score of how good each solution is. We don\u0026rsquo;t have to know the environment configuration to solve the problem.\nSay, we start with a population of random solutions. All of them are capable of interacting with the environment and only candidates with high fitness scores can survive (only the fittest can survive in a competition for limited resources). A new generation is then created by recombining the settings (gene mutation) of high-fitness survivors. This process is repeated until the new solutions are good enough.\nVery different from the popular MDP-based approaches as what we have introduced above, ES aims to learn the policy parameter $\\theta$ without value approximation. Let\u0026rsquo;s assume the distribution over the parameter $\\theta$ is an isotropic multivariate Gaussian with mean $\\mu$ and fixed covariance $\\sigma^2I$. The gradient of $F(\\theta)$ is calculated:\n $$ \\begin{aligned} \u0026 \\nabla_\\theta \\mathbb{E}_{\\theta \\sim N(\\mu, \\sigma^2)} F(\\theta) \\\\ =\u0026 \\nabla_\\theta \\int_\\theta F(\\theta) \\Pr(\\theta) \u0026\u0026 \\text{Pr(.) is the Gaussian density function.} \\\\ =\u0026 \\int_\\theta F(\\theta) \\Pr(\\theta) \\frac{\\nabla_\\theta \\Pr(\\theta)}{\\Pr(\\theta)} \\\\ =\u0026 \\int_\\theta F(\\theta) \\Pr(\\theta) \\nabla_\\theta \\log \\Pr(\\theta) \\\\ =\u0026 \\mathbb{E}_{\\theta \\sim N(\\mu, \\sigma^2)} [F(\\theta) \\nabla_\\theta \\log \\Pr(\\theta)] \u0026\u0026 \\text{Similar to how we do policy gradient update.} \\\\ =\u0026 \\mathbb{E}_{\\theta \\sim N(\\mu, \\sigma^2)} \\Big[ F(\\theta) \\nabla_\\theta \\log \\Big( \\frac{1}{\\sqrt{2\\pi\\sigma^2}} e^{-\\frac{(\\theta - \\mu)^2}{2 \\sigma^2 }} \\Big) \\Big] \\\\ =\u0026 \\mathbb{E}_{\\theta \\sim N(\\mu, \\sigma^2)} \\Big[ F(\\theta) \\nabla_\\theta \\Big( -\\log \\sqrt{2\\pi\\sigma^2} - \\frac{(\\theta - \\mu)^2}{2 \\sigma^2} \\Big) \\Big] \\\\ =\u0026 \\mathbb{E}_{\\theta \\sim N(\\mu, \\sigma^2)} \\Big[ F(\\theta) \\frac{\\theta - \\mu}{\\sigma^2} \\Big] \\end{aligned} $$ We can rewrite this formula in terms of a \u0026ldquo;mean\u0026rdquo; parameter $\\theta$ (different from the $\\theta$ above; this $\\theta$ is the base gene for further mutation), $\\epsilon \\sim N(0, I)$ and therefore $\\theta + \\epsilon \\sigma \\sim N(\\theta, \\sigma^2)$. $\\epsilon$ controls how much Gaussian noises should be added to create mutation:\n $$ \\nabla_\\theta \\mathbb{E}_{\\epsilon \\sim N(0, I)} F(\\theta + \\sigma \\epsilon) = \\frac{1}{\\sigma} \\mathbb{E}_{\\epsilon \\sim N(0, I)} [F(\\theta + \\sigma \\epsilon) \\epsilon] $$ Fig. 9. A simple parallel evolution-strategies-based RL algorithm. Parallel workers share the random seeds so that they can reconstruct the Gaussian noises with tiny communication bandwidth. (Image source: Salimans et al. 2017.) ES, as a black-box optimization algorithm, is another approach to RL problems (In my original writing, I used the phrase \u0026ldquo;a nice alternative\u0026rdquo;; Seita pointed me to this discussion and thus I updated my wording.). It has a couple of good characteristics (Salimans et al., 2017) keeping it fast and easy to train:\n ES does not need value function approximation; ES does not perform gradient back-propagation; ES is invariant to delayed or long-term rewards; ES is highly parallelizable with very little data communication. Known Problems Exploration-Exploitation Dilemma The problem of exploration vs exploitation dilemma has been discussed in my previous post. When the RL problem faces an unknown environment, this issue is especially a key to finding a good solution: without enough exploration, we cannot learn the environment well enough; without enough exploitation, we cannot complete our reward optimization task.\nDifferent RL algorithms balance between exploration and exploitation in different ways. In MC methods, Q-learning or many on-policy algorithms, the exploration is commonly implemented by ε-greedy; In ES, the exploration is captured by the policy parameter perturbation. Please keep this into consideration when develop a new RL algorithm.\nDeadly Triad Issue We do seek the efficiency and flexibility of TD methods that involve bootstrapping. However, when off-policy, nonlinear function approximation, and bootstrapping are combined in one RL algorithm, the training could be unstable and hard to converge. This issue is known as the deadly triad (Sutton \u0026amp; Barto, 2017). Many architectures using deep learning models were proposed to resolve the problem, including DQN to stabilize the training with experience replay and occasionally frozen target network.\nCase Study: AlphaGo Zero The game of Go has been an extremely hard problem in the field of Artificial Intelligence for decades until recent years. AlphaGo and AlphaGo Zero are two programs developed by a team at DeepMind. Both involve deep Convolutional Neural Networks (CNN) and Monte Carlo Tree Search (MCTS) and both have been approved to achieve the level of professional human Go players. Different from AlphaGo that relied on supervised learning from expert human moves, AlphaGo Zero used only reinforcement learning and self-play without human knowledge beyond the basic rules.\nFig. 10. The board of Go. Two players play black and white stones alternatively on the vacant intersections of a board with 19 x 19 lines. A group of stones must have at least one open point (an intersection, called a \"liberty\") to remain on the board and must have at least two or more enclosed liberties (called \"eyes\") to stay \"alive\". No stone shall repeat a previous position. With all the knowledge of RL above, let\u0026rsquo;s take a look at how AlphaGo Zero works. The main component is a deep CNN over the game board configuration (precisely, a ResNet with batch normalization and ReLU). This network outputs two values:\n $$ (p, v) = f_\\theta(s) $$ $s$: the game board configuration, 19 x 19 x 17 stacked feature planes; 17 features for each position, 8 past configurations (including current) for the current player + 8 past configurations for the opponent + 1 feature indicating the color (1=black, 0=white). We need to code the color specifically because the network is playing with itself and the colors of current player and opponents are switching between steps. $p$: the probability of selecting a move over 19^2 + 1 candidates (19^2 positions on the board, in addition to passing). $v$: the winning probability given the current setting. During self-play, MCTS further improves the action probability distribution $\\pi \\sim p(.)$ and then the action $a_t$ is sampled from this improved policy. The reward $z_t$ is a binary value indicating whether the current player eventually wins the game. Each move generates an episode tuple $(s_t, \\pi_t, z_t)$ and it is saved into the replay memory. The details on MCTS are skipped for the sake of space in this post; please read the original paper if you are interested.\nFig. 11. AlphaGo Zero is trained by self-play while MCTS improves the output policy further in every step. (Image source: Figure 1a in Silver et al., 2017). The network is trained with the samples in the replay memory to minimize the loss:\n $$ \\mathcal{L} = (z - v)^2 - \\pi^\\top \\log p + c \\| \\theta \\|^2 $$ where $c$ is a hyperparameter controlling the intensity of L2 penalty to avoid overfitting.\nAlphaGo Zero simplified AlphaGo by removing supervised learning and merging separated policy and value networks into one. It turns out that AlphaGo Zero achieved largely improved performance with a much shorter training time! I strongly recommend reading these two papers side by side and compare the difference, super fun.\nI know this is a long read, but hopefully worth it. If you notice mistakes and errors in this post, don\u0026rsquo;t hesitate to contact me at [lilian dot wengweng at gmail dot com]. See you in the next post! :)\n Cited as:\n@article{weng2018bandit, title = \u0026quot;A (Long) Peek into Reinforcement Learning\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2018\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2018-02-19-rl-overview/\u0026quot; } References [1] Yuxi Li. Deep reinforcement learning: An overview. arXiv preprint arXiv:1701.07274. 2017.\n[2] Richard S. Sutton and Andrew G. Barto. Reinforcement Learning: An Introduction; 2nd Edition. 2017.\n[3] Volodymyr Mnih, et al. Asynchronous methods for deep reinforcement learning. ICML. 2016.\n[4] Tim Salimans, et al. Evolution strategies as a scalable alternative to reinforcement learning. arXiv preprint arXiv:1703.03864 (2017).\n[5] David Silver, et al. Mastering the game of go without human knowledge. Nature 550.7676 (2017): 354.\n[6] David Silver, et al. Mastering the game of Go with deep neural networks and tree search. Nature 529.7587 (2016): 484-489.\n[7] Volodymyr Mnih, et al. Human-level control through deep reinforcement learning. Nature 518.7540 (2015): 529.\n[8] Ziyu Wang, et al. Dueling network architectures for deep reinforcement learning. ICML. 2016.\n[9] Reinforcement Learning lectures by David Silver on YouTube.\n[10] OpenAI Blog: Evolution Strategies as a Scalable Alternative to Reinforcement Learning\n[11] Frank Sehnke, et al. Parameter-exploring policy gradients. Neural Networks 23.4 (2010): 551-559.\n[12] Csaba Szepesvári. Algorithms for reinforcement learning. 1st Edition. Synthesis lectures on artificial intelligence and machine learning 4.1 (2010): 1-103.\n If you notice mistakes and errors in this post, please don\u0026rsquo;t hesitate to contact me at [lilian dot wengweng at gmail dot com] and I would be super happy to correct them right away!\n","permalink":"https://lilianweng.github.io/posts/2018-02-19-rl-overview/","summary":"[Updated on 2020-09-03: Updated the algorithm of SARSA and Q-learning so that the difference is more pronounced. [Updated on 2021-09-19: Thanks to 爱吃猫的鱼, we have this post in Chinese].\nA couple of exciting news in Artificial Intelligence (AI) has just happened in recent years. AlphaGo defeated the best professional human player in the game of Go. Very soon the extended algorithm AlphaGo Zero beat AlphaGo by 100-0 without supervised learning on human knowledge.","title":"A (Long) Peek into Reinforcement Learning"},{"content":"The algorithms are implemented for Bernoulli bandit in lilianweng/multi-armed-bandit.\nExploitation vs Exploration The exploration vs exploitation dilemma exists in many aspects of our life. Say, your favorite restaurant is right around the corner. If you go there every day, you would be confident of what you will get, but miss the chances of discovering an even better option. If you try new places all the time, very likely you are gonna have to eat unpleasant food from time to time. Similarly, online advisors try to balance between the known most attractive ads and the new ads that might be even more successful.\nFig. 1. A real-life example of the exploration vs exploitation dilemma: where to eat? (Image source: UC Berkeley AI course slide, lecture 11.) If we have learned all the information about the environment, we are able to find the best strategy by even just simulating brute-force, let alone many other smart approaches. The dilemma comes from the incomplete information: we need to gather enough information to make best overall decisions while keeping the risk under control. With exploitation, we take advantage of the best option we know. With exploration, we take some risk to collect information about unknown options. The best long-term strategy may involve short-term sacrifices. For example, one exploration trial could be a total failure, but it warns us of not taking that action too often in the future.\nWhat is Multi-Armed Bandit? The multi-armed bandit problem is a classic problem that well demonstrates the exploration vs exploitation dilemma. Imagine you are in a casino facing multiple slot machines and each is configured with an unknown probability of how likely you can get a reward at one play. The question is: What is the best strategy to achieve highest long-term rewards?\nIn this post, we will only discuss the setting of having an infinite number of trials. The restriction on a finite number of trials introduces a new type of exploration problem. For instance, if the number of trials is smaller than the number of slot machines, we cannot even try every machine to estimate the reward probability (!) and hence we have to behave smartly w.r.t. a limited set of knowledge and resources (i.e. time).\nFig. 2. An illustration of how a Bernoulli multi-armed bandit works. The reward probabilities are **unknown** to the player. A naive approach can be that you continue to playing with one machine for many many rounds so as to eventually estimate the \u0026ldquo;true\u0026rdquo; reward probability according to the law of large numbers. However, this is quite wasteful and surely does not guarantee the best long-term reward.\nDefinition Now let\u0026rsquo;s give it a scientific definition.\nA Bernoulli multi-armed bandit can be described as a tuple of $\\langle \\mathcal{A}, \\mathcal{R} \\rangle$, where:\n We have $K$ machines with reward probabilities, $\\{ \\theta_1, \\dots, \\theta_K \\}$. At each time step t, we take an action a on one slot machine and receive a reward r. $\\mathcal{A}$ is a set of actions, each referring to the interaction with one slot machine. The value of action a is the expected reward, $Q(a) = \\mathbb{E} [r \\vert a] = \\theta$. If action $a_t$ at the time step t is on the i-th machine, then $Q(a_t) = \\theta_i$. $\\mathcal{R}$ is a reward function. In the case of Bernoulli bandit, we observe a reward r in a stochastic fashion. At the time step t, $r_t = \\mathcal{R}(a_t)$ may return reward 1 with a probability $Q(a_t)$ or 0 otherwise. It is a simplified version of Markov decision process, as there is no state $\\mathcal{S}$.\nThe goal is to maximize the cumulative reward $\\sum_{t=1}^T r_t$. If we know the optimal action with the best reward, then the goal is same as to minimize the potential regret or loss by not picking the optimal action.\nThe optimal reward probability $\\theta^{*}$ of the optimal action $a^{*}$ is:\n $$ \\theta^{*}=Q(a^{*})=\\max_{a \\in \\mathcal{A}} Q(a) = \\max_{1 \\leq i \\leq K} \\theta_i $$ Our loss function is the total regret we might have by not selecting the optimal action up to the time step T:\n $$ \\mathcal{L}_T = \\mathbb{E} \\Big[ \\sum_{t=1}^T \\big( \\theta^{*} - Q(a_t) \\big) \\Big] $$ Bandit Strategies Based on how we do exploration, there several ways to solve the multi-armed bandit.\n No exploration: the most naive approach and a bad one. Exploration at random Exploration smartly with preference to uncertainty ε-Greedy Algorithm The ε-greedy algorithm takes the best action most of the time, but does random exploration occasionally. The action value is estimated according to the past experience by averaging the rewards associated with the target action a that we have observed so far (up to the current time step t):\n $$ \\hat{Q}_t(a) = \\frac{1}{N_t(a)} \\sum_{\\tau=1}^t r_\\tau \\mathbb{1}[a_\\tau = a] $$ where $\\mathbb{1}$ is a binary indicator function and $N_t(a)$ is how many times the action a has been selected so far, $N_t(a) = \\sum_{\\tau=1}^t \\mathbb{1}[a_\\tau = a]$.\nAccording to the ε-greedy algorithm, with a small probability $\\epsilon$ we take a random action, but otherwise (which should be the most of the time, probability 1-$\\epsilon$) we pick the best action that we have learnt so far: $\\hat{a}^{*}_t = \\arg\\max_{a \\in \\mathcal{A}} \\hat{Q}_t(a)$.\nCheck my toy implementation here.\nUpper Confidence Bounds Random exploration gives us an opportunity to try out options that we have not known much about. However, due to the randomness, it is possible we end up exploring a bad action which we have confirmed in the past (bad luck!). To avoid such inefficient exploration, one approach is to decrease the parameter ε in time and the other is to be optimistic about options with high uncertainty and thus to prefer actions for which we haven\u0026rsquo;t had a confident value estimation yet. Or in other words, we favor exploration of actions with a strong potential to have a optimal value.\nThe Upper Confidence Bounds (UCB) algorithm measures this potential by an upper confidence bound of the reward value, $\\hat{U}_t(a)$, so that the true value is below with bound $Q(a) \\leq \\hat{Q}_t(a) + \\hat{U}_t(a)$ with high probability. The upper bound $\\hat{U}_t(a)$ is a function of $N_t(a)$; a larger number of trials $N_t(a)$ should give us a smaller bound $\\hat{U}_t(a)$.\nIn UCB algorithm, we always select the greediest action to maximize the upper confidence bound:\n $$ a^{UCB}_t = argmax_{a \\in \\mathcal{A}} \\hat{Q}_t(a) + \\hat{U}_t(a) $$ Now, the question is how to estimate the upper confidence bound.\nHoeffding\u0026rsquo;s Inequality If we do not want to assign any prior knowledge on how the distribution looks like, we can get help from \u0026ldquo;Hoeffding\u0026rsquo;s Inequality\u0026rdquo; \u0026mdash; a theorem applicable to any bounded distribution.\nLet $X_1, \\dots, X_t$ be i.i.d. (independent and identically distributed) random variables and they are all bounded by the interval [0, 1]. The sample mean is $\\overline{X}_t = \\frac{1}{t}\\sum_{\\tau=1}^t X_\\tau$. Then for u \u0026gt; 0, we have:\n $$ \\mathbb{P} [ \\mathbb{E}[X] \\overline{X}_t + u] \\leq e^{-2tu^2} $$ Given one target action a, let us consider:\n $r_t(a)$ as the random variables, $Q(a)$ as the true mean, $\\hat{Q}_t(a)$ as the sample mean, And $u$ as the upper confidence bound, $u = U_t(a)$ Then we have,\n $$ \\mathbb{P} [ Q(a) \\hat{Q}_t(a) + U_t(a)] \\leq e^{-2t{U_t(a)}^2} $$ We want to pick a bound so that with high chances the true mean is blow the sample mean + the upper confidence bound. Thus $e^{-2t U_t(a)^2}$ should be a small probability. Let\u0026rsquo;s say we are ok with a tiny threshold p:\n $$ e^{-2t U_t(a)^2} = p \\text{ Thus, } U_t(a) = \\sqrt{\\frac{-\\log p}{2 N_t(a)}} $$ UCB1 One heuristic is to reduce the threshold p in time, as we want to make more confident bound estimation with more rewards observed. Set $p=t^{-4}$ we get UCB1 algorithm:\n $$ U_t(a) = \\sqrt{\\frac{2 \\log t}{N_t(a)}} \\text{ and } a^{UCB1}_t = \\arg\\max_{a \\in \\mathcal{A}} Q(a) + \\sqrt{\\frac{2 \\log t}{N_t(a)}} $$ Bayesian UCB In UCB or UCB1 algorithm, we do not assume any prior on the reward distribution and therefore we have to rely on the Hoeffding\u0026rsquo;s Inequality for a very generalize estimation. If we are able to know the distribution upfront, we would be able to make better bound estimation.\nFor example, if we expect the mean reward of every slot machine to be Gaussian as in Fig 2, we can set the upper bound as 95% confidence interval by setting $\\hat{U}_t(a)$ to be twice the standard deviation.\nFig. 3. When the expected reward has a Gaussian distribution. $\\sigma(a\\_i)$ is the standard deviation and $c\\sigma(a\\_i)$ is the upper confidence bound. The constant $c$ is a adjustable hyperparameter. (Image source: UCL RL course lecture 9's slides) Check my toy implementation of UCB1 and Bayesian UCB with Beta prior on θ.\nThompson Sampling Thompson sampling has a simple idea but it works great for solving the multi-armed bandit problem.\nFig. 4. Oops, I guess not this Thompson? (Credit goes to Ben Taborsky; he has a full theorem of how Thompson invented while pondering over who to pass the ball. Yes I stole his joke.) At each time step, we want to select action a according to the probability that a is optimal:\n $$ \\begin{aligned} \\pi(a \\; \\vert \\; h_t) \u0026= \\mathbb{P} [ Q(a) Q(a'), \\forall a' \\neq a \\; \\vert \\; h_t] \\\\ \u0026= \\mathbb{E}_{\\mathcal{R} \\vert h_t} [ \\mathbb{1}(a = \\arg\\max_{a \\in \\mathcal{A}} Q(a)) ] \\end{aligned} $$ where $\\pi(a ; \\vert ; h_t)$ is the probability of taking action a given the history $h_t$.\nFor the Bernoulli bandit, it is natural to assume that $Q(a)$ follows a Beta distribution, as $Q(a)$ is essentially the success probability θ in Bernoulli distribution. The value of $\\text{Beta}(\\alpha, \\beta)$ is within the interval [0, 1]; α and β correspond to the counts when we succeeded or failed to get a reward respectively.\nFirst, let us initialize the Beta parameters α and β based on some prior knowledge or belief for every action. For example,\n α = 1 and β = 1; we expect the reward probability to be 50% but we are not very confident. α = 1000 and β = 9000; we strongly believe that the reward probability is 10%. At each time t, we sample an expected reward, $\\tilde{Q}(a)$, from the prior distribution $\\text{Beta}(\\alpha_i, \\beta_i)$ for every action. The best action is selected among samples: $a^{TS}_t = \\arg\\max_{a \\in \\mathcal{A}} \\tilde{Q}(a)$. After the true reward is observed, we can update the Beta distribution accordingly, which is essentially doing Bayesian inference to compute the posterior with the known prior and the likelihood of getting the sampled data.\n $$ \\begin{aligned} \\alpha_i \u0026 \\leftarrow \\alpha_i + r_t \\mathbb{1}[a^{TS}_t = a_i] \\\\ \\beta_i \u0026 \\leftarrow \\beta_i + (1-r_t) \\mathbb{1}[a^{TS}_t = a_i] \\end{aligned} $$ Thompson sampling implements the idea of probability matching. Because its reward estimations $\\tilde{Q}$ are sampled from posterior distributions, each of these probabilities is equivalent to the probability that the corresponding action is optimal, conditioned on observed history.\nHowever, for many practical and complex problems, it can be computationally intractable to estimate the posterior distributions with observed true rewards using Bayesian inference. Thompson sampling still can work out if we are able to approximate the posterior distributions using methods like Gibbs sampling, Laplace approximate, and the bootstraps. This tutorial presents a comprehensive review; strongly recommend it if you want to learn more about Thompson sampling.\nCase Study I implemented the above algorithms in lilianweng/multi-armed-bandit. A BernoulliBandit object can be constructed with a list of random or predefined reward probabilities. The bandit algorithms are implemented as subclasses of Solver, taking a Bandit object as the target problem. The cumulative regrets are tracked in time.\nFig. 4. The result of a small experiment on solving a Bernoulli bandit with K = 10 slot machines with reward probabilities, {0.0, 0.1, 0.2, ..., 0.9}. Each solver runs 10000 steps. (Left) The plot of time step vs the cumulative regrets. (Middle) The plot of true reward probability vs estimated probability. (Right) The fraction of each action is picked during the 10000-step run.* Summary We need exploration because information is valuable. In terms of the exploration strategies, we can do no exploration at all, focusing on the short-term returns. Or we occasionally explore at random. Or even further, we explore and we are picky about which options to explore \u0026mdash; actions with higher uncertainty are favored because they can provide higher information gain.\n Cited as:\n@article{weng2018bandit, title = \u0026quot;The Multi-Armed Bandit Problem and Its Solutions\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2018\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2018-01-23-multi-armed-bandit/\u0026quot; } References [1] CS229 Supplemental Lecture notes: Hoeffding\u0026rsquo;s inequality.\n[2] RL Course by David Silver - Lecture 9: Exploration and Exploitation\n[3] Olivier Chapelle and Lihong Li. \u0026ldquo;An empirical evaluation of thompson sampling.\u0026quot; NIPS. 2011.\n[4] Russo, Daniel, et al. \u0026ldquo;A Tutorial on Thompson Sampling.\u0026quot; arXiv:1707.02038 (2017).\n","permalink":"https://lilianweng.github.io/posts/2018-01-23-multi-armed-bandit/","summary":"The algorithms are implemented for Bernoulli bandit in lilianweng/multi-armed-bandit.\nExploitation vs Exploration The exploration vs exploitation dilemma exists in many aspects of our life. Say, your favorite restaurant is right around the corner. If you go there every day, you would be confident of what you will get, but miss the chances of discovering an even better option. If you try new places all the time, very likely you are gonna have to eat unpleasant food from time to time.","title":"The Multi-Armed Bandit Problem and Its Solutions"},{"content":"[Updated on 2018-12-20: Remove YOLO here. Part 4 will cover multiple fast object detection algorithms, including YOLO.] [Updated on 2018-12-27: Add bbox regression and tricks sections for R-CNN.]\nIn the series of \u0026ldquo;Object Detection for Dummies\u0026rdquo;, we started with basic concepts in image processing, such as gradient vectors and HOG, in Part 1. Then we introduced classic convolutional neural network architecture designs for classification and pioneer models for object recognition, Overfeat and DPM, in Part 2. In the third post of this series, we are about to review a set of models in the R-CNN (\u0026ldquo;Region-based CNN\u0026rdquo;) family.\nLinks to all the posts in the series: [Part 1] [Part 2] [Part 3] [Part 4].\nHere is a list of papers covered in this post ;)\n Model Goal Resources R-CNN Object recognition [paper][code] Fast R-CNN Object recognition [paper][code] Faster R-CNN Object recognition [paper][code] Mask R-CNN Image segmentation [paper][code] R-CNN R-CNN (Girshick et al., 2014) is short for \u0026ldquo;Region-based Convolutional Neural Networks\u0026rdquo;. The main idea is composed of two steps. First, using selective search, it identifies a manageable number of bounding-box object region candidates (\u0026ldquo;region of interest\u0026rdquo; or \u0026ldquo;RoI\u0026rdquo;). And then it extracts CNN features from each region independently for classification.\nFig. 1. The architecture of R-CNN. (Image source: Girshick et al., 2014) Model Workflow How R-CNN works can be summarized as follows:\n Pre-train a CNN network on image classification tasks; for example, VGG or ResNet trained on ImageNet dataset. The classification task involves N classes. NOTE: You can find a pre-trained AlexNet in Caffe Model Zoo. I don’t think you can find it in Tensorflow, but Tensorflow-slim model library provides pre-trained ResNet, VGG, and others.\n Propose category-independent regions of interest by selective search (~2k candidates per image). Those regions may contain target objects and they are of different sizes. Region candidates are warped to have a fixed size as required by CNN. Continue fine-tuning the CNN on warped proposal regions for K + 1 classes; The additional one class refers to the background (no object of interest). In the fine-tuning stage, we should use a much smaller learning rate and the mini-batch oversamples the positive cases because most proposed regions are just background. Given every image region, one forward propagation through the CNN generates a feature vector. This feature vector is then consumed by a binary SVM trained for each class independently. The positive samples are proposed regions with IoU (intersection over union) overlap threshold \u0026gt;= 0.3, and negative samples are irrelevant others. To reduce the localization errors, a regression model is trained to correct the predicted detection window on bounding box correction offset using CNN features. Bounding Box Regression Given a predicted bounding box coordinate $\\mathbf{p} = (p_x, p_y, p_w, p_h)$ (center coordinate, width, height) and its corresponding ground truth box coordinates $\\mathbf{g} = (g_x, g_y, g_w, g_h)$ , the regressor is configured to learn scale-invariant transformation between two centers and log-scale transformation between widths and heights. All the transformation functions take $\\mathbf{p}$ as input.\n $$ \\begin{aligned} \\hat{g}_x \u0026= p_w d_x(\\mathbf{p}) + p_x \\\\ \\hat{g}_y \u0026= p_h d_y(\\mathbf{p}) + p_y \\\\ \\hat{g}_w \u0026= p_w \\exp({d_w(\\mathbf{p})}) \\\\ \\hat{g}_h \u0026= p_h \\exp({d_h(\\mathbf{p})}) \\end{aligned} $$ Fig. 2. Illustration of transformation between predicted and ground truth bounding boxes. An obvious benefit of applying such transformation is that all the bounding box correction functions, $d_i(\\mathbf{p})$ where $i \\in \\{ x, y, w, h \\}$, can take any value between [-∞, +∞]. The targets for them to learn are:\n $$ \\begin{aligned} t_x \u0026= (g_x - p_x) / p_w \\\\ t_y \u0026= (g_y - p_y) / p_h \\\\ t_w \u0026= \\log(g_w/p_w) \\\\ t_h \u0026= \\log(g_h/p_h) \\end{aligned} $$ A standard regression model can solve the problem by minimizing the SSE loss with regularization:\n $$ \\mathcal{L}_\\text{reg} = \\sum_{i \\in \\{x, y, w, h\\}} (t_i - d_i(\\mathbf{p}))^2 + \\lambda \\|\\mathbf{w}\\|^2 $$ The regularization term is critical here and RCNN paper picked the best λ by cross validation. It is also noteworthy that not all the predicted bounding boxes have corresponding ground truth boxes. For example, if there is no overlap, it does not make sense to run bbox regression. Here, only a predicted box with a nearby ground truth box with at least 0.6 IoU is kept for training the bbox regression model.\nCommon Tricks Several tricks are commonly used in RCNN and other detection models.\nNon-Maximum Suppression\nLikely the model is able to find multiple bounding boxes for the same object. Non-max suppression helps avoid repeated detection of the same instance. After we get a set of matched bounding boxes for the same object category: Sort all the bounding boxes by confidence score. Discard boxes with low confidence scores. While there is any remaining bounding box, repeat the following: Greedily select the one with the highest score. Skip the remaining boxes with high IoU (i.e. \u0026gt; 0.5) with previously selected one.\nFig. 3. Multiple bounding boxes detect the car in the image. After non-maximum suppression, only the best remains and the rest are ignored as they have large overlaps with the selected one. (Image source: DPM paper) Hard Negative Mining\nWe consider bounding boxes without objects as negative examples. Not all the negative examples are equally hard to be identified. For example, if it holds pure empty background, it is likely an “easy negative”; but if the box contains weird noisy texture or partial object, it could be hard to be recognized and these are “hard negative”.\nThe hard negative examples are easily misclassified. We can explicitly find those false positive samples during the training loops and include them in the training data so as to improve the classifier.\nSpeed Bottleneck Looking through the R-CNN learning steps, you could easily find out that training an R-CNN model is expensive and slow, as the following steps involve a lot of work:\n Running selective search to propose 2000 region candidates for every image; Generating the CNN feature vector for every image region (N images * 2000). The whole process involves three models separately without much shared computation: the convolutional neural network for image classification and feature extraction; the top SVM classifier for identifying target objects; and the regression model for tightening region bounding boxes. Fast R-CNN To make R-CNN faster, Girshick (2015) improved the training procedure by unifying three independent models into one jointly trained framework and increasing shared computation results, named Fast R-CNN. Instead of extracting CNN feature vectors independently for each region proposal, this model aggregates them into one CNN forward pass over the entire image and the region proposals share this feature matrix. Then the same feature matrix is branched out to be used for learning the object classifier and the bounding-box regressor. In conclusion, computation sharing speeds up R-CNN.\nFig. 4. The architecture of Fast R-CNN. (Image source: Girshick, 2015) RoI Pooling It is a type of max pooling to convert features in the projected region of the image of any size, h x w, into a small fixed window, H x W. The input region is divided into H x W grids, approximately every subwindow of size h/H x w/W. Then apply max-pooling in each grid.\nFig. 5. RoI pooling (Image source: Stanford CS231n slides.) Model Workflow How Fast R-CNN works is summarized as follows; many steps are same as in R-CNN:\n First, pre-train a convolutional neural network on image classification tasks. Propose regions by selective search (~2k candidates per image). Alter the pre-trained CNN: Replace the last max pooling layer of the pre-trained CNN with a RoI pooling layer. The RoI pooling layer outputs fixed-length feature vectors of region proposals. Sharing the CNN computation makes a lot of sense, as many region proposals of the same images are highly overlapped. Replace the last fully connected layer and the last softmax layer (K classes) with a fully connected layer and softmax over K + 1 classes. Finally the model branches into two output layers: A softmax estimator of K + 1 classes (same as in R-CNN, +1 is the \u0026ldquo;background\u0026rdquo; class), outputting a discrete probability distribution per RoI. A bounding-box regression model which predicts offsets relative to the original RoI for each of K classes. Loss Function The model is optimized for a loss combining two tasks (classification + localization):\n| Symbol | Explanation | | $u$ | True class label, $ u \\in 0, 1, \\dots, K$; by convention, the catch-all background class has $u = 0$. | | $p$ | Discrete probability distribution (per RoI) over K + 1 classes: $p = (p_0, \\dots, p_K)$, computed by a softmax over the K + 1 outputs of a fully connected layer. | | $v$ | True bounding box $ v = (v_x, v_y, v_w, v_h) $. | | $t^u$ | Predicted bounding box correction, $t^u = (t^u_x, t^u_y, t^u_w, t^u_h)$. See above. | {:.info}\nThe loss function sums up the cost of classification and bounding box prediction: $\\mathcal{L} = \\mathcal{L}_\\text{cls} + \\mathcal{L}_\\text{box}$. For \u0026ldquo;background\u0026rdquo; RoI, $\\mathcal{L}_\\text{box}$ is ignored by the indicator function $\\mathbb{1} [u \\geq 1]$, defined as:\n $$ \\mathbb{1} [u = 1] = \\begin{cases} 1 \u0026 \\text{if } u \\geq 1\\\\ 0 \u0026 \\text{otherwise} \\end{cases} $$ The overall loss function is:\n $$ \\begin{align*} \\mathcal{L}(p, u, t^u, v) \u0026= \\mathcal{L}_\\text{cls} (p, u) + \\mathbb{1} [u \\geq 1] \\mathcal{L}_\\text{box}(t^u, v) \\\\ \\mathcal{L}_\\text{cls}(p, u) \u0026= -\\log p_u \\\\ \\mathcal{L}_\\text{box}(t^u, v) \u0026= \\sum_{i \\in \\{x, y, w, h\\}} L_1^\\text{smooth} (t^u_i - v_i) \\end{align*} $$ The bounding box loss $\\mathcal{L}_{box}$ should measure the difference between $t^u_i$ and $v_i$ using a robust loss function. The smooth L1 loss is adopted here and it is claimed to be less sensitive to outliers.\n $$ L_1^\\text{smooth}(x) = \\begin{cases} 0.5 x^2 \u0026 \\text{if } \\vert x \\vert Fig. 6. The plot of smooth L1 loss, $y = L\\_1^\\text{smooth}(x)$. (Image source: link) Speed Bottleneck Fast R-CNN is much faster in both training and testing time. However, the improvement is not dramatic because the region proposals are generated separately by another model and that is very expensive.\nFaster R-CNN An intuitive speedup solution is to integrate the region proposal algorithm into the CNN model. Faster R-CNN (Ren et al., 2016) is doing exactly this: construct a single, unified model composed of RPN (region proposal network) and fast R-CNN with shared convolutional feature layers.\nFig. 7. An illustration of Faster R-CNN model. (Image source: Ren et al., 2016) Model Workflow Pre-train a CNN network on image classification tasks. Fine-tune the RPN (region proposal network) end-to-end for the region proposal task, which is initialized by the pre-train image classifier. Positive samples have IoU (intersection-over-union) \u0026gt; 0.7, while negative samples have IoU \u0026lt; 0.3. Slide a small n x n spatial window over the conv feature map of the entire image. At the center of each sliding window, we predict multiple regions of various scales and ratios simultaneously. An anchor is a combination of (sliding window center, scale, ratio). For example, 3 scales + 3 ratios =\u0026gt; k=9 anchors at each sliding position. Train a Fast R-CNN object detection model using the proposals generated by the current RPN Then use the Fast R-CNN network to initialize RPN training. While keeping the shared convolutional layers, only fine-tune the RPN-specific layers. At this stage, RPN and the detection network have shared convolutional layers! Finally fine-tune the unique layers of Fast R-CNN Step 4-5 can be repeated to train RPN and Fast R-CNN alternatively if needed. Loss Function Faster R-CNN is optimized for a multi-task loss function, similar to fast R-CNN.\n| Symbol | Explanation | | $p_i$ | Predicted probability of anchor i being an object. | | $p^*_i$ | Ground truth label (binary) of whether anchor i is an object. | | $t_i$ | Predicted four parameterized coordinates. | | $t^*_i$ | Ground truth coordinates. | | $N_\\text{cls}$ | Normalization term, set to be mini-batch size (~256) in the paper. | | $N_\\text{box}$ | Normalization term, set to the number of anchor locations (~2400) in the paper. | | $\\lambda$ | A balancing parameter, set to be ~10 in the paper (so that both $\\mathcal{L}_\\text{cls}$ and $\\mathcal{L}_\\text{box}$ terms are roughly equally weighted). | {:.info}\nThe multi-task loss function combines the losses of classification and bounding box regression:\n $$ \\begin{align*} \\mathcal{L} \u0026= \\mathcal{L}_\\text{cls} + \\mathcal{L}_\\text{box} \\\\ \\mathcal{L}(\\{p_i\\}, \\{t_i\\}) \u0026= \\frac{1}{N_\\text{cls}} \\sum_i \\mathcal{L}_\\text{cls} (p_i, p^*_i) + \\frac{\\lambda}{N_\\text{box}} \\sum_i p^*_i \\cdot L_1^\\text{smooth}(t_i - t^*_i) \\\\ \\end{align*} $$ where $\\mathcal{L}_\\text{cls}$ is the log loss function over two classes, as we can easily translate a multi-class classification into a binary classification by predicting a sample being a target object versus not. $L_1^\\text{smooth}$ is the smooth L1 loss.\n $$ \\mathcal{L}_\\text{cls} (p_i, p^*_i) = - p^*_i \\log p_i - (1 - p^*_i) \\log (1 - p_i) $$ Mask R-CNN Mask R-CNN (He et al., 2017) extends Faster R-CNN to pixel-level image segmentation. The key point is to decouple the classification and the pixel-level mask prediction tasks. Based on the framework of Faster R-CNN, it added a third branch for predicting an object mask in parallel with the existing branches for classification and localization. The mask branch is a small fully-connected network applied to each RoI, predicting a segmentation mask in a pixel-to-pixel manner.\nFig. 8. Mask R-CNN is Faster R-CNN model with image segmentation. (Image source: He et al., 2017) Because pixel-level segmentation requires much more fine-grained alignment than bounding boxes, mask R-CNN improves the RoI pooling layer (named \u0026ldquo;RoIAlign layer\u0026rdquo;) so that RoI can be better and more precisely mapped to the regions of the original image.\nFig. 9. Predictions by Mask R-CNN on COCO test set. (Image source: He et al., 2017) RoIAlign The RoIAlign layer is designed to fix the location misalignment caused by quantization in the RoI pooling. RoIAlign removes the hash quantization, for example, by using x/16 instead of [x/16], so that the extracted features can be properly aligned with the input pixels. Bilinear interpolation is used for computing the floating-point location values in the input.\nFig. 10. A region of interest is mapped **accurately** from the original image onto the feature map without rounding up to integers. (Image source: link) Loss Function The multi-task loss function of Mask R-CNN combines the loss of classification, localization and segmentation mask: $ \\mathcal{L} = \\mathcal{L}_\\text{cls} + \\mathcal{L}_\\text{box} + \\mathcal{L}_\\text{mask}$, where $\\mathcal{L}_\\text{cls}$ and $\\mathcal{L}_\\text{box}$ are same as in Faster R-CNN.\nThe mask branch generates a mask of dimension m x m for each RoI and each class; K classes in total. Thus, the total output is of size $K \\cdot m^2$. Because the model is trying to learn a mask for each class, there is no competition among classes for generating masks.\n$\\mathcal{L}_\\text{mask}$ is defined as the average binary cross-entropy loss, only including k-th mask if the region is associated with the ground truth class k.\n $$ \\mathcal{L}_\\text{mask} = - \\frac{1}{m^2} \\sum_{1 \\leq i, j \\leq m} \\big[ y_{ij} \\log \\hat{y}^k_{ij} + (1-y_{ij}) \\log (1- \\hat{y}^k_{ij}) \\big] $$ where $y_{ij}$ is the label of a cell (i, j) in the true mask for the region of size m x m; $\\hat{y}_{ij}^k$ is the predicted value of the same cell in the mask learned for the ground-truth class k.\nSummary of Models in the R-CNN family Here I illustrate model designs of R-CNN, Fast R-CNN, Faster R-CNN and Mask R-CNN. You can track how one model evolves to the next version by comparing the small differences.\n Cited as:\n@article{weng2017detection3, title = \u0026quot;Object Detection for Dummies Part 3: R-CNN Family\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2017\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2017-12-31-object-recognition-part-3/\u0026quot; } Reference [1] Ross Girshick, Jeff Donahue, Trevor Darrell, and Jitendra Malik. \u0026ldquo;Rich feature hierarchies for accurate object detection and semantic segmentation.\u0026quot; In Proc. IEEE Conf. on computer vision and pattern recognition (CVPR), pp. 580-587. 2014.\n[2] Ross Girshick. \u0026ldquo;Fast R-CNN.\u0026quot; In Proc. IEEE Intl. Conf. on computer vision, pp. 1440-1448. 2015.\n[3] Shaoqing Ren, Kaiming He, Ross Girshick, and Jian Sun. \u0026ldquo;Faster R-CNN: Towards real-time object detection with region proposal networks.\u0026quot; In Advances in neural information processing systems (NIPS), pp. 91-99. 2015.\n[4] Kaiming He, Georgia Gkioxari, Piotr Dollár, and Ross Girshick. \u0026ldquo;Mask R-CNN.\u0026quot; arXiv preprint arXiv:1703.06870, 2017.\n[5] Joseph Redmon, Santosh Divvala, Ross Girshick, and Ali Farhadi. \u0026ldquo;You only look once: Unified, real-time object detection.\u0026quot; In Proc. IEEE Conf. on computer vision and pattern recognition (CVPR), pp. 779-788. 2016.\n[6] \u0026ldquo;A Brief History of CNNs in Image Segmentation: From R-CNN to Mask R-CNN\u0026rdquo; by Athelas.\n[7] Smooth L1 Loss: https://github.com/rbgirshick/py-faster-rcnn/files/764206/SmoothL1Loss.1.pdf\n","permalink":"https://lilianweng.github.io/posts/2017-12-31-object-recognition-part-3/","summary":"[Updated on 2018-12-20: Remove YOLO here. Part 4 will cover multiple fast object detection algorithms, including YOLO.] [Updated on 2018-12-27: Add bbox regression and tricks sections for R-CNN.]\nIn the series of \u0026ldquo;Object Detection for Dummies\u0026rdquo;, we started with basic concepts in image processing, such as gradient vectors and HOG, in Part 1. Then we introduced classic convolutional neural network architecture designs for classification and pioneer models for object recognition, Overfeat and DPM, in Part 2.","title":"Object Detection for Dummies Part 3: R-CNN Family"},{"content":"Part 1 of the \u0026ldquo;Object Detection for Dummies\u0026rdquo; series introduced: (1) the concept of image gradient vector and how HOG algorithm summarizes the information across all the gradient vectors in one image; (2) how the image segmentation algorithm works to detect regions that potentially contain objects; (3) how the Selective Search algorithm refines the outcomes of image segmentation for better region proposal.\nIn Part 2, we are about to find out more on the classic convolution neural network architectures for image classification. They lay the foundation for further progress on the deep learning models for object detection. Go check Part 3 if you want to learn more on R-CNN and related models.\nLinks to all the posts in the series: [Part 1] [Part 2] [Part 3] [Part 4].\nCNN for Image Classification CNN, short for \u0026ldquo;Convolutional Neural Network\u0026rdquo;, is the go-to solution for computer vision problems in the deep learning world. It was, to some extent, inspired by how human visual cortex system works.\nConvolution Operation I strongly recommend this guide to convolution arithmetic, which provides a clean and solid explanation with tons of visualizations and examples. Here let\u0026rsquo;s focus on two-dimensional convolution as we are working with images in this post.\nIn short, convolution operation slides a predefined kernel (also called \u0026ldquo;filter\u0026rdquo;) on top of the input feature map (matrix of image pixels), multiplying and adding the values of the kernel and partial input features to generate the output. The values form an output matrix, as usually, the kernel is much smaller than the input image.\nFig. 1. An illustration of applying a kernel on the input feature map to generate the output. (Image source: River Trail documentation) Figure 2 showcases two real examples of how to convolve a 3x3 kernel over a 5x5 2D matrix of numeric values to generate a 3x3 matrix. By controlling the padding size and the stride length, we can generate an output matrix of a certain size.\nFig. 2. Two examples of 2D convolution operation: (top) no padding and 1x1 strides; (bottom) 1x1 border zeros padding and 2x2 strides. (Image source: deeplearning.net) AlexNet (Krizhevsky et al, 2012) 5 convolution [+ optional max pooling] layers + 2 MLP layers + 1 LR layer Use data augmentation techniques to expand the training dataset, such as image translations, horizontal reflections, and patch extractions. Fig. 3. The architecture of AlexNet. (Image source: link) VGG (Simonyan and Zisserman, 2014) The network is considered as \u0026ldquo;very deep\u0026rdquo; at its time; 19 layers The architecture is extremely simplified with only 3x3 convolutional layers and 2x2 pooling layers. The stacking of small filters simulates a larger filter with fewer parameters. ResNet (He et al., 2015) The network is indeed very deep; 152 layers of simple architecture. Residual Block: Some input of a certain layer can be passed to the component two layers later. Residual blocks are essential for keeping a deep network trainable and eventually work. Without residual blocks, the training loss of a plain network does not monotonically decrease as the number of layers increases due to vanishing and exploding gradients. Fig. 4. An illustration of the residual block of ResNet. In some way, we can say the design of residual blocks is inspired by V4 getting input directly from V1 in the human visual cortex system. (left image source: Wang et al., 2017) Evaluation Metrics: mAP A common evaluation metric used in many object recognition and detection tasks is \u0026ldquo;mAP\u0026rdquo;, short for \u0026ldquo;mean average precision\u0026rdquo;. It is a number from 0 to 100; higher value is better.\n Combine all detections from all test images to draw a precision-recall curve (PR curve) for each class; The \u0026ldquo;average precision\u0026rdquo; (AP) is the area under the PR curve. Given that target objects are in different classes, we first compute AP separately for each class, and then average over classes. A detection is a true positive if it has \u0026ldquo;intersection over union\u0026rdquo; (IoU) with a ground-truth box greater than some threshold (usually 0.5; if so, the metric is \u0026ldquo;mAP@0.5\u0026rdquo;) Deformable Parts Model The Deformable Parts Model (DPM) (Felzenszwalb et al., 2010) recognizes objects with a mixture graphical model (Markov random fields) of deformable parts. The model consists of three major components:\n A coarse root filter defines a detection window that approximately covers an entire object. A filter specifies weights for a region feature vector. Multiple part filters that cover smaller parts of the object. Parts filters are learned at twice resolution of the root filter. A spatial model for scoring the locations of part filters relative to the root. Fig. 5. The DPM model contains (a) a root filter, (b) multiple part filters at twice the resolution, and (c) a model for scoring the location and deformation of parts. The quality of detecting an object is measured by the score of filters minus the deformation costs. The matching score $f$, in laymen\u0026rsquo;s terms, is:\n $$ f(\\text{model}, x) = f(\\beta_\\text{root}, x) + \\sum_{\\beta_\\text{part} \\in \\text{part filters}} \\max_y [f(\\beta_\\text{part}, y) - \\text{cost}(\\beta_\\text{part}, x, y)] $$ in which,\n $x$ is an image with a specified position and scale; $y$ is a sub region of $x$. $\\beta_\\text{root}$ is the root filter. $\\beta_\\text{part}$ is one part filter. cost() measures the penalty of the part deviating from its ideal location relative to the root. The basic score model is the dot product between the filter $\\beta$ and the region feature vector $\\Phi(x)$: $f(\\beta, x) = \\beta \\cdot \\Phi(x)$. The feature set $\\Phi(x)$ can be defined by HOG or other similar algorithms.\nA root location with high score detects a region with high chances to contain an object, while the locations of the parts with high scores confirm a recognized object hypothesis. The paper adopted latent SVM to model the classifier.\nFig. 6. The matching process by DPM. (Image source: Felzenszwalb et al., 2010) The author later claimed that DPM and CNN models are not two distinct approaches to object recognition. Instead, a DPM model can be formulated as a CNN by unrolling the DPM inference algorithm and mapping each step to an equivalent CNN layer. (Check the details in Girshick et al., 2015!)\nOverfeat Overfeat [paper][code] is a pioneer model of integrating the object detection, localization and classification tasks all into one convolutional neural network. The main idea is to (i) do image classification at different locations on regions of multiple scales of the image in a sliding window fashion, and (ii) predict the bounding box locations with a regressor trained on top of the same convolution layers.\nThe Overfeat model architecture is very similar to AlexNet. It is trained as follows:\nFig. 7. The training stages of the Overfeat model. (Image source: link) Train a CNN model (similar to AlexNet) on the image classification task. Then, we replace the top classifier layers by a regression network and train it to predict object bounding boxes at each spatial location and scale. The regressor is class-specific, each generated for one image class. Input: Images with classification and bounding box. Output: $(x_\\text{left}, x_\\text{right}, y_\\text{top}, y_\\text{bottom})$, 4 values in total, representing the coordinates of the bounding box edges. Loss: The regressor is trained to minimize $l2$ norm between generated bounding box and the ground truth for each training example. At the detection time,\n Perform classification at each location using the pretrained CNN model. Predict object bounding boxes on all classified regions generated by the classifier. Merge bounding boxes with sufficient overlap from localization and sufficient confidence of being the same object from the classifier. Cited as:\n@article{weng2017detection2, title = \u0026quot;Object Detection for Dummies Part 2: CNN, DPM and Overfeat\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2017\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2017-12-15-object-recognition-part-2/\u0026quot; } Reference [1] Vincent Dumoulin and Francesco Visin. \u0026ldquo;A guide to convolution arithmetic for deep learning.\u0026quot; arXiv preprint arXiv:1603.07285 (2016).\n[2] Haohan Wang, Bhiksha Raj, and Eric P. Xing. \u0026ldquo;On the Origin of Deep Learning.\u0026quot; arXiv preprint arXiv:1702.07800 (2017).\n[3] Pedro F. Felzenszwalb, Ross B. Girshick, David McAllester, and Deva Ramanan. \u0026ldquo;Object detection with discriminatively trained part-based models.\u0026quot; IEEE transactions on pattern analysis and machine intelligence 32, no. 9 (2010): 1627-1645.\n[4] Ross B. Girshick, Forrest Iandola, Trevor Darrell, and Jitendra Malik. \u0026ldquo;Deformable part models are convolutional neural networks.\u0026quot; In Proc. IEEE Conf. on Computer Vision and Pattern Recognition (CVPR), pp. 437-446. 2015.\n[5] Sermanet, Pierre, David Eigen, Xiang Zhang, Michaël Mathieu, Rob Fergus, and Yann LeCun. \u0026ldquo;OverFeat: Integrated Recognition, Localization and Detection using Convolutional Networks\u0026rdquo; arXiv preprint arXiv:1312.6229 (2013).\n","permalink":"https://lilianweng.github.io/posts/2017-12-15-object-recognition-part-2/","summary":"Part 1 of the \u0026ldquo;Object Detection for Dummies\u0026rdquo; series introduced: (1) the concept of image gradient vector and how HOG algorithm summarizes the information across all the gradient vectors in one image; (2) how the image segmentation algorithm works to detect regions that potentially contain objects; (3) how the Selective Search algorithm refines the outcomes of image segmentation for better region proposal.\nIn Part 2, we are about to find out more on the classic convolution neural network architectures for image classification.","title":"Object Detection for Dummies Part 2: CNN, DPM and Overfeat"},{"content":"I\u0026rsquo;ve never worked in the field of computer vision and has no idea how the magic could work when an autonomous car is configured to tell apart a stop sign from a pedestrian in a red hat. To motivate myself to look into the maths behind object recognition and detection algorithms, I\u0026rsquo;m writing a few posts on this topic \u0026ldquo;Object Detection for Dummies\u0026rdquo;. This post, part 1, starts with super rudimentary concepts in image processing and a few methods for image segmentation. Nothing related to deep neural networks yet. Deep learning models for object detection and recognition will be discussed in Part 2 and Part 3.\n Disclaimer: When I started, I was using \u0026ldquo;object recognition\u0026rdquo; and \u0026ldquo;object detection\u0026rdquo; interchangeably. I don\u0026rsquo;t think they are the same: the former is more about telling whether an object exists in an image while the latter needs to spot where the object is. However, they are highly related and many object recognition algorithms lay the foundation for detection.\n Links to all the posts in the series: [Part 1] [Part 2] [Part 3] [Part 4].\nImage Gradient Vector First of all, I would like to make sure we can distinguish the following terms. They are very similar, closely related, but not exactly the same.\n Derivative Directional Derivative Gradient Value type Scalar Scalar Vector Definition The rate of change of a function $f(x,y,z,\u0026hellip;)$ at a point $(x_0,y_0,z_0,\u0026hellip;)$, which is the slope of the tangent line at the point. The instantaneous rate of change of $f(x,y,z, \u0026hellip;)$ in the direction of an unit vector $\\vec{u}$. It points in the direction of the greatest rate of increase of the function, containing all the partial derivative information of a multivariable function. In the image processing, we want to know the direction of colors changing from one extreme to the other (i.e. black to white on a grayscale image). Therefore, we want to measure \u0026ldquo;gradient\u0026rdquo; on pixels of colors. The gradient on an image is discrete because each pixel is independent and cannot be further split.\nThe image gradient vector is defined as a metric for every individual pixel, containing the pixel color changes in both x-axis and y-axis. The definition is aligned with the gradient of a continuous multi-variable function, which is a vector of partial derivatives of all the variables. Suppose f(x, y) records the color of the pixel at location (x, y), the gradient vector of the pixel (x, y) is defined as follows:\n $$ \\begin{align*} \\nabla f(x, y) = \\begin{bmatrix} g_x \\\\ g_y \\end{bmatrix} = \\begin{bmatrix} \\frac{\\partial f}{\\partial x} \\\\[6pt] \\frac{\\partial f}{\\partial y} \\end{bmatrix} = \\begin{bmatrix} f(x+1, y) - f(x-1, y)\\\\ f(x, y+1) - f(x, y-1) \\end{bmatrix} \\end{align*} $$ The $\\frac{\\partial f}{\\partial x}$ term is the partial derivative on the x-direction, which is computed as the color difference between the adjacent pixels on the left and right of the target, f(x+1, y) - f(x-1, y). Similarly, the $\\frac{\\partial f}{\\partial y}$ term is the partial derivative on the y-direction, measured as f(x, y+1) - f(x, y-1), the color difference between the adjacent pixels above and below the target.\nThere are two important attributes of an image gradient:\n Magnitude is the L2-norm of the vector, $g = \\sqrt{ g_x^2 + g_y^2 }$. Direction is the arctangent of the ratio between the partial derivatives on two directions, $\\theta = \\arctan{(g_y / g_x)}$. Fig. 1. To compute the gradient vector of a target pixel at location (x, y), we need to know the colors of its four neighbors (or eight surrounding pixels depending on the kernel). The gradient vector of the example in Fig. 1. is:\n $$ \\begin{align*} \\nabla f = \\begin{bmatrix} f(x+1, y) - f(x-1, y)\\\\ f(x, y+1) - f(x, y-1) \\end{bmatrix} = \\begin{bmatrix} 55-105\\\\ 90-40 \\end{bmatrix} = \\begin{bmatrix} -50\\\\ 50 \\end{bmatrix} \\end{align*} $$ Thus,\n the magnitude is $\\sqrt{50^2 + (-50)^2} = 70.7107$, and the direction is $\\arctan{(-50/50)} = -45^{\\circ}$. Repeating the gradient computation process for every pixel iteratively is too slow. Instead, it can be well translated into applying a convolution operator on the entire image matrix, labeled as $\\mathbf{A}$ using one of the specially designed convolutional kernels.\nLet\u0026rsquo;s start with the x-direction of the example in Fig 1. using the kernel $[-1,0,1]$ sliding over the x-axis; $\\ast$ is the convolution operator:\n $$ \\begin{align*} \\mathbf{G}_x \u0026= [-1, 0, 1] \\ast [105, 255, 55] = -105 + 0 + 55 = -50 \\end{align*} $$ Similarly, on the y-direction, we adopt the kernel $[+1, 0, -1]^\\top$:\n $$ \\begin{align*} \\mathbf{G}_y \u0026= [+1, 0, -1]^\\top \\ast \\begin{bmatrix} 90\\\\ 255\\\\ 40 \\end{bmatrix} = 90 + 0 - 40 = 50 \\end{align*} $$ Try this in python:\nimport numpy as np import scipy.signal as sig data = np.array([[0, 105, 0], [40, 255, 90], [0, 55, 0]]) G_x = sig.convolve2d(data, np.array([[-1, 0, 1]]), mode=\u0026#39;valid\u0026#39;) G_y = sig.convolve2d(data, np.array([[-1], [0], [1]]), mode=\u0026#39;valid\u0026#39;) These two functions return array([[0], [-50], [0]]) and array([[0, 50, 0]]) respectively. (Note that in the numpy array representation, 40 is shown in front of 90, so -1 is listed before 1 in the kernel correspondingly.)\nCommon Image Processing Kernels Prewitt operator: Rather than only relying on four directly adjacent neighbors, the Prewitt operator utilizes eight surrounding pixels for smoother results.\n $$ \\mathbf{G}_x = \\begin{bmatrix} -1 \u0026 0 \u0026 +1 \\\\ -1 \u0026 0 \u0026 +1 \\\\ -1 \u0026 0 \u0026 +1 \\end{bmatrix} \\ast \\mathbf{A} \\text{ and } \\mathbf{G}_y = \\begin{bmatrix} +1 \u0026 +1 \u0026 +1 \\\\ 0 \u0026 0 \u0026 0 \\\\ -1 \u0026 -1 \u0026 -1 \\end{bmatrix} \\ast \\mathbf{A} $$ Sobel operator: To emphasize the impact of directly adjacent pixels more, they get assigned with higher weights.\n $$ \\mathbf{G}_x = \\begin{bmatrix} -1 \u0026 0 \u0026 +1 \\\\ -2 \u0026 0 \u0026 +2 \\\\ -1 \u0026 0 \u0026 +1 \\end{bmatrix} \\ast \\mathbf{A} \\text{ and } \\mathbf{G}_y = \\begin{bmatrix} +1 \u0026 +2 \u0026 +1 \\\\ 0 \u0026 0 \u0026 0 \\\\ -1 \u0026 -2 \u0026 -1 \\end{bmatrix} \\ast \\mathbf{A} $$ Different kernels are created for different goals, such as edge detection, blurring, sharpening and many more. Check this wiki page for more examples and references.\nExample: Manu in 2004 Let\u0026rsquo;s run a simple experiment on the photo of Manu Ginobili in 2004 [[Download Image]({{ \u0026lsquo;/assets/data/manu-2004.jpg\u0026rsquo; | relative_url }}){:target=\u0026quot;_blank\u0026quot;}] when he still had a lot of hair. For simplicity, the photo is converted to grayscale first. For colored images, we just need to repeat the same process in each color channel respectively.\nFig. 2. Manu Ginobili in 2004 with hair. (Image source: Manu Ginobili's bald spot through the years) import numpy as np import scipy import scipy.signal as sig # With mode=\u0026#34;L\u0026#34;, we force the image to be parsed in the grayscale, so it is # actually unnecessary to convert the photo color beforehand. img = scipy.misc.imread(\u0026#34;manu-2004.jpg\u0026#34;, mode=\u0026#34;L\u0026#34;) # Define the Sobel operator kernels. kernel_x = np.array([[-1, 0, 1],[-2, 0, 2],[-1, 0, 1]]) kernel_y = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]) G_x = sig.convolve2d(img, kernel_x, mode=\u0026#39;same\u0026#39;) G_y = sig.convolve2d(img, kernel_y, mode=\u0026#39;same\u0026#39;) # Plot them! fig = plt.figure() ax1 = fig.add_subplot(121) ax2 = fig.add_subplot(122) # Actually plt.imshow() can handle the value scale well even if I don\u0026#39;t do # the transformation (G_x + 255) / 2. ax1.imshow((G_x + 255) / 2, cmap=\u0026#39;gray\u0026#39;); ax1.set_xlabel(\u0026#34;Gx\u0026#34;) ax2.imshow((G_y + 255) / 2, cmap=\u0026#39;gray\u0026#39;); ax2.set_xlabel(\u0026#34;Gy\u0026#34;) plt.show() Fig. 3. Apply Sobel operator kernel on the example image. You might notice that most area is in gray. Because the difference between two pixel is between -255 and 255 and we need to convert them back to [0, 255] for the display purpose. A simple linear transformation ($\\mathbf{G}$ + 255)/2 would interpret all the zeros (i.e., constant colored background shows no change in gradient) as 125 (shown as gray).\nHistogram of Oriented Gradients (HOG) The Histogram of Oriented Gradients (HOG) is an efficient way to extract features out of the pixel colors for building an object recognition classifier. With the knowledge of image gradient vectors, it is not hard to understand how HOG works. Let\u0026rsquo;s start!\nHow HOG works Preprocess the image, including resizing and color normalization.\n Compute the gradient vector of every pixel, as well as its magnitude and direction.\n Divide the image into many 8x8 pixel cells. In each cell, the magnitude values of these 64 cells are binned and cumulatively added into 9 buckets of unsigned direction (no sign, so 0-180 degree rather than 0-360 degree; this is a practical choice based on empirical experiments). For better robustness, if the direction of the gradient vector of a pixel lays between two buckets, its magnitude does not all go into the closer one but proportionally split between two. For example, if a pixel\u0026rsquo;s gradient vector has magnitude 8 and degree 15, it is between two buckets for degree 0 and 20 and we would assign 2 to bucket 0 and 6 to bucket 20. This interesting configuration makes the histogram much more stable when small distortion is applied to the image.\n Fig. 4. How to split one gradient vector's magnitude if its degress is between two degree bins. (Image source: https://www.learnopencv.com/histogram-of-oriented-gradients/) Then we slide a 2x2 cells (thus 16x16 pixels) block across the image. In each block region, 4 histograms of 4 cells are concatenated into one-dimensional vector of 36 values and then normalized to have an unit weight. The final HOG feature vector is the concatenation of all the block vectors. It can be fed into a classifier like SVM for learning object recognition tasks. Example: Manu in 2004 Let\u0026rsquo;s reuse the same example image in the previous section. Remember that we have computed $\\mathbf{G}_x$ and $\\mathbf{G}_y$ for the whole image.\nN_BUCKETS = 9 CELL_SIZE = 8 # Each cell is 8x8 pixels BLOCK_SIZE = 2 # Each block is 2x2 cells def assign_bucket_vals(m, d, bucket_vals): left_bin = int(d / 20.) # Handle the case when the direction is between [160, 180) right_bin = (int(d / 20.) + 1) % N_BUCKETS assert 0 \u0026lt;= left_bin \u0026lt; right_bin \u0026lt; N_BUCKETS left_val= m * (right_bin * 20 - d) / 20 right_val = m * (d - left_bin * 20) / 20 bucket_vals[left_bin] += left_val bucket_vals[right_bin] += right_val def get_magnitude_hist_cell(loc_x, loc_y): # (loc_x, loc_y) defines the top left corner of the target cell. cell_x = G_x[loc_x:loc_x + CELL_SIZE, loc_y:loc_y + CELL_SIZE] cell_y = G_y[loc_x:loc_x + CELL_SIZE, loc_y:loc_y + CELL_SIZE] magnitudes = np.sqrt(cell_x * cell_x + cell_y * cell_y) directions = np.abs(np.arctan(cell_y / cell_x) * 180 / np.pi) buckets = np.linspace(0, 180, N_BUCKETS + 1) bucket_vals = np.zeros(N_BUCKETS) map( lambda (m, d): assign_bucket_vals(m, d, bucket_vals), zip(magnitudes.flatten(), directions.flatten()) ) return bucket_vals def get_magnitude_hist_block(loc_x, loc_y): # (loc_x, loc_y) defines the top left corner of the target block. return reduce( lambda arr1, arr2: np.concatenate((arr1, arr2)), [get_magnitude_hist_cell(x, y) for x, y in zip( [loc_x, loc_x + CELL_SIZE, loc_x, loc_x + CELL_SIZE], [loc_y, loc_y, loc_y + CELL_SIZE, loc_y + CELL_SIZE], )] ) The following code simply calls the functions to construct a histogram and plot it.\n# Random location [200, 200] as an example. loc_x = loc_y = 200 ydata = get_magnitude_hist_block(loc_x, loc_y) ydata = ydata / np.linalg.norm(ydata) xdata = range(len(ydata)) bucket_names = np.tile(np.arange(N_BUCKETS), BLOCK_SIZE * BLOCK_SIZE) assert len(ydata) == N_BUCKETS * (BLOCK_SIZE * BLOCK_SIZE) assert len(bucket_names) == len(ydata) plt.figure(figsize=(10, 3)) plt.bar(xdata, ydata, align=\u0026#39;center\u0026#39;, alpha=0.8, width=0.9) plt.xticks(xdata, bucket_names * 20, rotation=90) plt.xlabel(\u0026#39;Direction buckets\u0026#39;) plt.ylabel(\u0026#39;Magnitude\u0026#39;) plt.grid(ls=\u0026#39;--\u0026#39;, color=\u0026#39;k\u0026#39;, alpha=0.1) plt.title(\u0026#34;HOG of block at [%d, %d]\u0026#34; % (loc_x, loc_y)) plt.tight_layout() In the code above, I use the block with top left corner located at [200, 200] as an example and here is the final normalized histogram of this block. You can play with the code to change the block location to be identified by a sliding window.\nFig. 5. Demonstration of a HOG histogram for one block. The code is mostly for demonstrating the computation process. There are many off-the-shelf libraries with HOG algorithm implemented, such as OpenCV, SimpleCV and scikit-image.\nImage Segmentation (Felzenszwalb\u0026rsquo;s Algorithm) When there exist multiple objects in one image (true for almost every real-world photos), we need to identify a region that potentially contains a target object so that the classification can be executed more efficiently.\nFelzenszwalb and Huttenlocher (2004) proposed an algorithm for segmenting an image into similar regions using a graph-based approach. It is also the initialization method for Selective Search (a popular region proposal algorithm) that we are gonna discuss later.\nSay, we use a undirected graph $G=(V, E)$ to represent an input image. One vertex $v_i \\in V$ represents one pixel. One edge $e = (v_i, v_j) \\in E$ connects two vertices $v_i$ and $v_j$. Its associated weight $w(v_i, v_j)$ measures the dissimilarity between $v_i$ and $v_j$. The dissimilarity can be quantified in dimensions like color, location, intensity, etc. The higher the weight, the less similar two pixels are. A segmentation solution $S$ is a partition of $V$ into multiple connected components, $\\{C\\}$. Intuitively similar pixels should belong to the same components while dissimilar ones are assigned to different components.\nGraph Construction There are two approaches to constructing a graph out of an image.\n Grid Graph: Each pixel is only connected with surrounding neighbours (8 other cells in total). The edge weight is the absolute difference between the intensity values of the pixels. Nearest Neighbor Graph: Each pixel is a point in the feature space (x, y, r, g, b), in which (x, y) is the pixel location and (r, g, b) is the color values in RGB. The weight is the Euclidean distance between two pixels' feature vectors. Key Concepts Before we lay down the criteria for a good graph partition (aka image segmentation), let us define a couple of key concepts:\n Internal difference: $Int(C) = \\max_{e\\in MST(C, E)} w(e)$, where $MST$ is the minimum spanning tree of the components. A component $C$ can still remain connected even when we have removed all the edges with weights \u0026lt; $Int(C)$. Difference between two components: $Dif(C_1, C_2) = \\min_{v_i \\in C_1, v_j \\in C_2, (v_i, v_j) \\in E} w(v_i, v_j)$. $Dif(C_1, C_2) = \\infty$ if there is no edge in-between. Minimum internal difference: $MInt(C_1, C_2) = min(Int(C_1) + \\tau(C_1), Int(C_2) + \\tau(C_2))$, where $\\tau(C) = k / \\vert C \\vert$ helps make sure we have a meaningful threshold for the difference between components. With a higher $k$, it is more likely to result in larger components. The quality of a segmentation is assessed by a pairwise region comparison predicate defined for given two regions $C_1$ and $C_2$:\n $$ D(C_1, C_2) = \\begin{cases} \\text{True} \u0026 \\text{ if } Dif(C_1, C_2) MInt(C_1, C_2) \\\\ \\text{False} \u0026 \\text{ otherwise} \\end{cases} $$ Only when the predicate holds True, we consider them as two independent components; otherwise the segmentation is too fine and they probably should be merged.\nHow Image Segmentation Works The algorithm follows a bottom-up procedure. Given $G=(V, E)$ and $|V|=n, |E|=m$:\n Edges are sorted by weight in ascending order, labeled as $e_1, e_2, \\dots, e_m$. Initially, each pixel stays in its own component, so we start with $n$ components. Repeat for $k=1, \\dots, m$: The segmentation snapshot at the step $k$ is denoted as $S^k$. We take the k-th edge in the order, $e_k = (v_i, v_j)$. If $v_i$ and $v_j$ belong to the same component, do nothing and thus $S^k = S^{k-1}$. If $v_i$ and $v_j$ belong to two different components $C_i^{k-1}$ and $C_j^{k-1}$ as in the segmentation $S^{k-1}$, we want to merge them into one if $w(v_i, v_j) \\leq MInt(C_i^{k-1}, C_j^{k-1})$; otherwise do nothing. If you are interested in the proof of the segmentation properties and why it always exists, please refer to the paper.\nFig. 6. An indoor scene with segmentation detected by the grid graph construction in Felzenszwalb's graph-based segmentation algorithm (k=300). Example: Manu in 2013 This time I would use the photo of old Manu Ginobili in 2013 [[Image]({{ \u0026lsquo;/assets/data/manu-2013.jpg\u0026rsquo; | relative_url }})] as the example image when his bald spot has grown up strong. Still for simplicity, we use the picture in grayscale.\nFig. 7. Manu Ginobili in 2013 with bald spot. (Image source: Manu Ginobili's bald spot through the years) Rather than coding from scratch, let us apply skimage.segmentation.felzenszwalb to the image.\nimport skimage.segmentation from matplotlib import pyplot as plt img2 = scipy.misc.imread(\u0026#34;manu-2013.jpg\u0026#34;, mode=\u0026#34;L\u0026#34;) segment_mask1 = skimage.segmentation.felzenszwalb(img2, scale=100) segment_mask2 = skimage.segmentation.felzenszwalb(img2, scale=1000) fig = plt.figure(figsize=(12, 5)) ax1 = fig.add_subplot(121) ax2 = fig.add_subplot(122) ax1.imshow(segment_mask1); ax1.set_xlabel(\u0026#34;k=100\u0026#34;) ax2.imshow(segment_mask2); ax2.set_xlabel(\u0026#34;k=1000\u0026#34;) fig.suptitle(\u0026#34;Felsenszwalb\u0026#39;s efficient graph based image segmentation\u0026#34;) plt.tight_layout() plt.show() The code ran two versions of Felzenszwalb\u0026rsquo;s algorithms as shown in Fig. 8. The left k=100 generates a finer-grained segmentation with small regions where Manu\u0026rsquo;s bald spot is identified. The right one k=1000 outputs a coarser-grained segmentation where regions tend to be larger.\nFig. 8. Felsenszwalb's efficient graph-based image segmentation is applied on the photo of Manu in 2013. Selective Search Selective search is a common algorithm to provide region proposals that potentially contain objects. It is built on top of the image segmentation output and use region-based characteristics (NOTE: not just attributes of a single pixel) to do a bottom-up hierarchical grouping.\nHow Selective Search Works At the initialization stage, apply Felzenszwalb and Huttenlocher\u0026rsquo;s graph-based image segmentation algorithm to create regions to start with. Use a greedy algorithm to iteratively group regions together: First the similarities between all neighbouring regions are calculated. The two most similar regions are grouped together, and new similarities are calculated between the resulting region and its neighbours. The process of grouping the most similar regions (Step 2) is repeated until the whole image becomes a single region. Fig. 9. The detailed algorithm of Selective Search. Configuration Variations Given two regions $(r_i, r_j)$, selective search proposed four complementary similarity measures:\n Color similarity Texture: Use algorithm that works well for material recognition such as SIFT. Size: Small regions are encouraged to merge early. Shape: Ideally one region can fill the gap of the other. By (i) tuning the threshold $k$ in Felzenszwalb and Huttenlocher\u0026rsquo;s algorithm, (ii) changing the color space and (iii) picking different combinations of similarity metrics, we can produce a diverse set of Selective Search strategies. The version that produces the region proposals with best quality is configured with (i) a mixture of various initial segmentation proposals, (ii) a blend of multiple color spaces and (iii) a combination of all similarity measures. Unsurprisingly we need to balance between the quality (the model complexity) and the speed.\n Cited as:\n@article{weng2017detection1, title = \u0026quot;Object Detection for Dummies Part 1: Gradient Vector, HOG, and SS\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2017\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2017-10-29-object-recognition-part-1/\u0026quot; } References [1] Dalal, Navneet, and Bill Triggs. \u0026ldquo;Histograms of oriented gradients for human detection.\u0026quot; Computer Vision and Pattern Recognition (CVPR), 2005.\n[2] Pedro F. Felzenszwalb, and Daniel P. Huttenlocher. \u0026ldquo;Efficient graph-based image segmentation.\u0026quot; Intl. journal of computer vision 59.2 (2004): 167-181.\n[3] Histogram of Oriented Gradients by Satya Mallick\n[4] Gradient Vectors by Chris McCormick\n[5] HOG Person Detector Tutorial by Chris McCormick\n","permalink":"https://lilianweng.github.io/posts/2017-10-29-object-recognition-part-1/","summary":"I\u0026rsquo;ve never worked in the field of computer vision and has no idea how the magic could work when an autonomous car is configured to tell apart a stop sign from a pedestrian in a red hat. To motivate myself to look into the maths behind object recognition and detection algorithms, I\u0026rsquo;m writing a few posts on this topic \u0026ldquo;Object Detection for Dummies\u0026rdquo;. This post, part 1, starts with super rudimentary concepts in image processing and a few methods for image segmentation.","title":"Object Detection for Dummies Part 1: Gradient Vector, HOG, and SS"},{"content":"Human vocabulary comes in free text. In order to make a machine learning model understand and process the natural language, we need to transform the free-text words into numeric values. One of the simplest transformation approaches is to do a one-hot encoding in which each distinct word stands for one dimension of the resulting vector and a binary value indicates whether the word presents (1) or not (0).\nHowever, one-hot encoding is impractical computationally when dealing with the entire vocabulary, as the representation demands hundreds of thousands of dimensions. Word embedding represents words and phrases in vectors of (non-binary) numeric values with much lower and thus denser dimensions. An intuitive assumption for good word embedding is that they can approximate the similarity between words (i.e., \u0026ldquo;cat\u0026rdquo; and \u0026ldquo;kitten\u0026rdquo; are similar words, and thus they are expected to be close in the reduced vector space) or disclose hidden semantic relationships (i.e., the relationship between \u0026ldquo;cat\u0026rdquo; and \u0026ldquo;kitten\u0026rdquo; is an analogy to the one between \u0026ldquo;dog\u0026rdquo; and \u0026ldquo;puppy\u0026rdquo;). Contextual information is super useful for learning word meaning and relationship, as similar words may appear in the similar context often.\nThere are two main approaches for learning word embedding, both relying on the contextual knowledge.\n Count-based: The first one is unsupervised, based on matrix factorization of a global word co-occurrence matrix. Raw co-occurrence counts do not work well, so we want to do smart things on top. Context-based: The second approach is supervised. Given a local context, we want to design a model to predict the target words and in the meantime, this model learns the efficient word embedding representation. Count-Based Vector Space Model Count-based vector space models heavily rely on the word frequency and co-occurrence matrix with the assumption that words in the same contexts share similar or related semantic meanings. The models map count-based statistics like co-occurrences between neighboring words down to a small and dense word vectors. PCA, topic models, and neural probabilistic language models are all good examples of this category.\n Different from the count-based approaches, context-based methods build predictive models that directly target at predicting a word given its neighbors. The dense word vectors are part of the model parameters. The best vector representation of each word is learned during the model training process.\nContext-Based: Skip-Gram Model Suppose that you have a sliding window of a fixed size moving along a sentence: the word in the middle is the \u0026ldquo;target\u0026rdquo; and those on its left and right within the sliding window are the context words. The skip-gram model (Mikolov et al., 2013) is trained to predict the probabilities of a word being a context word for the given target.\nThe following example demonstrates multiple pairs of target and context words as training samples, generated by a 5-word window sliding along the sentence.\n \u0026ldquo;The man who passes the sentence should swing the sword.\u0026rdquo; \u0026ndash; Ned Stark\n Sliding window (size = 5) Target word Context [The man who] the man, who [The man who passes] man the, who, passes [The man who passes the] who the, man, passes, the [man who passes the sentence] passes man, who, the, sentence \u0026hellip; \u0026hellip; \u0026hellip; [sentence should swing the sword] swing sentence, should, the, sword [should swing the sword] the should, swing, sword [swing the sword] sword swing, the {:.info} Each context-target pair is treated as a new observation in the data. For example, the target word \u0026ldquo;swing\u0026rdquo; in the above case produces four training samples: (\u0026ldquo;swing\u0026rdquo;, \u0026ldquo;sentence\u0026rdquo;), (\u0026ldquo;swing\u0026rdquo;, \u0026ldquo;should\u0026rdquo;), (\u0026ldquo;swing\u0026rdquo;, \u0026ldquo;the\u0026rdquo;), and (\u0026ldquo;swing\u0026rdquo;, \u0026ldquo;sword\u0026rdquo;).\nFig. 1. The skip-gram model. Both the input vector $\\mathbf{x}$ and the output $\\mathbf{y}$ are one-hot encoded word representations. The hidden layer is the word embedding of size $N$. Given the vocabulary size $V$, we are about to learn word embedding vectors of size $N$. The model learns to predict one context word (output) using one target word (input) at a time.\nAccording to Fig. 1,\n Both input word $w_i$ and the output word $w_j$ are one-hot encoded into binary vectors $\\mathbf{x}$ and $\\mathbf{y}$ of size $V$. First, the multiplication of the binary vector $\\mathbf{x}$ and the word embedding matrix $W$ of size $V \\times N$ gives us the embedding vector of the input word $w_i$: the i-th row of the matrix $W$. This newly discovered embedding vector of dimension $N$ forms the hidden layer. The multiplication of the hidden layer and the word context matrix $W’$ of size $N \\times V$ produces the output one-hot encoded vector $\\mathbf{y}$. The output context matrix $W’$ encodes the meanings of words as context, different from the embedding matrix $W$. NOTE: Despite the name, $W’$ is independent of $W$, not a transpose or inverse or whatsoever. Context-Based: Continuous Bag-of-Words (CBOW) The Continuous Bag-of-Words (CBOW) is another similar model for learning word vectors. It predicts the target word (i.e. \u0026ldquo;swing\u0026rdquo;) from source context words (i.e., \u0026ldquo;sentence should the sword\u0026rdquo;).\nFig. 2. The CBOW model. Word vectors of multiple context words are averaged to get a fixed-length vector as in the hidden layer. Other symbols have the same meanings as in Fig 1. Because there are multiple contextual words, we average their corresponding word vectors, constructed by the multiplication of the input vector and the matrix $W$. Because the averaging stage smoothes over a lot of the distributional information, some people believe the CBOW model is better for small dataset.\nLoss Functions Both the skip-gram model and the CBOW model should be trained to minimize a well-designed loss/objective function. There are several loss functions we can incorporate to train these language models. In the following discussion, we will use the skip-gram model as an example to describe how the loss is computed.\nFull Softmax The skip-gram model defines the embedding vector of every word by the matrix $W$ and the context vector by the output matrix $W'$. Given an input word $w_I$, let us label the corresponding row of $W$ as vector $v_{w_I}$ (embedding vector) and its corresponding column of $W'$ as $v'_{w_I}$ (context vector). The final output layer applies softmax to compute the probability of predicting the output word $w_O$ given $w_I$, and therefore:\n $$ p(w_O \\vert w_I) = \\frac{\\exp({v'_{w_O}}^{\\top} v_{w_I})}{\\sum_{i=1}^V \\exp({v'_{w_i}}^{\\top} v_{w_I})} $$ This is accurate as presented in Fig. 1. However, when $V$ is extremely large, calculating the denominator by going through all the words for every single sample is computationally impractical. The demand for more efficient conditional probability estimation leads to the new methods like hierarchical softmax.\nHierarchical Softmax Morin and Bengio (2005) proposed hierarchical softmax to make the sum calculation faster with the help of a binary tree structure. The hierarchical softmax encodes the language model\u0026rsquo;s output softmax layer into a tree hierarchy, where each leaf is one word and each internal node stands for relative probabilities of the children nodes.\nFig. 3. An illustration of the hierarchical softmax binary tree. The leaf nodes in white are words in the vocabulary. The gray inner nodes carry information on the probabilities of reaching its child nodes. One path starting from the root to the leaf $w\\_i$. $n(w\\_i, j)$ denotes the j-th node on this path. (Image source: word2vec Parameter Learning Explained) Each word $w_i$ has a unique path from the root down to its corresponding leaf. The probability of picking this word is equivalent to the probability of taking this path from the root down through the tree branches. Since we know the embedding vector $v_n$ of the internal node $n$, the probability of getting the word can be computed by the product of taking left or right turn at every internal node stop.\nAccording to Fig. 3, the probability of one node is ($\\sigma$ is the sigmoid function):\n $$ \\begin{align} p(\\text{turn right} \\to \\dots w_I \\vert n) \u0026= \\sigma({v'_n}^{\\top} v_{w_I})\\\\ p(\\text{turn left } \\to \\dots w_I \\vert n) \u0026= 1 - p(\\text{turn right} \\vert n) = \\sigma(-{v'_n}^{\\top} v_{w_I}) \\end{align} $$ The final probability of getting a context word $w_O$ given an input word $w_I$ is:\n $$ p(w_O \\vert w_I) = \\prod_{k=1}^{L(w_O)} \\sigma(\\mathbb{I}_{\\text{turn}}(n(w_O, k), n(w_O, k+1)) \\cdot {v'_{n(w_O, k)}}^{\\top} v_{w_I}) $$ where $L(w_O)$ is the depth of the path leading to the word $w_O$ and $\\mathbb{I}_{\\text{turn}}$ is a specially indicator function which returns 1 if $n(w_O, k+1)$ is the left child of $n(w_O, k)$ otherwise -1. The internal nodes' embeddings are learned during the model training. The tree structure helps greatly reduce the complexity of the denominator estimation from O(V) (vocabulary size) to O(log V) (the depth of the tree) at the training time. However, at the prediction time, we still to compute the probability of every word and pick the best, as we don\u0026rsquo;t know which leaf to reach for in advance.\nA good tree structure is crucial to the model performance. Several handy principles are: group words by frequency like what is implemented by Huffman tree for simple speedup; group similar words into same or close branches (i.e. use predefined word clusters, WordNet).\nCross Entropy Another approach completely steers away from the softmax framework. Instead, the loss function measures the cross entropy between the predicted probabilities $p$ and the true binary labels $\\mathbf{y}$.\nFirst, let\u0026rsquo;s recall that the cross entropy between two distributions $p$ and $q$ is measured as $ H(p, q) = -\\sum_x p(x) \\log q(x) $. In our case, the true label $y_i$ is 1 only when $w_i$ is the output word; $y_j$ is 0 otherwise. The loss function $\\mathcal{L}_\\theta$ of the model with parameter config $\\theta$ aims to minimize the cross entropy between the prediction and the ground truth, as lower cross entropy indicates high similarity between two distributions.\n $$ \\mathcal{L}_\\theta = - \\sum_{i=1}^V y_i \\log p(w_i | w_I) = - \\log p(w_O \\vert w_I) $$ Recall that,\n $$ p(w_O \\vert w_I) = \\frac{\\exp({v'_{w_O}}^{\\top} v_{w_I})}{\\sum_{i=1}^V \\exp({v'_{w_i}}^{\\top} v_{w_I})} $$ Therefore,\n $$ \\mathcal{L}_{\\theta} = - \\log \\frac{\\exp({v'_{w_O}}^{\\top}{v_{w_I}})}{\\sum_{i=1}^V \\exp({v'_{w_i}}^{\\top}{v_{w_I} })} = - {v'_{w_O}}^{\\top}{v_{w_I} } + \\log \\sum_{i=1}^V \\exp({v'_{w_i} }^{\\top}{v_{w_I}}) $$ To start training the model using back-propagation with SGD, we need to compute the gradient of the loss function. For simplicity, let\u0026rsquo;s label $z_{IO} = {v'_{w_O}}^{\\top}{v_{w_I}}$.\n $$ \\begin{align} \\nabla_\\theta \\mathcal{L}_{\\theta} \u0026= \\nabla_\\theta\\big( - z_{IO} + \\log \\sum_{i=1}^V e^{z_{Ii}} \\big) \\\\ \u0026= - \\nabla_\\theta z_{IO} + \\nabla_\\theta \\big( \\log \\sum_{i=1}^V e^{z_{Ii}} \\big) \\\\ \u0026= - \\nabla_\\theta z_{IO} + \\frac{1}{\\sum_{i=1}^V e^{z_{Ii}}} \\sum_{i=1}^V e^{z_{Ii}} \\nabla_\\theta z_{Ii} \\\\ \u0026= - \\nabla_\\theta z_{IO} + \\sum_{i=1}^V \\frac{e^{z_{Ii}}}{\\sum_{i=1}^V e^{z_{Ii}}} \\nabla_\\theta z_{Ii} \\\\ \u0026= - \\nabla_\\theta z_{IO} + \\sum_{i=1}^V p(w_i \\vert w_I) \\nabla_\\theta z_{Ii} \\\\ \u0026= - \\nabla_\\theta z_{IO} + \\mathbb{E}_{w_i \\sim Q(\\tilde{w})} \\nabla_\\theta z_{Ii} \\end{align} $$ where $Q(\\tilde{w})$ is the distribution of noise samples.\nAccording to the formula above, the correct output word has a positive reinforcement according to the first term (the larger $\\nabla_\\theta z_{IO}$ the better loss we have), while other words have a negative impact as captured by the second term.\nHow to estimate $\\mathbb{E}_{w_i \\sim Q(\\tilde{w})} \\nabla_\\theta {v'_{w_i}}^{\\top}{v_{w_I}}$ with a sample set of noise words rather than scanning through the entire vocabulary is the key of using cross-entropy-based sampling approach.\nNoise Contrastive Estimation (NCE) The Noise Contrastive Estimation (NCE) metric intends to differentiate the target word from noise samples using a logistic regression classifier (Gutmann and Hyvärinen, 2010).\nGiven an input word $w_I$, the correct output word is known as $w$. In the meantime, we sample $N$ other words from the noise sample distribution $Q$, denoted as $\\tilde{w}_1, \\tilde{w}_2, \\dots, \\tilde{w}_N \\sim Q$. Let\u0026rsquo;s label the decision of the binary classifier as $d$ and $d$$ can only take a binary value.\n $$ \\mathcal{L}_\\theta = - [ \\log p(d=1 \\vert w, w_I) + \\sum_{i=1, \\tilde{w}_i \\sim Q}^N \\log p(d=0|\\tilde{w}_i, w_I) ] $$ When $N$ is big enough, according to the Law of large numbers,\n $$ \\mathcal{L}_\\theta = - [ \\log p(d=1 \\vert w, w_I) + N\\mathbb{E}_{\\tilde{w}_i \\sim Q} \\log p(d=0|\\tilde{w}_i, w_I)] $$ To compute the probability $p(d=1 \\vert w, w_I)$, we can start with the joint probability $p(d, w \\vert w_I)$. Among $w, \\tilde{w}_1, \\tilde{w}_2, \\dots, \\tilde{w}_N$, we have 1 out of (N+1) chance to pick the true word $w$, which is sampled from the conditional probability $p(w \\vert w_I)$; meanwhile, we have N out of (N+1) chances to pick a noise word, each sampled from $q(\\tilde{w}) \\sim Q$. Thus,\n $$ p(d, w | w_I) = \\begin{cases} \\frac{1}{N+1} p(w \\vert w_I) \u0026 \\text{if } d=1 \\\\ \\frac{N}{N+1} q(\\tilde{w}) \u0026 \\text{if } d=0 \\end{cases} $$ Then we can figure out $p(d=1 \\vert w, w_I)$ and $p(d=0 \\vert w, w_I)$:\n $$ \\begin{align} p(d=1 \\vert w, w_I) \u0026= \\frac{p(d=1, w \\vert w_I)}{p(d=1, w \\vert w_I) + p(d=0, w \\vert w_I)} \u0026= \\frac{p(w \\vert w_I)}{p(w \\vert w_I) + Nq(\\tilde{w})} \\end{align} $$ $$ \\begin{align} p(d=0 \\vert w, w_I) \u0026= \\frac{p(d=0, w \\vert w_I)}{p(d=1, w \\vert w_I) + p(d=0, w \\vert w_I)} \u0026= \\frac{Nq(\\tilde{w})}{p(w \\vert w_I) + Nq(\\tilde{w})} \\end{align} $$ Finally the loss function of NCE\u0026rsquo;s binary classifier becomes:\n $$ \\begin{align} \\mathcal{L}_\\theta \u0026 = - [ \\log p(d=1 \\vert w, w_I) + \\sum_{\\substack{i=1 \\\\ \\tilde{w}_i \\sim Q}}^N \\log p(d=0|\\tilde{w}_i, w_I)] \\\\ \u0026 = - [ \\log \\frac{p(w \\vert w_I)}{p(w \\vert w_I) + Nq(\\tilde{w})} + \\sum_{\\substack{i=1 \\\\ \\tilde{w}_i \\sim Q}}^N \\log \\frac{Nq(\\tilde{w}_i)}{p(w \\vert w_I) + Nq(\\tilde{w}_i)}] \\end{align} $$ However, $p(w \\vert w_I)$ still involves summing up the entire vocabulary in the denominator. Let’s label the denominator as a partition function of the input word, $Z(w_I)$. A common assumption is $Z(w) \\approx 1$ given that we expect the softmax output layer to be normalized (Minh and Teh, 2012). Then the loss function is simplified to:\n $$ \\mathcal{L}_\\theta = - [ \\log \\frac{\\exp({v'_w}^{\\top}{v_{w_I}})}{\\exp({v'_w}^{\\top}{v_{w_I}}) + Nq(\\tilde{w})} + \\sum_{\\substack{i=1 \\\\ \\tilde{w}_i \\sim Q}}^N \\log \\frac{Nq(\\tilde{w}_i)}{\\exp({v'_w}^{\\top}{v_{w_I}}) + Nq(\\tilde{w}_i)}] $$ The noise distribution $Q$ is a tunable parameter and we would like to design it in a way so that:\n intuitively it should be very similar to the real data distribution; and it should be easy to sample from. For example, the sampling implementation (log_uniform_candidate_sampler) of NCE loss in tensorflow assumes that such noise samples follow a log-uniform distribution, also known as Zipfian’s law. The probability of a given word in logarithm is expected to be reversely proportional to its rank, while high-frequency words are assigned with lower ranks. In this case, $q(\\tilde{w}) = \\frac{1}{ \\log V}(\\log (r_{\\tilde{w}} + 1) - \\log r_{\\tilde{w}})$, where $r_{\\tilde{w}} \\in [1, V]$ is the rank of a word by frequency in descending order.\nNegative Sampling (NEG) The Negative Sampling (NEG) proposed by Mikolov et al. (2013) is a simplified variation of NCE loss. It is especially famous for training Google\u0026rsquo;s word2vec project. Different from NCE Loss which attempts to approximately maximize the log probability of the softmax output, negative sampling did further simplification because it focuses on learning high-quality word embedding rather than modeling the word distribution in natural language.\nNEG approximates the binary classifier\u0026rsquo;s output with sigmoid functions as follows:\n $$ \\begin{align} p(d=1 \\vert w_, w_I) \u0026= \\sigma({v'_{w}}^\\top v_{w_I}) \\\\ p(d=0 \\vert w, w_I) \u0026= 1 - \\sigma({v'_{w}}^\\top v_{w_I}) = \\sigma(-{v'_{w}}^\\top v_{w_I}) \\end{align} $$ The final NCE loss function looks like:\n $$ \\mathcal{L}_\\theta = - [ \\log \\sigma({v'_{w}}^\\top v_{w_I}) + \\sum_{\\substack{i=1 \\\\ \\tilde{w}_i \\sim Q}}^N \\log \\sigma(-{v'_{\\tilde{w}_i}}^\\top v_{w_I})] $$ Other Tips for Learning Word Embedding Mikolov et al. (2013) suggested several helpful practices that could result in good word embedding learning outcomes.\n Soft sliding window. When pairing the words within the sliding window, we could assign less weight to more distant words. One heuristic is \u0026mdash; given a maximum window size parameter defined, $s_{\\text{max}}$, the actual window size is randomly sampled between 1 and $s_{\\text{max}}$ for every training sample. Thus, each context word has the probability of 1/(its distance to the target word) being observed, while the adjacent words are always observed.\n Subsampling frequent words. Extremely frequent words might be too general to differentiate the context (i.e. think about stopwords). While on the other hand, rare words are more likely to carry distinct information. To balance the frequent and rare words, Mikolov et al. proposed to discard words $w$ with probability $1-\\sqrt{t/f(w)}$ during sampling. Here $f(w)$ is the word frequency and $t$ is an adjustable threshold.\n Learning phrases first. A phrase often stands as a conceptual unit, rather than a simple composition of individual words. For example, we cannot really tell \u0026ldquo;New York\u0026rdquo; is a city name even we know the meanings of \u0026ldquo;new\u0026rdquo; and \u0026ldquo;york\u0026rdquo;. Learning such phrases first and treating them as word units before training the word embedding model improves the outcome quality. A simple data-driven approach is based on unigram and bigram counts: $s_{\\text{phrase}} = \\frac{C(w_i w_j) - \\delta}{ C(w_i)C(w_j)}$, where $C(.)$ is simple count of an unigram $w_i$ or bigram $w_i w_j$ and $\\delta$ is a discounting threshold to prevent super infrequent words and phrases. Higher scores indicate higher chances of being phrases. To form phrases longer than two words, we can scan the vocabulary multiple times with decreasing score cutoff values.\n GloVe: Global Vectors The Global Vector (GloVe) model proposed by Pennington et al. (2014) aims to combine the count-based matrix factorization and the context-based skip-gram model together.\nWe all know the counts and co-occurrences can reveal the meanings of words. To distinguish from $p(w_O \\vert w_I)$ in the context of a word embedding word, we would like to define the co-ocurrence probability as:\n $$ p_{\\text{co}}(w_k \\vert w_i) = \\frac{C(w_i, w_k)}{C(w_i)} $$ $C(w_i, w_k)$ counts the co-occurrence between words $w_i$ and $w_k$.\nSay, we have two words, $w_i$=\u0026ldquo;ice\u0026rdquo; and $w_j$=\u0026ldquo;steam\u0026rdquo;. The third word $\\tilde{w}_k$=\u0026ldquo;solid\u0026rdquo; is related to \u0026ldquo;ice\u0026rdquo; but not \u0026ldquo;steam\u0026rdquo;, and thus we expect $p_{\\text{co}}(\\tilde{w}_k \\vert w_i)$ to be much larger than $p_{\\text{co}}(\\tilde{w}_k \\vert w_j)$ and therefore $\\frac{p_{\\text{co}}(\\tilde{w}_k \\vert w_i)}{p_{\\text{co}}(\\tilde{w}_k \\vert w_j)}$ to be very large. If the third word $\\tilde{w}_k$ = \u0026ldquo;water\u0026rdquo; is related to both or $\\tilde{w}_k$ = \u0026ldquo;fashion\u0026rdquo; is unrelated to either of them, $\\frac{p_{\\text{co}}(\\tilde{w}_k \\vert w_i)}{p_{\\text{co}}(\\tilde{w}_k \\vert w_j)}$ is expected to be close to one.\nThe intuition here is that the word meanings are captured by the ratios of co-occurrence probabilities rather than the probabilities themselves. The global vector models the relationship between two words regarding to the third context word as:\n $$ F(w_i, w_j, \\tilde{w}_k) = \\frac{p_{\\text{co}}(\\tilde{w}_k \\vert w_i)}{p_{\\text{co}}(\\tilde{w}_k \\vert w_j)} $$ Further, since the goal is to learn meaningful word vectors, $F$ is designed to be a function of the linear difference between two words $w_i - w_j$:\n $$ F((w_i - w_j)^\\top \\tilde{w}_k) = \\frac{p_{\\text{co}}(\\tilde{w}_k \\vert w_i)}{p_{\\text{co}}(\\tilde{w}_k \\vert w_j)} $$ With the consideration of $F$ being symmetric between target words and context words, the final solution is to model $F$ as an exponential function. Please read the original paper (Pennington et al., 2014) for more details of the equations.\n $$ \\begin{align} F({w_i}^\\top \\tilde{w}_k) \u0026= \\exp({w_i}^\\top \\tilde{w}_k) = p_{\\text{co}}(\\tilde{w}_k \\vert w_i) \\\\ F((w_i - w_j)^\\top \\tilde{w}_k) \u0026= \\exp((w_i - w_j)^\\top \\tilde{w}_k) = \\frac{\\exp(w_i^\\top \\tilde{w}_k)}{\\exp(w_j^\\top \\tilde{w}_k)} = \\frac{p_{\\text{co}}(\\tilde{w}_k \\vert w_i)}{p_{\\text{co}}(\\tilde{w}_k \\vert w_j)} \\end{align} $$ Finally,\n $$ {w_i}^\\top \\tilde{w}_k = \\log p_{\\text{co}}(\\tilde{w}_k \\vert w_i) = \\log \\frac{C(w_i, \\tilde{w}_k)}{C(w_i)} = \\log C(w_i, \\tilde{w}_k) - \\log C(w_i) $$ Since the second term $-\\log C(w_i)$ is independent of $k$, we can add bias term $b_i$ for $w_i$ to capture $-\\log C(w_i)$. To keep the symmetric form, we also add in a bias $\\tilde{b}_k$ for $\\tilde{w}_k$.\n $$ \\log C(w_i, \\tilde{w}_k) = {w_i}^\\top \\tilde{w}_k + b_i + \\tilde{b}_k $$ The loss function for the GloVe model is designed to preserve the above formula by minimizing the sum of the squared errors:\n $$ \\mathcal{L}_\\theta = \\sum_{i=1, j=1}^V f(C(w_i,w_j)) ({w_i}^\\top \\tilde{w}_j + b_i + \\tilde{b}_j - \\log C(w_i, \\tilde{w}_j))^2 $$ The weighting schema $f(c)$ is a function of the co-occurrence of $w_i$ and $w_j$ and it is an adjustable model configuration. It should be close to zero as $c \\to 0$; should be non-decreasing as higher co-occurrence should have more impact; should saturate when $c$ become extremely large. The paper proposed the following weighting function.\n $$ f(c) = \\begin{cases} (\\frac{c}{c_{\\max}})^\\alpha \u0026 \\text{if } c Examples: word2vec on \u0026ldquo;Game of Thrones\u0026rdquo; After reviewing all the theoretical knowledge above, let\u0026rsquo;s try a little experiment in word embedding extracted from \u0026ldquo;the Games of Thrones corpus\u0026rdquo;. The process is super straightforward using gensim.\nStep 1: Extract words\nimport sys from nltk.corpus import stopwords from nltk.tokenize import sent_tokenize STOP_WORDS = set(stopwords.words(\u0026#39;english\u0026#39;)) def get_words(txt): return filter( lambda x: x not in STOP_WORDS, re.findall(r\u0026#39;\\b(\\w+)\\b\u0026#39;, txt) ) def parse_sentence_words(input_file_names): \u0026#34;\u0026#34;\u0026#34;Returns a list of a list of words. Each sublist is a sentence.\u0026#34;\u0026#34;\u0026#34; sentence_words = [] for file_name in input_file_names: for line in open(file_name): line = line.strip().lower() line = line.decode(\u0026#39;unicode_escape\u0026#39;).encode(\u0026#39;ascii\u0026#39;,\u0026#39;ignore\u0026#39;) sent_words = map(get_words, sent_tokenize(line)) sent_words = filter(lambda sw: len(sw) \u0026gt; 1, sent_words) if len(sent_words) \u0026gt; 1: sentence_words += sent_words return sentence_words # You would see five .txt files after unzip \u0026#39;a_song_of_ice_and_fire.zip\u0026#39; input_file_names = [\u0026#34;001ssb.txt\u0026#34;, \u0026#34;002ssb.txt\u0026#34;, \u0026#34;003ssb.txt\u0026#34;, \u0026#34;004ssb.txt\u0026#34;, \u0026#34;005ssb.txt\u0026#34;] GOT_SENTENCE_WORDS= parse_sentence_words(input_file_names) Step 2: Feed a word2vec model\nfrom gensim.models import Word2Vec # size: the dimensionality of the embedding vectors. # window: the maximum distance between the current and predicted word within a sentence. model = Word2Vec(GOT_SENTENCE_WORDS, size=128, window=3, min_count=5, workers=4) model.wv.save_word2vec_format(\u0026#34;got_word2vec.txt\u0026#34;, binary=False) Step 3: Check the results\nIn the GoT word embedding space, the top similar words to \u0026ldquo;king\u0026rdquo; and \u0026ldquo;queen\u0026rdquo; are:\n model.most_similar('king', topn=10) (word, similarity with \u0026lsquo;king\u0026rsquo;) model.most_similar('queen', topn=10) (word, similarity with \u0026lsquo;queen\u0026rsquo;) (\u0026lsquo;kings\u0026rsquo;, 0.897245) (\u0026lsquo;cersei\u0026rsquo;, 0.942618) (\u0026lsquo;baratheon\u0026rsquo;, 0.809675) (\u0026lsquo;joffrey\u0026rsquo;, 0.933756) (\u0026lsquo;son\u0026rsquo;, 0.763614) (\u0026lsquo;margaery\u0026rsquo;, 0.931099) (\u0026lsquo;robert\u0026rsquo;, 0.708522) (\u0026lsquo;sister\u0026rsquo;, 0.928902) (\u0026lsquo;lords\u0026rsquo;, 0.698684) (\u0026lsquo;prince\u0026rsquo;, 0.927364) (\u0026lsquo;joffrey\u0026rsquo;, 0.696455) (\u0026lsquo;uncle\u0026rsquo;, 0.922507) (\u0026lsquo;prince\u0026rsquo;, 0.695699) (\u0026lsquo;varys\u0026rsquo;, 0.918421) (\u0026lsquo;brother\u0026rsquo;, 0.685239) (\u0026lsquo;ned\u0026rsquo;, 0.917492) (\u0026lsquo;aerys\u0026rsquo;, 0.684527) (\u0026lsquo;melisandre\u0026rsquo;, 0.915403) (\u0026lsquo;stannis\u0026rsquo;, 0.682932) (\u0026lsquo;robb\u0026rsquo;, 0.915272) Cited as:\n@article{weng2017wordembedding, title = \u0026quot;Learning word embedding\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2017\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2017-10-15-word-embedding/\u0026quot; } References [1] Tensorflow Tutorial Vector Representations of Words.\n[2] \u0026ldquo;Word2Vec Tutorial - The Skip-Gram Model\u0026rdquo; by Chris McCormick.\n[3] \u0026ldquo;On word embeddings - Part 2: Approximating the Softmax\u0026rdquo; by Sebastian Ruder.\n[4] Xin Rong. word2vec Parameter Learning Explained\n[5] Mikolov, Tomas, Kai Chen, Greg Corrado, and Jeffrey Dean. \u0026ldquo;Efficient estimation of word representations in vector space.\u0026quot; arXiv preprint arXiv:1301.3781 (2013).\n[6] Frederic Morin and Yoshua Bengio. \u0026ldquo;Hierarchical Probabilistic Neural Network Language Model.\u0026quot; Aistats. Vol. 5. 2005.\n[7] Michael Gutmann and Aapo Hyvärinen. \u0026ldquo;Noise-contrastive estimation: A new estimation principle for unnormalized statistical models.\u0026quot; Proc. Intl. Conf. on Artificial Intelligence and Statistics. 2010.\n[8] Tomas Mikolov, Ilya Sutskever, Kai Chen, Greg Corrado, and Jeffrey Dean. \u0026ldquo;Distributed representations of words and phrases and their compositionality.\u0026quot; Advances in neural information processing systems. 2013.\n[9] Tomas Mikolov, Kai Chen, Greg Corrado, and Jeffrey Dean. \u0026ldquo;Efficient estimation of word representations in vector space.\u0026quot; arXiv preprint arXiv:1301.3781 (2013).\n[10] Marco Baroni, Georgiana Dinu, and Germán Kruszewski. \u0026ldquo;Don\u0026rsquo;t count, predict! A systematic comparison of context-counting vs. context-predicting semantic vectors.\u0026quot; ACL (1). 2014.\n[11] Jeffrey Pennington, Richard Socher, and Christopher Manning. \u0026ldquo;Glove: Global vectors for word representation.\u0026quot; Proc. Conf. on empirical methods in natural language processing (EMNLP). 2014.\n","permalink":"https://lilianweng.github.io/posts/2017-10-15-word-embedding/","summary":"Human vocabulary comes in free text. In order to make a machine learning model understand and process the natural language, we need to transform the free-text words into numeric values. One of the simplest transformation approaches is to do a one-hot encoding in which each distinct word stands for one dimension of the resulting vector and a binary value indicates whether the word presents (1) or not (0).\nHowever, one-hot encoding is impractical computationally when dealing with the entire vocabulary, as the representation demands hundreds of thousands of dimensions.","title":"Learning Word Embedding"},{"content":"Professor Naftali Tishby passed away in 2021. Hope the post can introduce his cool idea of information bottleneck to more people.\nRecently I watched the talk \u0026ldquo;Information Theory in Deep Learning\u0026rdquo; by Prof Naftali Tishby and found it very interesting. He presented how to apply the information theory to study the growth and transformation of deep neural networks during training. Using the Information Bottleneck (IB) method, he proposed a new learning bound for deep neural networks (DNN), as the traditional learning theory fails due to the exponentially large number of parameters. Another keen observation is that DNN training involves two distinct phases: First, the network is trained to fully represent the input data and minimize the generalization error; then, it learns to forget the irrelevant details by compressing the representation of the input.\nMost of the materials in this post are from Prof Tishby’s talk and related papers.\nBasic Concepts Markov Chain\nA Markov process is a \u0026ldquo;memoryless\u0026rdquo; (also called \u0026ldquo;Markov Property\u0026rdquo;) stochastic process. A Markov chain is a type of Markov process containing multiple discrete states. That is being said, the conditional probability of future states of the process is only determined by the current state and does not depend on the past states.\nKullback–Leibler (KL) Divergence\nKL divergence measures how one probability distribution $p$ diverges from a second expected probability distribution $q$. It is asymmetric.\n $$ \\begin{aligned} D_{KL}(p \\| q) \u0026= \\sum_x p(x) \\log \\frac{p(x)}{q(x)} \\\\ \u0026= - \\sum_x p(x)\\log q(x) + \\sum_x p(x)\\log p(x) \\\\ \u0026= H(P, Q) - H(P) \\end{aligned} $$ $D_{KL}$ achieves the minimum zero when $p(x)$ == $q(x)$ everywhere.\nMutual Information\nMutual information measures the mutual dependence between two variables. It quantifies the \u0026ldquo;amount of information\u0026rdquo; obtained about one random variable through the other random variable. Mutual information is symmetric.\n $$ \\begin{aligned} I(X;Y) \u0026= D_{KL}[p(x,y) \\| p(x)p(y)] \\\\ \u0026= \\sum_{x \\in X, y \\in Y} p(x, y) \\log(\\frac{p(x, y)}{p(x)p(y)}) \\\\ \u0026= \\sum_{x \\in X, y \\in Y} p(x, y) \\log(\\frac{p(x|y)}{p(x)}) \\\\ \u0026= H(X) - H(X|Y) \\\\ \\end{aligned} $$ Data Processing Inequality (DPI)\nFor any markov chain: $X \\to Y \\to Z$, we would have $I(X; Y) \\geq I(X; Z)$.\nA deep neural network can be viewed as a Markov chain, and thus when we are moving down the layers of a DNN, the mutual information between the layer and the input can only decrease.\nReparametrization invariance\nFor two invertible functions $\\phi$, $\\psi$, the mutual information still holds: $I(X; Y) = I(\\phi(X); \\psi(Y))$.\nFor example, if we shuffle the weights in one layer of DNN, it would not affect the mutual information between this layer and another.\nDeep Neural Networks as Markov Chains The training data contains sampled observations from the joint distribution of $X$ and $Y$. The input variable $X$ and weights of hidden layers are all high-dimensional random variable. The ground truth target $Y$ and the predicted value $\\hat{Y}$ are random variables of smaller dimensions in the classification settings.\nFig. 1. The structure of a deep neural network, which consists of the target label $Y$, input layer $X$, hidden layers $h\\_1, \\dots, h\\_m$ and the final prediction $\\hat{Y}$. (Image source: Tishby and Zaslavsky, 2015) If we label the hidden layers of a DNN as $h_1, h_2, \\dots, h_m$ as in Fig. 1, we can view each layer as one state of a Markov Chain: $ h_i \\to h_{i+1}$. According to DPI, we would have:\n $$ \\begin{aligned} H(X) \\geq I(X; h_1) \\geq I(X; h_2) \\geq \\dots \\geq I(X; h_m) \\geq I(X; \\hat{Y}) \\\\ I(X; Y) \\geq I(h_1; Y) \\geq I(h_2; Y) \\geq \\dots \\geq I(h_m; Y) \\geq I(\\hat{Y}; Y) \\end{aligned} $$ A DNN is designed to learn how to describe $X$ to predict $Y$ and eventually, to compress $X$ to only hold the information related to $Y$. Tishby describes this processing as \u0026ldquo;successive refinement of relevant information\u0026rdquo;.\nInformation Plane Theorem A DNN has successive internal representations of $X$, a set of hidden layers $\\{T_i\\}$. The information plane theorem characterizes each layer by its encoder and decoder information. The encoder is a representation of the input data $X$, while the decoder translates the information in the current layer to the target ouput $Y$.\nPrecisely, in an information plane plot:\n X-axis: The sample complexity of $T_i$ is determined by the encoder mutual information $I(X; T_i)$. Sample complexity refers to how many samples you need to achieve certain accuracy and generalization. Y-axis: The accuracy (generalization error) is determined by the decoder mutual information $I(T_i; Y)$. Fig. 2. The encoder vs decoder mutual information of DNN hidden layers of 50 experiments. Different layers are color-coders, with green being the layer right next to the input and the orange being the furthest. There are three snapshots, at the initial epoch, 400 epochs and 9000 epochs respectively. (Image source: Shwartz-Ziv and Tishby, 2017) Each dot in Fig. 2. marks the encoder/ decoder mutual information of one hidden layer of one network simulation (no regularization is applied; no weights decay, no dropout, etc.). They move up as expected because the knowledge about the true labels is increasing (accuracy increases). At the early stage, the hidden layers learn a lot about the input $X$, but later they start to compress to forget some information about the input. Tishby believes that \u0026ldquo;the most important part of learning is actually forgetting\u0026rdquo;. Check out this nice video that demonstrates how the mutual information measures of layers are changing in epoch time.\nFig. 3. Here is an aggregated view of Fig 2. The compression happens after the generalization error becomes very small. (Image source: Tishby’ talk 15:15) Two Optimization Phases Tracking the normalized mean and standard deviation of each layer\u0026rsquo;s weights in time also reveals two optimization phases of the training process.\nFig. 4. The norm of mean and standard deviation of each layer's weight gradients for each layer as a function of training epochs. Different layers are color-coded. (Image source: Shwartz-Ziv and Tishby, 2017) Among early epochs, the mean values are three magnitudes larger than the standard deviations. After a sufficient number of epochs, the error saturates and the standard deviations become much noisier afterward. The further a layer is away from the output, the noisier it gets, because the noises can get amplified and accumulated through the back-prop process (not due to the width of the layer).\nLearning Theory \u0026ldquo;Old\u0026rdquo; Generalization Bounds The generalization bounds defined by the classic learning theory is:\n $$ \\epsilon^2 $\\epsilon$: The difference between the training error and the generalization error. The generalization error measures how accurate the prediction of an algorithm is for previously unseen data. $H_\\epsilon$: $\\epsilon$-cover of the hypothesis class. Typically we assume the size $\\vert H_\\epsilon \\vert \\sim (1/\\epsilon)^d$. $\\delta$: Confidence. $m$: The number of training samples. $d$: The VC dimension of the hypothesis. This definition states that the difference between the training error and the generalization error is bounded by a function of the hypothesis space size and the dataset size. The bigger the hypothesis space gets, the bigger the generalization error becomes. I recommend this tutorial on ML theory, part1 and part2, if you are interested in reading more on generalization bounds.\nHowever, it does not work for deep learning. The larger a network is, the more parameters it needs to learn. With this generalization bounds, larger networks (larger $d$) would have worse bounds. This is contrary to the intuition that larger networks are able to achieve better performance with higher expressivity.\n\u0026ldquo;New\u0026rdquo; Input compression bound To solve this counterintuitive observation, Tishby et al. proposed a new input compression bound for DNN.\nFirst let us have $T_\\epsilon$ as an $\\epsilon$-partition of the input variable $X$. This partition compresses the input with respect to the homogeneity to the labels into small cells. The cells in total can cover the whole input space. If the prediction outputs binary values, we can replace the cardinality of the hypothesis, $\\vert H_\\epsilon \\vert$, with $2^{\\vert T_\\epsilon \\vert}$.\n $$ |H_\\epsilon| \\sim 2^{|X|} \\to 2^{|T_\\epsilon|} $$ When $X$ is large, the size of $X$ is approximately $2^{H(X)}$. Each cell in the $\\epsilon$-partition is of size $2^{H(X \\vert T_\\epsilon)}$. Therefore we have $\\vert T_\\epsilon \\vert \\sim \\frac{2^{H(X)}}{2^{H(X \\vert T_\\epsilon)}} = 2^{I(T_\\epsilon; X)}$. Then the input compression bound becomes:\n $$ \\epsilon^2 Fig. 5. The black line is the optimal achievable information bottleneck (IB) limit. The red line corresponds to the upper bound on the out-of-sample IB distortion, when trained on a finite sample set. $\\Delta C$ is the complexity gap and $\\Delta G$ is the generalization gap. (Recreated based on Tishby’ talk 24:50) Network Size and Training Data Size The Benefit of More Hidden Layers Having more layers give us computational benefits and speed up the training process for good generalization.\nFig. 6. The optimization time is much shorter (fewer epochs) with more hidden layers. (Image source: Shwartz-Ziv and Tishby, 2017) Compression through stochastic relaxation: According to the diffusion equation, the relaxation time of layer $k$ is proportional to the exponential of this layer\u0026rsquo;s compression amount $\\Delta S_k$: $\\Delta t_k \\sim \\exp(\\Delta S_k)$. We can compute the layer compression as $\\Delta S_k = I(X; T_k) - I(X; T_{k-1})$. Because $\\exp(\\sum_k \\Delta S_k) \\geq \\sum_k \\exp(\\Delta S_k)$, we would expect an exponential decrease in training epochs with more hidden layers (larger $k$).\nThe Benefit of More Training Samples Fitting more training data requires more information captured by the hidden layers. With increased training data size, the decoder mutual information (recall that this is directly related to the generalization error), $I(T; Y)$, is pushed up and gets closer to the theoretical information bottleneck bound. Tishby emphasized that It is the mutual information, not the layer size or the VC dimension, that determines generalization, different from standard theories.\nFig. 7. The training data of different sizes is color-coded. The information plane of multiple converged networks are plotted. More training data leads to better generalization. (Image source: Shwartz-Ziv and Tishby, 2017) Cited as:\n@article{weng2017infotheory, title = \u0026quot;Anatomize Deep Learning with Information Theory\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2017\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2017-09-28-information-bottleneck/\u0026quot; } References [1] Naftali Tishby. Information Theory of Deep Learning\n[2] Machine Learning Theory - Part 1: Introduction\n[3] Machine Learning Theory - Part 2: Generalization Bounds\n[4] New Theory Cracks Open the Black Box of Deep Learning by Quanta Magazine.\n[5] Naftali Tishby and Noga Zaslavsky. \u0026ldquo;Deep learning and the information bottleneck principle.\u0026quot; IEEE Information Theory Workshop (ITW), 2015.\n[6] Ravid Shwartz-Ziv and Naftali Tishby. \u0026ldquo;Opening the Black Box of Deep Neural Networks via Information.\u0026quot; arXiv preprint arXiv:1703.00810, 2017.\n","permalink":"https://lilianweng.github.io/posts/2017-09-28-information-bottleneck/","summary":"Professor Naftali Tishby passed away in 2021. Hope the post can introduce his cool idea of information bottleneck to more people.\nRecently I watched the talk \u0026ldquo;Information Theory in Deep Learning\u0026rdquo; by Prof Naftali Tishby and found it very interesting. He presented how to apply the information theory to study the growth and transformation of deep neural networks during training. Using the Information Bottleneck (IB) method, he proposed a new learning bound for deep neural networks (DNN), as the traditional learning theory fails due to the exponentially large number of parameters.","title":"Anatomize Deep Learning with Information Theory"},{"content":"[Updated on 2018-09-30: thanks to Yoonju, we have this post translated in Korean!] [Updated on 2019-04-18: this post is also available on arXiv.]\nGenerative adversarial network (GAN) has shown great results in many generative tasks to replicate the real-world rich content such as images, human language, and music. It is inspired by game theory: two models, a generator and a critic, are competing with each other while making each other stronger at the same time. However, it is rather challenging to train a GAN model, as people are facing issues like training instability or failure to converge.\nHere I would like to explain the maths behind the generative adversarial network framework, why it is hard to be trained, and finally introduce a modified version of GAN intended to solve the training difficulties.\nKullback–Leibler and Jensen–Shannon Divergence Before we start examining GANs closely, let us first review two metrics for quantifying the similarity between two probability distributions.\n(1) KL (Kullback–Leibler) divergence measures how one probability distribution $p$ diverges from a second expected probability distribution $q$.\n $$ D_{KL}(p \\| q) = \\int_x p(x) \\log \\frac{p(x)}{q(x)} dx $$ $D_{KL}$ achieves the minimum zero when $p(x)$ == $q(x)$ everywhere.\nIt is noticeable according to the formula that KL divergence is asymmetric. In cases where $p(x)$ is close to zero, but $q(x)$ is significantly non-zero, the $q$\u0026rsquo;s effect is disregarded. It could cause buggy results when we just want to measure the similarity between two equally important distributions.\n(2) Jensen–Shannon Divergence is another measure of similarity between two probability distributions, bounded by $[0, 1]$. JS divergence is symmetric (yay!) and more smooth. Check this Quora post if you are interested in reading more about the comparison between KL divergence and JS divergence.\n $$ D_{JS}(p \\| q) = \\frac{1}{2} D_{KL}(p \\| \\frac{p + q}{2}) + \\frac{1}{2} D_{KL}(q \\| \\frac{p + q}{2}) $$ Fig. 1. Given two Gaussian distribution, $p$ with mean=0 and std=1 and $q$ with mean=1 and std=1. The average of two distributions is labelled as $m=(p+q)/2$. KL divergence $D_{KL}$ is asymmetric but JS divergence $D_{JS}$ is symmetric. Some believe (Huszar, 2015) that one reason behind GANs' big success is switching the loss function from asymmetric KL divergence in traditional maximum-likelihood approach to symmetric JS divergence. We will discuss more on this point in the next section.\nGenerative Adversarial Network (GAN) GAN consists of two models:\n A discriminator $D$ estimates the probability of a given sample coming from the real dataset. It works as a critic and is optimized to tell the fake samples from the real ones. A generator $G$ outputs synthetic samples given a noise variable input $z$ ($z$ brings in potential output diversity). It is trained to capture the real data distribution so that its generative samples can be as real as possible, or in other words, can trick the discriminator to offer a high probability. Fig. 2. Architecture of a generative adversarial network. (Image source: www.kdnuggets.com/2017/01/generative-...-learning.html) These two models compete against each other during the training process: the generator $G$ is trying hard to trick the discriminator, while the critic model $D$ is trying hard not to be cheated. This interesting zero-sum game between two models motivates both to improve their functionalities.\nGiven,\n Symbol Meaning Notes $p_{z}$ Data distribution over noise input $z$ Usually, just uniform. $p_{g}$ The generator\u0026rsquo;s distribution over data $x$ $p_{r}$ Data distribution over real sample $x$ On one hand, we want to make sure the discriminator $D$\u0026rsquo;s decisions over real data are accurate by maximizing $\\mathbb{E}_{x \\sim p_{r}(x)} [\\log D(x)]$. Meanwhile, given a fake sample $G(z), z \\sim p_z(z)$, the discriminator is expected to output a probability, $D(G(z))$, close to zero by maximizing $\\mathbb{E}_{z \\sim p_{z}(z)} [\\log (1 - D(G(z)))]$.\nOn the other hand, the generator is trained to increase the chances of $D$ producing a high probability for a fake example, thus to minimize $\\mathbb{E}_{z \\sim p_{z}(z)} [\\log (1 - D(G(z)))]$.\nWhen combining both aspects together, $D$ and $G$ are playing a minimax game in which we should optimize the following loss function:\n $$ \\begin{aligned} \\min_G \\max_D L(D, G) \u0026 = \\mathbb{E}_{x \\sim p_{r}(x)} [\\log D(x)] + \\mathbb{E}_{z \\sim p_z(z)} [\\log(1 - D(G(z)))] \\\\ \u0026 = \\mathbb{E}_{x \\sim p_{r}(x)} [\\log D(x)] + \\mathbb{E}_{x \\sim p_g(x)} [\\log(1 - D(x)] \\end{aligned} $$ ($\\mathbb{E}_{x \\sim p_{r}(x)} [\\log D(x)]$ has no impact on $G$ during gradient descent updates.)\nWhat is the optimal value for D? Now we have a well-defined loss function. Let\u0026rsquo;s first examine what is the best value for $D$.\n $$ L(G, D) = \\int_x \\bigg( p_{r}(x) \\log(D(x)) + p_g (x) \\log(1 - D(x)) \\bigg) dx $$ Since we are interested in what is the best value of $D(x)$ to maximize $L(G, D)$, let us label\n $$ \\tilde{x} = D(x), A=p_{r}(x), B=p_g(x) $$ And then what is inside the integral (we can safely ignore the integral because $x$ is sampled over all the possible values) is:\n $$ \\begin{aligned} f(\\tilde{x}) \u0026 = A log\\tilde{x} + B log(1-\\tilde{x}) \\\\ \\frac{d f(\\tilde{x})}{d \\tilde{x}} \u0026 = A \\frac{1}{ln10} \\frac{1}{\\tilde{x}} - B \\frac{1}{ln10} \\frac{1}{1 - \\tilde{x}} \\\\ \u0026 = \\frac{1}{ln10} (\\frac{A}{\\tilde{x}} - \\frac{B}{1-\\tilde{x}}) \\\\ \u0026 = \\frac{1}{ln10} \\frac{A - (A + B)\\tilde{x}}{\\tilde{x} (1 - \\tilde{x})} \\\\ \\end{aligned} $$ Thus, set $\\frac{d f(\\tilde{x})}{d \\tilde{x}} = 0$, we get the best value of the discriminator: $D^*(x) = \\tilde{x}^* = \\frac{A}{A + B} = \\frac{p_{r}(x)}{p_{r}(x) + p_g(x)} \\in [0, 1]$.\nOnce the generator is trained to its optimal, $p_g$ gets very close to $p_{r}$. When $p_g = p_{r}$, $D^*(x)$ becomes $1/2$.\nWhat is the global optimal? When both $G$ and $D$ are at their optimal values, we have $p_g = p_{r}$ and $D^*(x) = 1/2$ and the loss function becomes:\n $$ \\begin{aligned} L(G, D^*) \u0026= \\int_x \\bigg( p_{r}(x) \\log(D^*(x)) + p_g (x) \\log(1 - D^*(x)) \\bigg) dx \\\\ \u0026= \\log \\frac{1}{2} \\int_x p_{r}(x) dx + \\log \\frac{1}{2} \\int_x p_g(x) dx \\\\ \u0026= -2\\log2 \\end{aligned} $$ What does the loss function represent? According to the formula listed in the previous section, JS divergence between $p_{r}$ and $p_g$ can be computed as:\n $$ \\begin{aligned} D_{JS}(p_{r} \\| p_g) =\u0026 \\frac{1}{2} D_{KL}(p_{r} || \\frac{p_{r} + p_g}{2}) + \\frac{1}{2} D_{KL}(p_{g} || \\frac{p_{r} + p_g}{2}) \\\\ =\u0026 \\frac{1}{2} \\bigg( \\log2 + \\int_x p_{r}(x) \\log \\frac{p_{r}(x)}{p_{r} + p_g(x)} dx \\bigg) + \\\\\u0026 \\frac{1}{2} \\bigg( \\log2 + \\int_x p_g(x) \\log \\frac{p_g(x)}{p_{r} + p_g(x)} dx \\bigg) \\\\ =\u0026 \\frac{1}{2} \\bigg( \\log4 + L(G, D^*) \\bigg) \\end{aligned} $$ Thus,\n $$ L(G, D^*) = 2D_{JS}(p_{r} \\| p_g) - 2\\log2 $$ Essentially the loss function of GAN quantifies the similarity between the generative data distribution $p_g$ and the real sample distribution $p_{r}$ by JS divergence when the discriminator is optimal. The best $G^*$ that replicates the real data distribution leads to the minimum $L(G^*, D^*) = -2\\log2$ which is aligned with equations above.\n Other Variations of GAN: There are many variations of GANs in different contexts or designed for different tasks. For example, for semi-supervised learning, one idea is to update the discriminator to output real class labels, $1, \\dots, K-1$, as well as one fake class label $K$. The generator model aims to trick the discriminator to output a classification label smaller than $K$.\n Tensorflow Implementation: carpedm20/DCGAN-tensorflow\nProblems in GANs Although GAN has shown great success in the realistic image generation, the training is not easy; The process is known to be slow and unstable.\nHard to achieve Nash equilibrium Salimans et al. (2016) discussed the problem with GAN\u0026rsquo;s gradient-descent-based training procedure. Two models are trained simultaneously to find a Nash equilibrium to a two-player non-cooperative game. However, each model updates its cost independently with no respect to another player in the game. Updating the gradient of both models concurrently cannot guarantee a convergence.\nLet\u0026rsquo;s check out a simple example to better understand why it is difficult to find a Nash equilibrium in an non-cooperative game. Suppose one player takes control of $x$ to minimize $f_1(x) = xy$, while at the same time the other player constantly updates $y$ to minimize $f_2(y) = -xy$.\nBecause $\\frac{\\partial f_1}{\\partial x} = y$ and $\\frac{\\partial f_2}{\\partial y} = -x$, we update $x$ with $x-\\eta \\cdot y$ and $y$ with $y+ \\eta \\cdot x$ simulitanously in one iteration, where $\\eta$ is the learning rate. Once $x$ and $y$ have different signs, every following gradient update causes huge oscillation and the instability gets worse in time, as shown in Fig. 3.\nFig. 3. A simulation of our example for updating $x$ to minimize $xy$ and updating $y$ to minimize $-xy$. The learning rate $\\eta = 0.1$. With more iterations, the oscillation grows more and more unstable. Low dimensional supports Term Explanation Manifold A topological space that locally resembles Euclidean space near each point. Precisely, when this Euclidean space is of dimension $n$, the manifold is referred as $n$-manifold. Support A real-valued function $f$ is the subset of the domain containing those elements which are not mapped to zero. Arjovsky and Bottou (2017) discussed the problem of the supports of $p_r$ and $p_g$ lying on low dimensional manifolds and how it contributes to the instability of GAN training thoroughly in a very theoretical paper \u0026ldquo;Towards principled methods for training generative adversarial networks\u0026rdquo;.\nThe dimensions of many real-world datasets, as represented by $p_r$, only appear to be artificially high. They have been found to concentrate in a lower dimensional manifold. This is actually the fundamental assumption for Manifold Learning. Thinking of the real world images, once the theme or the contained object is fixed, the images have a lot of restrictions to follow, i.e., a dog should have two ears and a tail, and a skyscraper should have a straight and tall body, etc. These restrictions keep images aways from the possibility of having a high-dimensional free form.\n$p_g$ lies in a low dimensional manifolds, too. Whenever the generator is asked to a much larger image like 64x64 given a small dimension, such as 100, noise variable input $z$, the distribution of colors over these 4096 pixels has been defined by the small 100-dimension random number vector and can hardly fill up the whole high dimensional space.\nBecause both $p_g$ and $p_r$ rest in low dimensional manifolds, they are almost certainly gonna be disjoint (See Fig. 4). When they have disjoint supports, we are always capable of finding a perfect discriminator that separates real and fake samples 100% correctly. Check the paper if you are curious about the proof.\nFig. 4. Low dimensional manifolds in high dimension space can hardly have overlaps. (Left) Two lines in a three-dimension space. (Right) Two surfaces in a three-dimension space. Vanishing gradient When the discriminator is perfect, we are guaranteed with $D(x) = 1, \\forall x \\in p_r$ and $D(x) = 0, \\forall x \\in p_g$. Therefore the loss function $L$ falls to zero and we end up with no gradient to update the loss during learning iterations. Fig. 5 demonstrates an experiment when the discriminator gets better, the gradient vanishes fast.\nFig. 5. First, a DCGAN is trained for 1, 10 and 25 epochs. Then, with the **generator fixed**, a discriminator is trained from scratch and measure the gradients with the original cost function. We see the gradient norms **decay quickly** (in log scale), in the best case 5 orders of magnitude after 4000 discriminator iterations. (Image source: Arjovsky and Bottou, 2017) As a result, training a GAN faces a dilemma:\n If the discriminator behaves badly, the generator does not have accurate feedback and the loss function cannot represent the reality. If the discriminator does a great job, the gradient of the loss function drops down to close to zero and the learning becomes super slow or even jammed. This dilemma clearly is capable to make the GAN training very tough.\nMode collapse During the training, the generator may collapse to a setting where it always produces same outputs. This is a common failure case for GANs, commonly referred to as Mode Collapse. Even though the generator might be able to trick the corresponding discriminator, it fails to learn to represent the complex real-world data distribution and gets stuck in a small space with extremely low variety.\nFig. 6. A DCGAN model is trained with an MLP network with 4 layers, 512 units and ReLU activation function, configured to lack a strong inductive bias for image generation. The results shows a significant degree of mode collapse. (Image source: Arjovsky, Chintala, \u0026 Bottou, 2017.) Lack of a proper evaluation metric Generative adversarial networks are not born with a good objection function that can inform us the training progress. Without a good evaluation metric, it is like working in the dark. No good sign to tell when to stop; No good indicator to compare the performance of multiple models.\nImproved GAN Training The following suggestions are proposed to help stabilize and improve the training of GANs.\nFirst five methods are practical techniques to achieve faster convergence of GAN training, proposed in \u0026ldquo;Improve Techniques for Training GANs\u0026rdquo;. The last two are proposed in \u0026ldquo;Towards principled methods for training generative adversarial networks\u0026rdquo; to solve the problem of disjoint distributions.\n(1) Feature Matching\nFeature matching suggests to optimize the discriminator to inspect whether the generator\u0026rsquo;s output matches expected statistics of the real samples. In such a scenario, the new loss function is defined as $| \\mathbb{E}_{x \\sim p_r} f(x) - \\mathbb{E}_{z \\sim p_z(z)}f(G(z)) |_2^2 $, where $f(x)$ can be any computation of statistics of features, such as mean or median.\n(2) Minibatch Discrimination\nWith minibatch discrimination, the discriminator is able to digest the relationship between training data points in one batch, instead of processing each point independently.\nIn one minibatch, we approximate the closeness between every pair of samples, $c(x_i, x_j)$, and get the overall summary of one data point by summing up how close it is to other samples in the same batch, $o(x_i) = \\sum_{j} c(x_i, x_j)$. Then $o(x_i)$ is explicitly added to the input of the model.\n(3) Historical Averaging\nFor both models, add $ | \\Theta - \\frac{1}{t} \\sum_{i=1}^t \\Theta_i |^2 $ into the loss function, where $\\Theta$ is the model parameter and $\\Theta_i$ is how the parameter is configured at the past training time $i$. This addition piece penalizes the training speed when $\\Theta$ is changing too dramatically in time.\n(4) One-sided Label Smoothing\nWhen feeding the discriminator, instead of providing 1 and 0 labels, use soften values such as 0.9 and 0.1. It is shown to reduce the networks' vulnerability.\n(5) Virtual Batch Normalization (VBN)\nEach data sample is normalized based on a fixed batch (\u0026ldquo;reference batch\u0026rdquo;) of data rather than within its minibatch. The reference batch is chosen once at the beginning and stays the same through the training.\nTheano Implementation: openai/improved-gan\n(6) Adding Noises.\nBased on the discussion in the previous section, we now know $p_r$ and $p_g$ are disjoint in a high dimensional space and it causes the problem of vanishing gradient. To artificially \u0026ldquo;spread out\u0026rdquo; the distribution and to create higher chances for two probability distributions to have overlaps, one solution is to add continuous noises onto the inputs of the discriminator $D$.\n(7) Use Better Metric of Distribution Similarity\nThe loss function of the vanilla GAN measures the JS divergence between the distributions of $p_r$ and $p_g$. This metric fails to provide a meaningful value when two distributions are disjoint.\nWasserstein metric is proposed to replace JS divergence because it has a much smoother value space. See more in the next section.\nWasserstein GAN (WGAN) What is Wasserstein distance? Wasserstein Distance is a measure of the distance between two probability distributions. It is also called Earth Mover\u0026rsquo;s distance, short for EM distance, because informally it can be interpreted as the minimum energy cost of moving and transforming a pile of dirt in the shape of one probability distribution to the shape of the other distribution. The cost is quantified by: the amount of dirt moved x the moving distance.\nLet us first look at a simple case where the probability domain is discrete. For example, suppose we have two distributions $P$ and $Q$, each has four piles of dirt and both have ten shovelfuls of dirt in total. The numbers of shovelfuls in each dirt pile are assigned as follows:\n $$ P_1 = 3, P_2 = 2, P_3 = 1, P_4 = 4\\\\ Q_1 = 1, Q_2 = 2, Q_3 = 4, Q_4 = 3 $$ In order to change $P$ to look like $Q$, as illustrated in Fig. 7, we:\n First move 2 shovelfuls from $P_1$ to $P_2$ =\u0026gt; $(P_1, Q_1)$ match up. Then move 2 shovelfuls from $P_2$ to $P_3$ =\u0026gt; $(P_2, Q_2)$ match up. Finally move 1 shovelfuls from $Q_3$ to $Q_4$ =\u0026gt; $(P_3, Q_3)$ and $(P_4, Q_4)$ match up. If we label the cost to pay to make $P_i$ and $Q_i$ match as $\\delta_i$, we would have $\\delta_{i+1} = \\delta_i + P_i - Q_i$ and in the example:\n $$ \\begin{aligned} \\delta_0 \u0026= 0\\\\ \\delta_1 \u0026= 0 + 3 - 1 = 2\\\\ \\delta_2 \u0026= 2 + 2 - 2 = 2\\\\ \\delta_3 \u0026= 2 + 1 - 4 = -1\\\\ \\delta_4 \u0026= -1 + 4 - 3 = 0 \\end{aligned} $$ Finally the Earth Mover\u0026rsquo;s distance is $W = \\sum \\vert \\delta_i \\vert = 5$.\nFig. 7. Step-by-step plan of moving dirt between piles in $P$ and $Q$ to make them match. When dealing with the continuous probability domain, the distance formula becomes:\n $$ W(p_r, p_g) = \\inf_{\\gamma \\sim \\Pi(p_r, p_g)} \\mathbb{E}_{(x, y) \\sim \\gamma}[\\| x-y \\|] $$ In the formula above, $\\Pi(p_r, p_g)$ is the set of all possible joint probability distributions between $p_r$ and $p_g$. One joint distribution $\\gamma \\in \\Pi(p_r, p_g)$ describes one dirt transport plan, same as the discrete example above, but in the continuous probability space. Precisely $\\gamma(x, y)$ states the percentage of dirt should be transported from point $x$ to $y$ so as to make $x$ follows the same probability distribution of $y$. That\u0026rsquo;s why the marginal distribution over $x$ adds up to $p_g$, $\\sum_{x} \\gamma(x, y) = p_g(y)$ (Once we finish moving the planned amount of dirt from every possible $x$ to the target $y$, we end up with exactly what $y$ has according to $p_g$.) and vice versa $\\sum_{y} \\gamma(x, y) = p_r(x)$.\nWhen treating $x$ as the starting point and $y$ as the destination, the total amount of dirt moved is $\\gamma(x, y)$ and the travelling distance is $| x-y |$ and thus the cost is $\\gamma(x, y) \\cdot | x-y |$. The expected cost averaged across all the $(x,y)$ pairs can be easily computed as:\n $$ \\sum_{x, y} \\gamma(x, y) \\| x-y \\| = \\mathbb{E}_{x, y \\sim \\gamma} \\| x-y \\| $$ Finally, we take the minimum one among the costs of all dirt moving solutions as the EM distance. In the definition of Wasserstein distance, the $\\inf$ (infimum, also known as greatest lower bound) indicates that we are only interested in the smallest cost.\nWhy Wasserstein is better than JS or KL divergence? Even when two distributions are located in lower dimensional manifolds without overlaps, Wasserstein distance can still provide a meaningful and smooth representation of the distance in-between.\nThe WGAN paper exemplified the idea with a simple example.\nSuppose we have two probability distributions, $P$ and $Q$:\n $$ \\forall (x, y) \\in P, x = 0 \\text{ and } y \\sim U(0, 1)\\\\ \\forall (x, y) \\in Q, x = \\theta, 0 \\leq \\theta \\leq 1 \\text{ and } y \\sim U(0, 1)\\\\ $$ Fig. 8. There is no overlap between $P$ and $Q$ when $\\theta \\neq 0$. When $\\theta \\neq 0$:\n $$ \\begin{aligned} D_{KL}(P \\| Q) \u0026= \\sum_{x=0, y \\sim U(0, 1)} 1 \\cdot \\log\\frac{1}{0} = +\\infty \\\\ D_{KL}(Q \\| P) \u0026= \\sum_{x=\\theta, y \\sim U(0, 1)} 1 \\cdot \\log\\frac{1}{0} = +\\infty \\\\ D_{JS}(P, Q) \u0026= \\frac{1}{2}(\\sum_{x=0, y \\sim U(0, 1)} 1 \\cdot \\log\\frac{1}{1/2} + \\sum_{x=0, y \\sim U(0, 1)} 1 \\cdot \\log\\frac{1}{1/2}) = \\log 2\\\\ W(P, Q) \u0026= |\\theta| \\end{aligned} $$ But when $\\theta = 0$, two distributions are fully overlapped:\n $$ \\begin{aligned} D_{KL}(P \\| Q) \u0026= D_{KL}(Q \\| P) = D_{JS}(P, Q) = 0\\\\ W(P, Q) \u0026= 0 = \\lvert \\theta \\rvert \\end{aligned} $$ $D_{KL}$ gives us inifity when two distributions are disjoint. The value of $D_{JS}$ has sudden jump, not differentiable at $\\theta = 0$. Only Wasserstein metric provides a smooth measure, which is super helpful for a stable learning process using gradient descents.\nUse Wasserstein distance as GAN loss function It is intractable to exhaust all the possible joint distributions in $\\Pi(p_r, p_g)$ to compute $\\inf_{\\gamma \\sim \\Pi(p_r, p_g)}$. Thus the authors proposed a smart transformation of the formula based on the Kantorovich-Rubinstein duality to:\n $$ W(p_r, p_g) = \\frac{1}{K} \\sup_{\\| f \\|_L \\leq K} \\mathbb{E}_{x \\sim p_r}[f(x)] - \\mathbb{E}_{x \\sim p_g}[f(x)] $$ where $\\sup$ (supremum) is the opposite of $inf$ (infimum); we want to measure the least upper bound or, in even simpler words, the maximum value.\nLipschitz continuity?\nThe function $f$ in the new form of Wasserstein metric is demanded to satisfy $| f |_L \\leq K$, meaning it should be K-Lipschitz continuous.\nA real-valued function $f: \\mathbb{R} \\rightarrow \\mathbb{R}$ is called $K$-Lipschitz continuous if there exists a real constant $K \\geq 0$ such that, for all $x_1, x_2 \\in \\mathbb{R}$,\n$$ \\lvert f(x_1) - f(x_2) \\rvert \\leq K \\lvert x_1 - x_2 \\rvert $$\nHere $K$ is known as a Lipschitz constant for function $f(.)$. Functions that are everywhere continuously differentiable is Lipschitz continuous, because the derivative, estimated as $\\frac{\\lvert f(x_1) - f(x_2) \\rvert}{\\lvert x_1 - x_2 \\rvert}$, has bounds. However, a Lipschitz continuous function may not be everywhere differentiable, such as $f(x) = \\lvert x \\rvert$.\nExplaining how the transformation happens on the Wasserstein distance formula is worthy of a long post by itself, so I skip the details here. If you are interested in how to compute Wasserstein metric using linear programming, or how to transfer Wasserstein metric into its dual form according to the Kantorovich-Rubinstein Duality, read this awesome post.\nSuppose this function $f$ comes from a family of K-Lipschitz continuous functions, $\\{ f_w \\}_{w \\in W}$, parameterized by $w$. In the modified Wasserstein-GAN, the \u0026ldquo;discriminator\u0026rdquo; model is used to learn $w$ to find a good $f_w$ and the loss function is configured as measuring the Wasserstein distance between $p_r$ and $p_g$.\n $$ L(p_r, p_g) = W(p_r, p_g) = \\max_{w \\in W} \\mathbb{E}_{x \\sim p_r}[f_w(x)] - \\mathbb{E}_{z \\sim p_r(z)}[f_w(g_\\theta(z))] $$ Thus the \u0026ldquo;discriminator\u0026rdquo; is not a direct critic of telling the fake samples apart from the real ones anymore. Instead, it is trained to learn a $K$-Lipschitz continuous function to help compute Wasserstein distance. As the loss function decreases in the training, the Wasserstein distance gets smaller and the generator model\u0026rsquo;s output grows closer to the real data distribution.\nOne big problem is to maintain the $K$-Lipschitz continuity of $f_w$ during the training in order to make everything work out. The paper presents a simple but very practical trick: After every gradient update, clamp the weights $w$ to a small window, such as $[-0.01, 0.01]$, resulting in a compact parameter space $W$ and thus $f_w$ obtains its lower and upper bounds to preserve the Lipschitz continuity.\nFig. 9. Algorithm of Wasserstein generative adversarial network. (Image source: Arjovsky, Chintala, \u0026 Bottou, 2017.) Compared to the original GAN algorithm, the WGAN undertakes the following changes:\n After every gradient update on the critic function, clamp the weights to a small fixed range, $[-c, c]$. Use a new loss function derived from the Wasserstein distance, no logarithm anymore. The \u0026ldquo;discriminator\u0026rdquo; model does not play as a direct critic but a helper for estimating the Wasserstein metric between real and generated data distribution. Empirically the authors recommended RMSProp optimizer on the critic, rather than a momentum based optimizer such as Adam which could cause instability in the model training. I haven\u0026rsquo;t seen clear theoretical explanation on this point through. Sadly, Wasserstein GAN is not perfect. Even the authors of the original WGAN paper mentioned that \u0026ldquo;Weight clipping is a clearly terrible way to enforce a Lipschitz constraint\u0026rdquo; (Oops!). WGAN still suffers from unstable training, slow convergence after weight clipping (when clipping window is too large), and vanishing gradients (when clipping window is too small).\nSome improvement, precisely replacing weight clipping with gradient penalty, has been discussed in Gulrajani et al. 2017. I will leave this to a future post.\nExample: Create New Pokemons! Just for fun, I tried out carpedm20/DCGAN-tensorflow on a tiny dataset, Pokemon sprites. The dataset only has 900-ish pokemon images, including different levels of same pokemon species.\nLet\u0026rsquo;s check out what types of new pokemons the model is able to create. Unfortunately due to the tiny training data, the new pokemons only have rough shapes without details. The shapes and colors do look better with more training epoches! Hooray!\nFig. 10. Train carpedm20/DCGAN-tensorflow on a set of Pokemon sprite images. The sample outputs are listed after training epoches = 7, 21, 49. If you are interested in a commented version of carpedm20/DCGAN-tensorflow and how to modify it to train WGAN and WGAN with gradient penalty, check lilianweng/unified-gan-tensorflow.\n Cited as:\n@article{weng2017gan, title = \u0026quot;From GAN to WGAN\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2017\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2017-08-20-gan/\u0026quot; } OR\n@misc{weng2019gan, title={From GAN to WGAN}, author={Lilian Weng}, year={2019}, eprint={1904.08994}, archivePrefix={arXiv}, primaryClass={cs.LG} } References [1] Goodfellow, Ian, et al. \u0026ldquo;Generative adversarial nets.\u0026quot; NIPS, 2014.\n[2] Tim Salimans, et al. \u0026ldquo;Improved techniques for training gans.\u0026quot; NIPS 2016.\n[3] Martin Arjovsky and Léon Bottou. \u0026ldquo;Towards principled methods for training generative adversarial networks.\u0026quot; arXiv preprint arXiv:1701.04862 (2017).\n[4] Martin Arjovsky, Soumith Chintala, and Léon Bottou. \u0026ldquo;Wasserstein GAN.\u0026quot; arXiv preprint arXiv:1701.07875 (2017).\n[5] Ishaan Gulrajani, Faruk Ahmed, Martin Arjovsky, Vincent Dumoulin, Aaron Courville. Improved training of wasserstein gans. arXiv preprint arXiv:1704.00028 (2017).\n[6] Computing the Earth Mover\u0026rsquo;s Distance under Transformations\n[7] Wasserstein GAN and the Kantorovich-Rubinstein Duality\n[8] zhuanlan.zhihu.com/p/25071913\n[9] Ferenc Huszár. \u0026ldquo;How (not) to Train your Generative Model: Scheduled Sampling, Likelihood, Adversary?.\u0026quot; arXiv preprint arXiv:1511.05101 (2015).\n","permalink":"https://lilianweng.github.io/posts/2017-08-20-gan/","summary":"[Updated on 2018-09-30: thanks to Yoonju, we have this post translated in Korean!] [Updated on 2019-04-18: this post is also available on arXiv.]\nGenerative adversarial network (GAN) has shown great results in many generative tasks to replicate the real-world rich content such as images, human language, and music. It is inspired by game theory: two models, a generator and a critic, are competing with each other while making each other stronger at the same time.","title":"From GAN to WGAN"},{"content":"The machine learning models have started penetrating into critical areas like health care, justice systems, and financial industry. Thus to figure out how the models make the decisions and make sure the decisioning process is aligned with the ethnic requirements or legal regulations becomes a necessity.\nMeanwhile, the rapid growth of deep learning models pushes the requirement of interpreting complicated models further. People are eager to apply the power of AI fully on key aspects of everyday life. However, it is hard to do so without enough trust in the models or an efficient procedure to explain unintended behavior, especially considering that the deep neural networks are born as black-boxes.\nThink of the following cases:\n The financial industry is highly regulated and loan issuers are required by law to make fair decisions and explain their credit models to provide reasons whenever they decide to decline loan application. Medical diagnosis model is responsible for human life. How can we be confident enough to treat a patient as instructed by a black-box model? When using a criminal decision model to predict the risk of recidivism at the court, we have to make sure the model behaves in an equitable, honest and nondiscriminatory manner. If a self-driving car suddenly acts abnormally and we cannot explain why, are we gonna be comfortable enough to use the technique in real traffic in large scale? At Affirm, we are issuing tens of thousands of installment loans every day and our underwriting model has to provide declination reasons when the model rejects one\u0026rsquo;s loan application. That\u0026rsquo;s one of the many motivations for me to dig deeper and write this post. Model interpretability is a big field in machine learning. This review is never met to exhaust every study, but to serve as a starting point.\n Interpretable Models Lipton (2017) summarized the properties of an interpretable model in a theoretical review paper, \u0026ldquo;The mythos of model interpretability\u0026rdquo;: A human can repeat (\u0026ldquo;simulatability\u0026rdquo;) the computation process with a full understanding of the algorithm (\u0026ldquo;algorithmic transparency\u0026rdquo;) and every individual part of the model owns an intuitive explanation (\u0026ldquo;decomposability\u0026rdquo;).\nMany classic models have relatively simpler formation and naturally, come with a model-specific interpretation method. Meanwhile, new tools are being developed to help create better interpretable models (Been, Khanna, \u0026amp; Koyejo, 2016; Lakkaraju, Bach \u0026amp; Leskovec, 2016).\nRegression A general form of a linear regression model is:\n$$ y = w_0 + w_1 x_1 + w_2 x_2 + … + w_n x_n $$\nThe coefficients describe the change of the response triggered by one unit increase of the independent variables. The coefficients are not comparable directly unless the features have been standardized (check sklearn.preprocessing.StandardScalar and RobustScaler), since one unit of different features can refer to very different things. Without standardization, the product $w_i \\dot x_i$ can be used to quantify one feature\u0026rsquo;s contribution to the response.\nNaive Bayes Naive Bayes is named as \u0026ldquo;Naive\u0026rdquo; because it works on a very simplified assumption that features are independent of each other and each contributes to the output independently.\nGiven a feature vector $\\mathbf{x} = [x_1, x_2, \\dots, x_n]$ and a class label $c \\in \\{1, 2, \\dots, C\\}$, the probability of this data point belonging to this class is:\n $$ \\begin{aligned} p(c | x_1, x_2, \\dots, x_n) \u0026\\propto p(c, x_1, x_2, \\dots, x_n)\\\\ \u0026\\propto p(c) p(x_1 | c) p(x_2 | c) \\dots p(x_n | c)\\\\ \u0026\\propto p(c) \\prod_{i=1}^n p(x_i | c). \\end{aligned} $$ The Naive Bayes classifier is then defined as:\n$$ \\hat{y} = \\arg\\max_{c \\in 1, \\dots, C} p(c) \\prod_{i=1}^n p(x_i | c) $$\nBecause the model has learned the prior $p(x_i \\vert c)$ during the training, the contribution of an individual feature value can be easily measured by the posterior, $p(c \\vert x_i) = p(c)p(x_i \\vert c) / p(x_i)$.\nDecision Tree/Decision Lists Decision lists are a set of boolean functions, usually constructed by the syntax like if... then... else.... The if-condition contains a function involving one or multiple features and a boolean output. Decision lists are born with good interpretability and can be visualized in a tree structure. Many research on decision lists is driven by medical applications, where the interpretability is almost as crucial as the model itself.\nA few types of decision lists are briefly described below:\n Falling Rule Lists (FRL) (Wang and Rudin, 2015) has fully enforced monotonicity on feature values. One key point, for example in the binary classification context, is that the probability of prediction $Y=1$ associated with each rule decreases as one moves down the decision lists. Bayesian Rule List (BRL) (Letham et al., 2015) is a generative model that yields a posterior distribution over possible decision lists. Interpretable Decision Sets (IDS) (Lakkaraju, Bach \u0026amp; Leskovec, 2016) is a prediction framework to create a set of classification rules. The learning is optimized for both accuracy and interpretability simultaneously. IDS is closely related to the BETA method I\u0026rsquo;m gonna describe later for interpreting black-box models. Random Forests Weirdly enough, many people believe that the Random Forests model is a black box, which is not true. Considering that the output of random forests is the majority vote by a large number of independent decision trees and each tree is naturally interpretable.\nIt is not very hard to gauge the influence of individual features if we look into a single tree at a time. The global feature importance of random forests can be quantified by the total decrease in node impurity averaged over all trees of the ensemble (\u0026ldquo;mean decrease impurity\u0026rdquo;).\nFor one instance, because the decision paths in all the trees are well tracked, we can use the difference between the mean value of data points in a parent node between that of a child node to approximate the contribution of this split. Read more in this series of blog posts: Interpreting Random Forests.\nInterpreting Black-Box Models A lot of models are not designed to be interpretable. Approaches to explaining a black-box model aim to extract information from the trained model to justify its prediction outcome, without knowing how the model works in details. To keep the interpretation process independent from the model implementation is good for real-world applications: Even when the base model is being constantly upgraded and refined, the interpretation engine built on top would not worry about the changes.\nWithout the concern of keeping the model transparent and interpretable, we can endow the model with greater power of expressivity by adding more parameters and nonlinearity computation. That\u0026rsquo;s how deep neural networks become successful in tasks involving rich inputs.\nThere is no hard requirement on how the explanation should be presented, but the primary goal is mainly to answer: Can I trust this model? When we rely on the model to make a critical or life-and-death decision, we have to make sure the model is trustworthy ahead of time.\nThe interpretation framework should balance between two goals:\n Fidelity: the prediction produced by an explanation should agree with the original model as much as possible. Interpretability: the explanation should be simple enough to be human-understandable. Side Notes: The next three methods are designed for local interpretation.\n Prediction Decomposition Robnik-Sikonja and Kononenko (2008) proposed to explain the model prediction for one instance by measuring the difference between the original prediction and the one made with omitting a set of features.\nLet\u0026rsquo;s say we need to generate an explanation for a classification model $f: \\mathbf{X} \\rightarrow \\mathbf{Y}$. Given a data point $x \\in X$ which consists of $a$ individual values of attribute $A_i$, $i = 1, \\dots, a$, and is labeled with class $y \\in Y$. The prediction difference is quantified by computing the difference between the model predicted probabilities with or without knowing $A_i$:\n$$ \\text{probDiff}_i (y | x) = p(y| x) - p(y | x \\backslash A_i) $$\n(The paper also discussed on using the odds ratio or the entropy-based information metric to quantify the prediction difference.)\nProblem: If the target model outputs a probability, then great, getting $ p(y \\vert x) $ is straightforward. Otherwise, the model prediction has to run through an appropriate post-modeling calibration to translate the prediction score into probabilities. This calibration layer is another piece of complication.\nAnother problem: If we generate $x \\backslash A_i$ by replacing $A_i$ with a missing value (like None, NaN, etc.), we have to rely on the model\u0026rsquo;s internal mechanism for missing value imputation. A model which replaces these missing cases with the median should have output very different from a model which imputes a special placeholder. One solution as presented in the paper is to replace $A_i$ with all possible values of this feature and then sum up the prediction weighted by how likely each value shows in the data:\n $$ \\begin{aligned} p(y \\vert x \\backslash A_i) \u0026= \\sum_{s=1}^{m_i} p(A_i=a_s \\vert x \\backslash A_i) p(y \\vert x \\leftarrow A_i=a_s) \\\\ \u0026\\approx \\sum_{s=1}^{m_i} p(A_i=a_s) p(y \\vert x \\leftarrow A_i=a_s) \\end{aligned} $$ Where $p(y \\vert x \\leftarrow A_i=a_s)$ is the probability of getting label $y$ if we replace the feature $A_i$ with value $a_s$ in the feature vector of $x$. There are $m_i$ unique values of $A_i$ in the training set.\nWith the help of the measures of prediction difference when omitting known features, we can decompose the impact of each individual feature on the prediction.\nFig. 1. Explanations for a SVM model predicting the survival of one male adult first-class passenger in the Titanic dataset. The information difference is very similar to the probability difference, but it measures the amount of information necessary to find out $y$ is true for the given instance without the knowledge of $A\\_i$: $\\text{infDiff}\\_i (y|x) = \\log\\_2 p(y|x) - \\log\\_2 p(y|x \\backslash A\\_i)$. Explanations for particular instance are depicted with dark bars. The light shaded half-height bars are average positive and negative explanations for given attributes' values. In this case, being a male adult makes it very less likely to survive; the class level does not impact as much. Local Gradient Explanation Vector This method (Baehrens, et al. 2010) is able to explain the local decision taken by arbitrary nonlinear classification algorithms, using the local gradients that characterize how a data point has to be moved to change its predicted label.\nLet\u0026rsquo;s say, we have a Bayes Classifier which is trained on the data set $X$ and outputs probabilities over the class labels $Y$, $p(Y=y \\vert X=x)$. And one class label $y$ is drawn from the class label pool, $\\{1, 2, \\dots, C\\}$. This Bayes classifier is constructed as:\n$$ f^{*}(x) = \\arg \\min_{c \\in \\{1, \\dots, C\\}} p(Y \\neq c \\vert X = x) $$\nThe local explanation vector is defined as the derivative of the probability prediction function at the test point $x = x_0$. A large entry in this vector highlights a feature with a big influence on the model decision; A positive sign indicates that increasing the feature would lower the probability of $x_0$ assigned to $f^{*}(x_0)$.\nHowever, this approach requires the model output to be a probability (similar to the \u0026ldquo;Prediction Decomposition\u0026rdquo; method above). What if the original model (labelled as $f$) is not calibrated to yield probabilities? As suggested by the paper, we can approximate $f$ by another classifier in a form that resembles the Bayes classifier $f^{*}$:\n(1) Apply Parzen window to the training data to estimate the weighted class densities:\n$$ \\hat{p}_{\\sigma}(x, y=c) = \\frac{1}{n} \\sum_{i \\in I_c} k_{\\sigma} (x - x_i) $$\nWhere $I_c$ is the index set containing the indices of data points assigned to class $c$ by the model $f$, $I_c = \\{i \\vert f(x_i) = c\\}$. $k_{\\sigma}$ is a kernel function. Gaussian kernel is a popular one among many candidates.\n(2) Then, apply the Bayes' rule to approximate the probability $p(Y=c \\vert X=x)$ for all classes:\n $$ \\begin{aligned} \\hat{p}_{\\sigma}(y=c | x) \u0026= \\frac{\\hat{p}_{\\sigma}(x, y=c)}{\\hat{p}_{\\sigma}(x, y=c) + \\hat{p}_{\\sigma}(x, y \\neq c)} \\\\ \u0026\\approx \\frac{\\sum_{i \\in I_c} k_{\\sigma} (x - x_i)}{\\sum_i k_{\\sigma} (x - x_i)} \\end{aligned} $$ (3) The final estimated Bayes classifier takes the form:\n$$ \\hat{f}_{\\sigma} = \\arg\\min_{c \\in \\{1, \\dots, C\\}} \\hat{p}_{\\sigma}(y \\neq c \\vert x) $$\nNoted that we can generate the labeled data with the original model $f$, as much as we want, not restricted by the size of the training data. The hyperparameter $\\sigma$ is selected to optimize the chances of $\\hat{f}_{\\sigma}(x) = f(x)$ to achieve high fidelity.\nFig. 2. An example of how local gradient explanation vector is applied on simple object classification with Gaussian Processes Classifier (GPC). The GPC model outputs the probability by nature. (a) shows the training points and their labels in red (positive 1) and blue (negative -1). (b) illustrates a probability function for the positive class. (c-d) shows the local gradients and the directions of the local explanation vectors. Side notes: As you can see both the methods above require the model prediction to be a probability. Calibration of the model output adds another layer of complication.\n LIME (Local Interpretable Model-Agnostic Explanations) LIME, short for local interpretable model-agnostic explanation, can approximate a black-box model locally in the neighborhood of the prediction we are interested (Ribeiro, Singh, \u0026amp; Guestrin, 2016).\nSame as above, let us label the black-box model as $f$. LIME presents the following steps:\n(1) Convert the dataset into interpretable data representation: $x \\Rightarrow x_b$.\n Text classifier: a binary vector indicating the presence or absence of a word Image classifier: a binary vector indicating the presence or absence of a contiguous patch of similar pixels (super-pixel). Fig. 3. An example of converting an image into interpretable data representation. (Image source: www.oreilly.com/learning/introduction-to-local-interpretable-model-agnostic-explanations-lime) (2) Given a prediction $f(x)$ with the corresponding interpretable data representation $x_b$, let us sample instances around $x_b$ by drawing nonzero elements of $x_b$ uniformly at random where the number of such draws is also uniformly sampled. This process generates a perturbed sample $z_b$ which contains a fraction of nonzero elements of $x_b$.\nThen we recover $z_b$ back into the original input $z$ and get a prediction score $f(z)$ by the target model.\nUse many such sampled data points $z_b \\in \\mathcal{Z}_b$ and their model predictions, we can learn an explanation model (such as in a form as simple as a regression) with local fidelity. The sampled data points are weighted differently based on how close they are to $x_b$. The paper used a lasso regression with preprocessing to select top $k$ most significant features beforehand, named \u0026ldquo;K-LASSO\u0026rdquo;.\nFig. 4. The pink and blue areas are two classes predicted by the black-box model $f$. the big red cross is the point to be explained and other smaller crosses (predicted as pink by $f$) and dots (predicted as blue by $f$) are sampled data points. Even though the model can be very complicated, we are still able to learn a local explanation model as simple as the grey dash line. (Image source: homes.cs.washington.edu/~marcotcr/blog/lime) Examining whether the explanation makes sense can directly decide whether the model is trustworthy because sometimes the model can pick up spurious correlation or generalization. One interesting example in the paper is to apply LIME on an SVM text classifier for differentiating \u0026ldquo;Christianity\u0026rdquo; from \u0026ldquo;Atheism\u0026rdquo;. The model achieved a pretty good accuracy (94% on held-out testing set!), but the LIME explanation demonstrated that decisions were made by very arbitrary reasons, such as counting the words \u0026ldquo;re\u0026rdquo;, \u0026ldquo;posting\u0026rdquo; and \u0026ldquo;host\u0026rdquo; which have no connection with neither \u0026ldquo;Christianity\u0026rdquo; nor \u0026ldquo;Atheism\u0026rdquo; directly. After such a diagnosis, we learned that even the model gives us a nice accuracy, it cannot be trusted. It also shed lights on ways to improve the model, such as better preprocessing on the text.\nFig. 5. Illustration of how to use LIME on an image classifier. (Image source: www.oreilly.com/learning/introduction-to-local-interpretable-model-agnostic-explanations-lime) For more detailed non-paper explanation, please read this blog post by the author. A very nice read.\n Side Notes: Interpreting a model locally is supposed to be easier than interpreting the model globally, but harder to maintain (thinking about the curse of dimensionality). Methods described below aim to explain the behavior of a model as a whole. However, the global approach is unable to capture the fine-grained interpretation, such as a feature might be important in this region but not at all in another.\n Feature Selection Essentially all the classic feature selection methods (Yang and Pedersen, 1997; Guyon and Elisseeff, 2003) can be considered as ways to explain a model globally. Feature selection methods decompose the contribution of multiple features so that we can explain the overall model output by individual feature impact.\nThere are a ton of resources on feature selection so I would skip the topic in this post.\nBETA (Black Box Explanation through Transparent Approximations) BETA, short for black box explanation through transparent approximations, is closely connected to Interpretable Decision Sets (Lakkaraju, Bach \u0026amp; Leskovec, 2016). BETA learns a compact two-level decision set in which each rule explains part of the model behavior unambiguously.\nThe authors proposed an novel objective function so that the learning process is optimized for high fidelity (high agreement between explanation and the model), low unambiguity (little overlaps between decision rules in the explanation), and high interpretability (the explanation decision set is lightweight and small). These aspects are combined into one objection function to optimize for.\nFig. 6. Measures for desiderata of a good model explanation: fidelity, unambiguity, and interpretability. Given the target model is $\\mathcal{B}$, its explanation is a two level decision set $\\Re$ containing a set of rules ${(q\\_1, s\\_1, c\\_1), \\dots, (q\\_M, s\\_M, c\\_M)}$, where $q\\_i$ and $s\\_i$ are conjunctions of predicates of the form (feature, operator, value) and $c\\_i$ is a class label. Check the paper for more details. (Image source: arxiv.org/abs/1707.01154) Explainable Artificial Intelligence I borrow the name of this section from the DARPA project \u0026ldquo;Explainable Artificial Intelligence\u0026rdquo;. This Explainable AI (XAI) program aims to develop more interpretable models and to enable human to understand, appropriately trust, and effectively manage the emerging generation of artificially intelligent techniques.\nWith the progress of the deep learning applications, people start worrying about that we may never know even if the model goes bad. The complicated structure, the large number of learnable parameters, the nonlinear mathematical operations and some intriguing properties (Szegedy et al., 2014) lead to the un-interpretability of deep neural networks, creating a true black-box. Although the power of deep learning is originated from this complexity \u0026mdash; more flexible to capture rich and intricate patterns in the real-world data.\nStudies on adversarial examples (OpenAI Blog: Robust Adversarial Examples, Attacking Machine Learning with Adversarial Examples, Goodfellow, Shlens \u0026amp; Szegedy, 2015; Nguyen, Yosinski, \u0026amp; Clune, 2015) raise the alarm on the robustness and safety of AI applications. Sometimes the models could show unintended, unexpected and unpredictable behavior and we have no fast/good strategy to tell why.\nFig. 7. Illustrations of adversarial examples. (a-d) are adversarial images that are generated by adding human-imperceptible noises onto original images (Szegedy et al., 2013). A well-trained neural network model can successfully classify original ones but fail adversarial ones. (e-h) are patterns that are generated (Nguyen, Yosinski \u0026 Clune, 2015). A well-trained neural network model labels them into (e) school bus, (f) guitar, (g) peacock and (h) Pekinese respectively. (Image source: Wang, Raj \u0026 Xing, 2017) Nvidia recently developed a method to visualize the most important pixel points in their self-driving cars' decisioning process. The visualization provides insights on how AI thinks and what the system relies on while operating the car. If what the AI believes to be important agrees with how human make similar decisions, we can naturally gain more confidence in the black-box model.\nMany exciting news and findings are happening in this evolving field every day. Hope my post can give you some pointers and encourage you to investigate more into this topic :)\n Cited as:\n@article{weng2017gan, title = \u0026quot;How to Explain the Prediction of a Machine Learning Model?\u0026quot;, author = \u0026quot;Weng, Lilian\u0026quot;, journal = \u0026quot;lilianweng.github.io\u0026quot;, year = \u0026quot;2017\u0026quot;, url = \u0026quot;https://lilianweng.github.io/posts/2017-08-01-interpretation/\u0026quot; } References [1] Zachary C. Lipton. \u0026ldquo;The mythos of model interpretability.\u0026quot; arXiv preprint arXiv:1606.03490 (2016).\n[2] Been Kim, Rajiv Khanna, and Oluwasanmi O. Koyejo. \u0026ldquo;Examples are not enough, learn to criticize! criticism for interpretability.\u0026rdquo; Advances in Neural Information Processing Systems. 2016.\n[3] Himabindu Lakkaraju, Stephen H. Bach, and Jure Leskovec. \u0026ldquo;Interpretable decision sets: A joint framework for description and prediction.\u0026quot; Proc. 22nd ACM SIGKDD Intl. Conf. on Knowledge Discovery and Data Mining. ACM, 2016.\n[4] Robnik-Šikonja, Marko, and Igor Kononenko. \u0026ldquo;Explaining classifications for individual instances.\u0026quot; IEEE Transactions on Knowledge and Data Engineering 20.5 (2008): 589-600.\n[5] Baehrens, David, et al. \u0026ldquo;How to explain individual classification decisions.\u0026quot; Journal of Machine Learning Research 11.Jun (2010): 1803-1831.\n[6] Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin. \u0026ldquo;Why should I trust you?: Explaining the predictions of any classifier.\u0026quot; Proc. 22nd ACM SIGKDD Intl. Conf. on Knowledge Discovery and Data Mining. ACM, 2016.\n[7] Yiming Yang, and Jan O. Pedersen. \u0026ldquo;A comparative study on feature selection in text categorization.\u0026quot; Intl. Conf. on Machine Learning. Vol. 97. 1997.\n[8] Isabelle Guyon, and André Elisseeff. \u0026ldquo;An introduction to variable and feature selection.\u0026quot; Journal of Machine Learning Research 3.Mar (2003): 1157-1182.\n[9] Ian J. Goodfellow, Jonathon Shlens, and Christian Szegedy. \u0026ldquo;Explaining and harnessing adversarial examples.\u0026quot; ICLR 2015.\n[10] Christian Szegedy, Wojciech Zaremba, Ilya Sutskever, Joan Bruna, Dumitru Erhan, Ian Goodfellow, Rob Fergus. \u0026ldquo;Intriguing properties of neural networks.\u0026quot; Intl. Conf. on Learning Representations (2014)\n[11] Nguyen, Anh, Jason Yosinski, and Jeff Clune. \u0026ldquo;Deep neural networks are easily fooled: High confidence predictions for unrecognizable images.\u0026quot; Proc. IEEE Conference on Computer Vision and Pattern Recognition. 2015.\n[12] Benjamin Letham, Cynthia Rudin, Tyler H. McCormick, and David Madigan. \u0026ldquo;Interpretable classifiers using rules and Bayesian analysis: Building a better stroke prediction model.\u0026quot; The Annals of Applied Statistics 9, No. 3 (2015): 1350-1371.\n[13] Haohan Wang, Bhiksha Raj, and Eric P. Xing. \u0026ldquo;On the Origin of Deep Learning.\u0026quot; arXiv preprint arXiv:1702.07800 (2017).\n[14] OpenAI Blog: Robust Adversarial Examples\n[15] Attacking Machine Learning with Adversarial Examples\n[16] Reading an AI Car’s Mind: How NVIDIA’s Neural Net Makes Decisions\n","permalink":"https://lilianweng.github.io/posts/2017-08-01-interpretation/","summary":"The machine learning models have started penetrating into critical areas like health care, justice systems, and financial industry. Thus to figure out how the models make the decisions and make sure the decisioning process is aligned with the ethnic requirements or legal regulations becomes a necessity.\nMeanwhile, the rapid growth of deep learning models pushes the requirement of interpreting complicated models further. People are eager to apply the power of AI fully on key aspects of everyday life.","title":"How to Explain the Prediction of a Machine Learning Model?"},{"content":"In the Part 2 tutorial, I would like to continue the topic on stock price prediction and to endow the recurrent neural network that I have built in Part 1 with the capability of responding to multiple stocks. In order to distinguish the patterns associated with different price sequences, I use the stock symbol embedding vectors as part of the input.\n Dataset During the search, I found this library for querying Yahoo! Finance API. It would be very useful if Yahoo hasn’t shut down the historical data fetch API. You may find it useful for querying other information though. Here I pick the Google Finance link, among a couple of free data sources for downloading historical stock prices.\nThe data fetch code can be written as simple as:\nimport urllib2 from datetime import datetime BASE_URL = \u0026#34;https://www.google.com/finance/historical?\u0026#34; \u0026#34;output=csv\u0026amp;q={0}\u0026amp;startdate=Jan+1%2C+1980\u0026amp;enddate={1}\u0026#34; symbol_url = BASE_URL.format( urllib2.quote(\u0026#39;GOOG\u0026#39;), # Replace with any stock you are interested. urllib2.quote(datetime.now().strftime(\u0026#34;%b+%d,+%Y\u0026#34;), \u0026#39;+\u0026#39;) ) When fetching the content, remember to add try-catch wrapper in case the link fails or the provided stock symbol is not valid.\ntry: f = urllib2.urlopen(symbol_url) with open(\u0026#34;GOOG.csv\u0026#34;, \u0026#39;w\u0026#39;) as fin: print \u0026gt;\u0026gt; fin, f.read() except urllib2.HTTPError: print \u0026#34;Fetching Failed: {}\u0026#34;.format(symbol_url) The full working data fetcher code is available here.\nModel Construction The model is expected to learn the price sequences of different stocks in time. Due to the different underlying patterns, I would like to tell the model which stock it is dealing with explicitly. Embedding is more favored than one-hot encoding, because:\n Given that the train set includes $N$ stocks, the one-hot encoding would introduce $N$ (or $N-1$) additional sparse feature dimensions. Once each stock symbol is mapped onto a much smaller embedding vector of length $k$, $k \\ll N$, we end up with a much more compressed representation and smaller dataset to take care of. Since embedding vectors are variables to learn. Similar stocks could be associated with similar embeddings and help the prediction of each others, such as \u0026ldquo;GOOG\u0026rdquo; and \u0026ldquo;GOOGL\u0026rdquo; which you will see in Fig. 5. later. In the recurrent neural network, at one time step $t$, the input vector contains input_size (labelled as $w$) daily price values of $i$-th stock, $(p_{i, tw}, p_{i, tw+1}, \\dots, p_{i, (t+1)w-1})$. The stock symbol is uniquely mapped to a vector of length embedding_size (labelled as $k$), $(e_{i,0}, e_{i,1}, \\dots, e_{i,k})$. As illustrated in Fig. 1., the price vector is concatenated with the embedding vector and then fed into the LSTM cell.\nAnother alternative is to concatenate the embedding vectors with the last state of the LSTM cell and learn new weights $W$ and bias $b$ in the output layer. However, in this way, the LSTM cell cannot tell apart prices of one stock from another and its power would be largely restrained. Thus I decided to go with the former approach.\nFig. 1. The architecture of the stock price prediction RNN model with stock symbol embeddings. Two new configuration settings are added into RNNConfig:\n embedding_size controls the size of each embedding vector; stock_count refers to the number of unique stocks in the dataset. Together they define the size of the embedding matrix, for which the model has to learn embedding_size $\\times$ stock_count additional variables compared to the model in Part 1.\nclass RNNConfig(): # ... old ones embedding_size = 3 stock_count = 50 Define the Graph \u0026mdash; Let\u0026rsquo;s start going through some code \u0026mdash;\n(1) As demonstrated in tutorial Part 1: Define the Graph, let us define a tf.Graph() named lstm_graph and a set of tensors to hold input data, inputs, targets, and learning_rate in the same way. One more placeholder to define is a list of stock symbols associated with the input prices. Stock symbols have been mapped to unique integers beforehand with label encoding.\n# Mapped to an integer. one label refers to one stock symbol. stock_labels = tf.placeholder(tf.int32, [None, 1]) (2) Then we need to set up an embedding matrix to play as a lookup table, containing the embedding vectors of all the stocks. The matrix is initialized with random numbers in the interval [-1, 1] and gets updated during training.\n# NOTE: config = RNNConfig() and it defines hyperparameters. # Convert the integer labels to numeric embedding vectors. embedding_matrix = tf.Variable( tf.random_uniform([config.stock_count, config.embedding_size], -1.0, 1.0) ) (3) Repeat the stock labels num_steps times to match the unfolded version of RNN and the shape of inputs tensor during training. The transformation operation tf.tile receives a base tensor and creates a new tensor by replicating its certain dimensions multiples times; precisely the $i$-th dimension of the input tensor gets multiplied by multiples[i] times. For example, if the stock_labels is [[0], [0], [2], [1]] tiling it by [1, 5] produces [[0 0 0 0 0], [0 0 0 0 0], [2 2 2 2 2], [1 1 1 1 1]].\nstacked_stock_labels = tf.tile(stock_labels, multiples=[1, config.num_steps]) (4) Then we map the symbols to embedding vectors according to the lookup table embedding_matrix.\n# stock_label_embeds.get_shape() = (?, num_steps, embedding_size). stock_label_embeds = tf.nn.embedding_lookup(embedding_matrix, stacked_stock_labels) (5) Finally, combine the price values with the embedding vectors. The operation tf.concat concatenates a list of tensors along the dimension axis. In our case, we want to keep the batch size and the number of steps unchanged, but only extend the input vector of length input_size to include embedding features.\n# inputs.get_shape() = (?, num_steps, input_size) # stock_label_embeds.get_shape() = (?, num_steps, embedding_size) # inputs_with_embeds.get_shape() = (?, num_steps, input_size + embedding_size) inputs_with_embeds = tf.concat([inputs, stock_label_embeds], axis=2) The rest of code runs the dynamic RNN, extracts the last state of the LSTM cell, and handles weights and bias in the output layer. See Part 1: Define the Graph for the details.\nTraining Session Please read Part 1: Start Training Session if you haven\u0026rsquo;t for how to run a training session in Tensorflow.\nBefore feeding the data into the graph, the stock symbols should be transformed to unique integers with label encoding.\nfrom sklearn.preprocessing import LabelEncoder label_encoder = LabelEncoder() label_encoder.fit(list_of_symbols) The train/test split ratio remains same, 90% for training and 10% for testing, for every individual stock.\nVisualize the Graph After the graph is defined in code, let us check the visualization in Tensorboard to make sure that components are constructed correctly. Essentially it looks very much like our architecture illustration in Fig. 1.\nFig. 2. Tensorboard visualization of the graph defined above. Two modules, \"train\" and \"save\", have been removed from the main graph. Other than presenting the graph structure or tracking the variables in time, Tensorboard also supports embeddings visualization. In order to communicate the embedding values to Tensorboard, we need to add proper tracking in the training logs.\n(0) In my embedding visualization, I want to color each stock with its industry sector. This metadata should stored in a csv file. The file has two columns, the stock symbol and the industry sector. It does not matter whether the csv file has header, but the order of the listed stocks must be consistent with label_encoder.classes_.\nimport csv embedding_metadata_path = os.path.join(your_log_file_folder, \u0026#39;metadata.csv\u0026#39;) with open(embedding_metadata_path, \u0026#39;w\u0026#39;) as fout: csv_writer = csv.writer(fout) # write the content into the csv file. # for example, csv_writer.writerows([\u0026#34;GOOG\u0026#34;, \u0026#34;information_technology\u0026#34;]) (1) Set up the summary writer first within the training tf.Session.\nfrom tensorflow.contrib.tensorboard.plugins import projector with tf.Session(graph=lstm_graph) as sess: summary_writer = tf.summary.FileWriter(your_log_file_folder) summary_writer.add_graph(sess.graph) (2) Add the tensor embedding_matrix defined in our graph lstm_graph into the projector config variable and attach the metadata csv file.\nprojector_config = projector.ProjectorConfig() # You can add multiple embeddings. Here we add only one. added_embedding = projector_config.embeddings.add() added_embedding.tensor_name = embedding_matrix.name # Link this tensor to its metadata file. added_embedding.metadata_path = embedding_metadata_path (3) This line creates a file projector_config.pbtxt in the folder your_log_file_folder. TensorBoard will read this file during startup.\nprojector.visualize_embeddings(summary_writer, projector_config) Results The model is trained with top 50 stocks with largest market values in the S\u0026amp;P 500 index.\n(Run the following command within github.com/lilianweng/stock-rnn)\npython main.py --stock_count=50 --embed_size=3 --input_size=3 --max_epoch=50 --train And the following configuration is used:\nstock_count = 100 input_size = 3 embed_size = 3 num_steps = 30 lstm_size = 256 num_layers = 1 max_epoch = 50 keep_prob = 0.8 batch_size = 64 init_learning_rate = 0.05 learning_rate_decay = 0.99 init_epoch = 5 Price Prediction As a brief overview of the prediction quality, Fig. 3 plots the predictions for test data of \u0026ldquo;KO\u0026rdquo;, \u0026ldquo;AAPL\u0026rdquo;, \u0026ldquo;GOOG\u0026rdquo; and \u0026ldquo;NFLX\u0026rdquo;. The overall trends matched up between the true values and the predictions. Considering how the prediction task is designed, the model relies on all the historical data points to predict only next 5 (input_size) days. With a small input_size, the model does not need to worry about the long-term growth curve. Once we increase input_size, the prediction would be much harder.\nFig. 3. True and predicted stock prices of AAPL, MSFT and GOOG in the test set. The prices are normalized across consecutive prediction sliding windows (See Part 1: Normalization. The y-axis values get multiplied by 5 for a better comparison between true and predicted trends. Embedding Visualization One common technique to visualize the clusters in embedding space is t-SNE (Maaten and Hinton, 2008), which is well supported in Tensorboard. t-SNE, short for “t-Distributed Stochastic Neighbor Embedding, is a variation of Stochastic Neighbor Embedding (Hinton and Roweis, 2002), but with a modified cost function that is easier to optimize.\n Similar to SNE, t-SNE first converts the high-dimensional Euclidean distances between data points into conditional probabilities that represent similarities. t-SNE defines a similar probability distribution over the data points in the low-dimensional space, and it minimizes the Kullback–Leibler divergence between the two distributions with respect to the locations of the points on the map. Check this post for how to adjust the parameters, Perplexity and learning rate (epsilon), in t-SNE visualization.\nFig. 4. Visualization of the stock embeddings using t-SNE. Each label is colored based on the stock industry sector. We have 5 clusters. Interstingly, GOOG, GOOGL and FB belong to the same cluster, while AMZN and AAPL stay in another. In the embedding space, we can measure the similarity between two stocks by examining the similarity between their embedding vectors. For example, GOOG is mostly similar to GOOGL in the learned embeddings (See Fig. 5).\nFig. 5. \"GOOG\" is clicked in the embedding visualization graph and top 20 similar neighbors are highlighted with colors from dark to light as the similarity decreases. Known Problems The prediction values get diminished and flatten quite a lot as the training goes. That\u0026rsquo;s why I multiplied the absolute values by a constant to make the trend is more visible in Fig. 3., as I\u0026rsquo;m more curious about whether the prediction on the up-or-down direction right. However, there must be a reason for the diminishing prediction value problem. Potentially rather than using simple MSE as the loss, we can adopt another form of loss function to penalize more when the direction is predicted wrong. The loss function decreases fast at the beginning, but it suffers from occasional value explosion (a sudden peak happens and then goes back immediately). I suspect it is related to the form of loss function too. A updated and smarter loss function might be able to resolve the issue. The full code in this tutorial is available in github.com/lilianweng/stock-rnn.\n","permalink":"https://lilianweng.github.io/posts/2017-07-22-stock-rnn-part-2/","summary":"In the Part 2 tutorial, I would like to continue the topic on stock price prediction and to endow the recurrent neural network that I have built in Part 1 with the capability of responding to multiple stocks. In order to distinguish the patterns associated with different price sequences, I use the stock symbol embedding vectors as part of the input.\n Dataset During the search, I found this library for querying Yahoo!","title":"Predict Stock Prices Using RNN: Part 2"},{"content":"This is a tutorial for how to build a recurrent neural network using Tensorflow to predict stock market prices. The full working code is available in github.com/lilianweng/stock-rnn. If you don\u0026rsquo;t know what is recurrent neural network or LSTM cell, feel free to check my previous post.\n One thing I would like to emphasize that because my motivation for writing this post is more on demonstrating how to build and train an RNN model in Tensorflow and less on solve the stock prediction problem, I didn\u0026rsquo;t try hard on improving the prediction outcomes. You are more than welcome to take my code as a reference point and add more stock prediction related ideas to improve it. Enjoy!\n Overview of Existing Tutorials There are many tutorials on the Internet, like:\n A noob\u0026rsquo;s guide to implementing RNN-LSTM using Tensorflow TensorFlow RNN Tutorial LSTM by Example using Tensorflow How to build a Recurrent Neural Network in TensorFlow RNNs in Tensorflow, a Practical Guide and Undocumented Features Sequence prediction using recurrent neural networks(LSTM) with TensorFlow Anyone Can Learn To Code an LSTM-RNN in Python How to do time series prediction using RNNs, TensorFlow and Cloud ML Engine Despite all these existing tutorials, I still want to write a new one mainly for three reasons:\n Early tutorials cannot cope with the new version any more, as Tensorflow is still under development and changes on API interfaces are being made fast. Many tutorials use synthetic data in the examples. Well, I would like to play with the real world data. Some tutorials assume that you have known something about Tensorflow API beforehand, which makes the reading a bit difficult. After reading a bunch of examples, I would like to suggest taking the official example on Penn Tree Bank (PTB) dataset as your starting point. The PTB example showcases a RNN model in a pretty and modular design pattern, but it might prevent you from easily understanding the model structure. Hence, here I will build up the graph in a very straightforward manner.\nThe Goal I will explain how to build an RNN model with LSTM cells to predict the prices of S\u0026amp;P500 index. The dataset can be downloaded from Yahoo! Finance ^GSPC. In the following example, I used S\u0026amp;P 500 data from Jan 3, 1950 (the maximum date that Yahoo! Finance is able to trace back to) to Jun 23, 2017. The dataset provides several price points per day. For simplicity, we will only use the daily close prices for prediction. Meanwhile, I will demonstrate how to use TensorBoard for easily debugging and model tracking.\nAs a quick recap: the recurrent neural network (RNN) is a type of artificial neural network with self-loop in its hidden layer(s), which enables RNN to use the previous state of the hidden neuron(s) to learn the current state given the new input. RNN is good at processing sequential data. Long short-term memory (LSTM) cell is a specially designed working unit that helps RNN better memorize the long-term context.\nFor more information in depth, please read my previous post or this awesome post.\nData Preparation The stock prices is a time series of length $N$, defined as $p_0, p_1, \\dots, p_{N-1}$ in which $p_i$ is the close price on day $i$, $0 \\le i \u0026lt; N$. Imagine that we have a sliding window of a fixed size $w$ (later, we refer to this as input_size) and every time we move the window to the right by size $w$, so that there is no overlap between data in all the sliding windows.\nFig. 1. The S\u0026P 500 prices in time. We use content in one sliding windows to make prediction for the next, while there is no overlap between two consecutive windows. The RNN model we are about to build has LSTM cells as basic hidden units. We use values from the very beginning in the first sliding window $W_0$ to the window $W_t$ at time $t$:\n $$ \\begin{aligned} W_0 \u0026= (p_0, p_1, \\dots, p_{w-1}) \\\\ W_1 \u0026= (p_w, p_{w+1}, \\dots, p_{2w-1}) \\\\ \\dots \\\\ W_t \u0026= (p_{tw}, p_{tw+1}, \\dots, p_{(t+1)w-1}) \\end{aligned} $$ to predict the prices in the following window $w_{t+1}$:\n$$ W_{t+1} = (p_{(t+1)w}, p_{(t+1)w+1}, \\dots, p_{(t+2)w-1}) $$\nEssentially we try to learn an approximation function, $f(W_0, W_1, \\dots, W_t) \\approx W_{t+1}$.\nFig. 2 The unrolled version of RNN. Considering how back propagation through time (BPTT) works, we usually train RNN in a “unrolled” version so that we don\u0026rsquo;t have to do propagation computation too far back and save the training complication.\nHere is the explanation on num_steps from Tensorflow\u0026rsquo;s tutorial:\n By design, the output of a recurrent neural network (RNN) depends on arbitrarily distant inputs. Unfortunately, this makes backpropagation computation difficult. In order to make the learning process tractable, it is common practice to create an \u0026ldquo;unrolled\u0026rdquo; version of the network, which contains a fixed number (num_steps) of LSTM inputs and outputs. The model is then trained on this finite approximation of the RNN. This can be implemented by feeding inputs of length num_steps at a time and performing a backward pass after each such input block.\n The sequence of prices are first split into non-overlapped small windows. Each contains input_size numbers and each is considered as one independent input element. Then any num_steps consecutive input elements are grouped into one training input, forming an \u0026ldquo;un-rolled\u0026rdquo; version of RNN for training on Tensorfow. The corresponding label is the input element right after them.\nFor instance, if input_size=3 and num_steps=2, my first few training examples would look like:\n $$ \\begin{aligned} \\text{Input}_1 \u0026= [[p_0, p_1, p_2], [p_3, p_4, p_5]]\\quad\\text{Label}_1 = [p_6, p_7, p_8] \\\\ \\text{Input}_2 \u0026= [[p_3, p_4, p_5], [p_6, p_7, p_8]]\\quad\\text{Label}_2 = [p_9, p_{10}, p_{11}] \\\\ \\text{Input}_3 \u0026= [[p_6, p_7, p_8], [p_9, p_{10}, p_{11}]]\\quad\\text{Label}_3 = [p_{12}, p_{13}, p_{14}] \\end{aligned} $$ Here is the key part for formatting the data:\nseq = [np.array(seq[i * self.input_size: (i + 1) * self.input_size]) for i in range(len(seq) // self.input_size)] # Split into groups of `num_steps` X = np.array([seq[i: i + self.num_steps] for i in range(len(seq) - self.num_steps)]) y = np.array([seq[i + self.num_steps] for i in range(len(seq) - self.num_steps)]) The complete code of data formatting is here.\nTrain / Test Split Since we always want to predict the future, we take the latest 10% of data as the test data.\nNormalization The S\u0026amp;P 500 index increases in time, bringing about the problem that most values in the test set are out of the scale of the train set and thus the model has to predict some numbers it has never seen before. Sadly and unsurprisingly, it does a tragic job. See Fig. 3.\nFig. 3 A very sad example when the RNN model have to predict numbers out of the scale of the training data. To solve the out-of-scale issue, I normalize the prices in each sliding window. The task becomes predicting the relative change rates instead of the absolute values. In a normalized sliding window $W'_t$ at time $t$, all the values are divided by the last unknown price\u0026mdash;the last price in $W_{t-1}$:\n$$ W'_t = (\\frac{p_{tw}}{p_{tw-1}}, \\frac{p_{tw+1}}{p_{tw-1}}, \\dots, \\frac{p_{(t+1)w-1}}{p_{tw-1}}) $$\nHere is a data archive stock-data-lilianweng.tar.gz of S \u0026amp; P 500 stock prices I crawled up to Jul, 2017. Feel free to play with it :)\nModel Construction Definitions lstm_size: number of units in one LSTM layer. num_layers: number of stacked LSTM layers. keep_prob: percentage of cell units to keep in the dropout operation. init_learning_rate: the learning rate to start with. learning_rate_decay: decay ratio in later training epochs. init_epoch: number of epochs using the constant init_learning_rate. max_epoch: total number of epochs in training input_size: size of the sliding window / one training data point batch_size: number of data points to use in one mini-batch. The LSTM model has num_layers stacked LSTM layer(s) and each layer contains lstm_size number of LSTM cells. Then a dropout mask with keep probability keep_prob is applied to the output of every LSTM cell. The goal of dropout is to remove the potential strong dependency on one dimension so as to prevent overfitting.\nThe training requires max_epoch epochs in total; an epoch is a single full pass of all the training data points. In one epoch, the training data points are split into mini-batches of size batch_size. We send one mini-batch to the model for one BPTT learning. The learning rate is set to init_learning_rate during the first init_epoch epochs and then decay by $\\times$ learning_rate_decay during every succeeding epoch.\n# Configuration is wrapped in one object for easy tracking and passing. class RNNConfig(): input_size=1 num_steps=30 lstm_size=128 num_layers=1 keep_prob=0.8 batch_size = 64 init_learning_rate = 0.001 learning_rate_decay = 0.99 init_epoch = 5 max_epoch = 50 config = RNNConfig() Define Graph A tf.Graph is not attached to any real data. It defines the flow of how to process the data and how to run the computation. Later, this graph can be fed with data within a tf.session and at this moment the computation happens for real.\n\u0026mdash; Let\u0026rsquo;s start going through some code \u0026mdash;\n(1) Initialize a new graph first.\nimport tensorflow as tf tf.reset_default_graph() lstm_graph = tf.Graph() (2) How the graph works should be defined within its scope.\nwith lstm_graph.as_default(): (3) Define the data required for computation. Here we need three input variables, all defined as tf.placeholder because we don\u0026rsquo;t know what they are at the graph construction stage.\n inputs: the training data X, a tensor of shape (# data examples, num_steps, input_size); the number of data examples is unknown, so it is None. In our case, it would be batch_size in training session. Check the input format example if confused. targets: the training label y, a tensor of shape (# data examples, input_size). learning_rate: a simple float. # Dimension = ( # number of data examples, # number of input in one computation step, # number of numbers in one input # ) # We don\u0026#39;t know the number of examples beforehand, so it is None. inputs = tf.placeholder(tf.float32, [None, config.num_steps, config.input_size]) targets = tf.placeholder(tf.float32, [None, config.input_size]) learning_rate = tf.placeholder(tf.float32, None) (4) This function returns one LSTMCell with or without dropout operation.\ndef _create_one_cell(): return tf.contrib.rnn.LSTMCell(config.lstm_size, state_is_tuple=True) if config.keep_prob \u0026lt; 1.0: return tf.contrib.rnn.DropoutWrapper(lstm_cell, output_keep_prob=config.keep_prob) (5) Let\u0026rsquo;s stack the cells into multiple layers if needed. MultiRNNCell helps connect sequentially multiple simple cells to compose one cell.\ncell = tf.contrib.rnn.MultiRNNCell( [_create_one_cell() for _ in range(config.num_layers)], state_is_tuple=True ) if config.num_layers \u0026gt; 1 else _create_one_cell() (6) tf.nn.dynamic_rnn constructs a recurrent neural network specified by cell (RNNCell). It returns a pair of (model outpus, state), where the outputs val is of size (batch_size, num_steps, lstm_size) by default. The state refers to the current state of the LSTM cell, not consumed here.\nval, _ = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32) (7) tf.transpose converts the outputs from the dimension (batch_size, num_steps, lstm_size) to (num_steps, batch_size, lstm_size). Then the last output is picked.\n# Before transpose, val.get_shape() = (batch_size, num_steps, lstm_size) # After transpose, val.get_shape() = (num_steps, batch_size, lstm_size) val = tf.transpose(val, [1, 0, 2]) # last.get_shape() = (batch_size, lstm_size) last = tf.gather(val, int(val.get_shape()[0]) - 1, name=\u0026#34;last_lstm_output\u0026#34;) (8) Define weights and biases between the hidden and output layers.\nweight = tf.Variable(tf.truncated_normal([config.lstm_size, config.input_size])) bias = tf.Variable(tf.constant(0.1, shape=[config.input_size])) prediction = tf.matmul(last, weight) + bias (9) We use mean square error as the loss metric and the RMSPropOptimizer algorithm for gradient descent optimization.\nloss = tf.reduce_mean(tf.square(prediction - targets)) optimizer = tf.train.RMSPropOptimizer(learning_rate) minimize = optimizer.minimize(loss) Start Training Session (1) To start training the graph with real data, we need to start a tf.session first.\nwith tf.Session(graph=lstm_graph) as sess: (2) Initialize the variables as defined.\ntf.global_variables_initializer().run() (0) The learning rates for training epochs should have been precomputed beforehand. The index refers to the epoch index.\nlearning_rates_to_use = [ config.init_learning_rate * ( config.learning_rate_decay ** max(float(i + 1 - config.init_epoch), 0.0) ) for i in range(config.max_epoch)] (3) Each loop below completes one epoch training.\nfor epoch_step in range(config.max_epoch): current_lr = learning_rates_to_use[epoch_step] # Check https://github.com/lilianweng/stock-rnn/blob/master/data_wrapper.py # if you are curious to know what is StockDataSet and how generate_one_epoch() # is implemented. for batch_X, batch_y in stock_dataset.generate_one_epoch(config.batch_size): train_data_feed = { inputs: batch_X, targets: batch_y, learning_rate: current_lr } train_loss, _ = sess.run([loss, minimize], train_data_feed) (4) Don\u0026rsquo;t forget to save your trained model at the end.\nsaver = tf.train.Saver() saver.save(sess, \u0026#34;your_awesome_model_path_and_name\u0026#34;, global_step=max_epoch_step) The complete code is available here.\nUse TensorBoard Building the graph without visualization is like drawing in the dark, very obscure and error-prone. Tensorboard provides easy visualization of the graph structure and the learning process. Check out this hand-on tutorial, only 20 min, but it is very practical and showcases several live demos.\nBrief Summary\n Use with [tf.name_scope](https://www.tensorflow.org/api_docs/python/tf/name_scope)(\u0026quot;your_awesome_module_name\u0026quot;): to wrap elements working on the similar goal together. Many tf.* methods accepts name= argument. Assigning a customized name can make your life much easier when reading the graph. Methods like tf.summary.scalar and tf.summary.histogram help track the values of variables in the graph during iterations. In the training session, define a log file using tf.summary.FileWriter. with tf.Session(graph=lstm_graph) as sess: merged_summary = tf.summary.merge_all() writer = tf.summary.FileWriter(\u0026#34;location_for_keeping_your_log_files\u0026#34;, sess.graph) writer.add_graph(sess.graph) Later, write the training progress and summary results into the file.\n_summary = sess.run([merged_summary], test_data_feed) writer.add_summary(_summary, global_step=epoch_step) # epoch_step in range(config.max_epoch) Fig. 4a The RNN graph built by the example code. The \"train\" module has been \"removed from the main graph\", as it is not a real part of the model during the prediction time. Fig. 4b Click the \"output_layer\" module to expand it and check the structure in details. The full working code is available in github.com/lilianweng/stock-rnn.\nResults I used the following configuration in the experiment.\nnum_layers=1 keep_prob=0.8 batch_size = 64 init_learning_rate = 0.001 learning_rate_decay = 0.99 init_epoch = 5 max_epoch = 100 num_steps=30 (Thanks to Yury for cathcing a bug that I had in the price normalization. Instead of using the last price of the previous time window, I ended up with using the last price in the same window. The following plots have been corrected.)\nOverall predicting the stock prices is not an easy task. Especially after normalization, the price trends look very noisy.\nFig. 5a Predictoin results for the last 200 days in test data. Model is trained with input_size=1 and lstm_size=32. Fig. 5b Predictoin results for the last 200 days in test data. Model is trained with input_size=1 and lstm_size=128. Fig. 5c Predictoin results for the last 200 days in test data. Model is trained with input_size=5, lstm_size=128 and max_epoch=75 (instead of 50). The example code in this tutorial is available in github.com/lilianweng/stock-rnn:scripts.\n(Updated on Sep 14, 2017) The model code has been updated to be wrapped into a class: LstmRNN. The model training can be triggered by main.py, such as:\npython main.py --stock_symbol=SP500 --train --input_size=1 --lstm_size=128 ","permalink":"https://lilianweng.github.io/posts/2017-07-08-stock-rnn-part-1/","summary":"This is a tutorial for how to build a recurrent neural network using Tensorflow to predict stock market prices. The full working code is available in github.com/lilianweng/stock-rnn. If you don\u0026rsquo;t know what is recurrent neural network or LSTM cell, feel free to check my previous post.\n One thing I would like to emphasize that because my motivation for writing this post is more on demonstrating how to build and train an RNN model in Tensorflow and less on solve the stock prediction problem, I didn\u0026rsquo;t try hard on improving the prediction outcomes.","title":"Predict Stock Prices Using RNN: Part 1"},{"content":"(The post was originated from my talk for WiMLDS x Fintech meetup hosted by Affirm.)\nI believe many of you have watched or heard of the games between AlphaGo and professional Go player Lee Sedol in 2016. Lee has the highest rank of nine dan and many world championships. No doubt, he is one of the best Go players in the world, but he lost by 1-4 in this series versus AlphaGo. Before this, Go was considered to be an intractable game for computers to master, as its simple rules lay out an exponential number of variations in the board positions, many more than what in Chess. This event surely highlighted 2016 as a big year for AI. Because of AlphaGo, much attention has been attracted to the progress of AI.\nMeanwhile, many companies are spending resources on pushing the edges of AI applications, that indeed have the potential to change or even revolutionize how we are gonna live. Familiar examples include self-driving cars, chatbots, home assistant devices and many others. One of the secret receipts behind the progress we have had in recent years is deep learning.\nWhy Does Deep Learning Work Now? Deep learning models, in simple words, are large and deep artificial neural nets. A neural network (\u0026ldquo;NN\u0026rdquo;) can be well presented in a directed acyclic graph: the input layer takes in signal vectors; one or multiple hidden layers process the outputs of the previous layer. The initial concept of a neural network can be traced back to more than half a century ago. But why does it work now? Why do people start talking about them all of a sudden?\nFig. 1. A three-layer artificial neural network. (Image source: http://cs231n.github.io/convolutional-networks/#conv) The reason is surprisingly simple:\n We have a lot more data. We have much powerful computers. A large and deep neural network has many more layers + many more nodes in each layer, which results in exponentially many more parameters to tune. Without enough data, we cannot learn parameters efficiently. Without powerful computers, learning would be too slow and insufficient.\nHere is an interesting plot presenting the relationship between the data scale and the model performance, proposed by Andrew Ng in his \u0026ldquo;Nuts and Bolts of Applying Deep Learning\u0026rdquo; talk. On a small dataset, traditional algorithms (Regression, Random Forests, SVM, GBM, etc.) or statistical learning does a great job, but once the data scale goes up to the sky, the large NN outperforms others. Partially because compared to a traditional ML model, a neural network model has many more parameters and has the capability to learn complicated nonlinear patterns. Thus we expect the model to pick the most helpful features by itself without too much expert-involved manual feature engineering.\nFig. 2. The data scale versus the model performance. (Recreated based on: https://youtu.be/F1ka6a13S9I) Deep Learning Models Next, let\u0026rsquo;s go through a few classical deep learning models.\nConvolutional Neural Network Convolutional neural networks, short for \u0026ldquo;CNN\u0026rdquo;, is a type of feed-forward artificial neural networks, in which the connectivity pattern between its neurons is inspired by the organization of the visual cortex system. The primary visual cortex (V1) does edge detection out of the raw visual input from the retina. The secondary visual cortex (V2), also called prestriate cortex, receives the edge features from V1 and extracts simple visual properties such as orientation, spatial frequency, and color. The visual area V4 handles more complicated object attributes. All the processed visual features flow into the final logic unit, inferior temporal gyrus (IT), for object recognition. The shortcut between V1 and V4 inspires a special type of CNN with connections between non-adjacent layers: Residual Net (He, et al. 2016) containing \u0026ldquo;Residual Block\u0026rdquo; which supports some input of one layer to be passed to the component two layers later.\nFig. 3. Illustration of the human visual cortex system. (Image source: Wang \u0026 Raj 2017) Convolution is a mathematical term, here referring to an operation between two matrices. The convolutional layer has a fixed small matrix defined, also called kernel or filter. As the kernel is sliding, or convolving, across the matrix representation of the input image, it is computing the element-wise multiplication of the values in the kernel matrix and the original image values. Specially designed kernels can process images for common purposes like blurring, sharpening, edge detection and many others, fast and efficiently.\nFig. 4. The LeNet architecture consists of two sets of convolutional, activation, and pooling layers, followed by a fully-connected layer, activation, another fully-connected layer, and finally a softmax classifier (Image source: http://deeplearning.net/tutorial/lenet.html) Convolutional and pooling (or \u0026ldquo;sub-sampling\u0026rdquo; in Fig. 4) layers act like the V1, V2 and V4 visual cortex units, responding to feature extraction. The object recognition reasoning happens in the later fully-connected layers which consume the extracted features.\nRecurrent Neural Network A sequence model is usually designed to transform an input sequence into an output sequence that lives in a different domain. Recurrent neural network, short for \u0026ldquo;RNN\u0026rdquo;, is suitable for this purpose and has shown tremendous improvement in problems like handwriting recognition, speech recognition, and machine translation (Sutskever et al. 2011, Liwicki et al. 2007).\nA recurrent neural network model is born with the capability to process long sequential data and to tackle tasks with context spreading in time. The model processes one element in the sequence at one time step. After computation, the newly updated unit state is passed down to the next time step to facilitate the computation of the next element. Imagine the case when an RNN model reads all the Wikipedia articles, character by character, and then it can predict the following words given the context.\nFig. 5. A recurrent neural network with one hidden unit (left) and its unrolling version in time (right). The unrolling version illustrates what happens in time: $s\\_{t-1}$, $s\\_{t}$, and $s\\_{t+1}$ are the same unit with different states at different time steps $t-1$, $t$, and $t+1$. (Image source: LeCun, Bengio, and Hinton, 2015; Fig. 5) However, simple perceptron neurons that linearly combine the current input element and the last unit state may easily lose the long-term dependencies. For example, we start a sentence with \u0026ldquo;Alice is working at \u0026hellip;\u0026rdquo; and later after a whole paragraph, we want to start the next sentence with \u0026ldquo;She\u0026rdquo; or \u0026ldquo;He\u0026rdquo; correctly. If the model forgets the character\u0026rsquo;s name \u0026ldquo;Alice\u0026rdquo;, we can never know. To resolve the issue, researchers created a special neuron with a much more complicated internal structure for memorizing long-term context, named \u0026ldquo;Long-short term memory (LSTM)\u0026quot; cell. It is smart enough to learn for how long it should memorize the old information, when to forget, when to make use of the new data, and how to combine the old memory with new input. This introduction is so well written that I recommend everyone with interest in LSTM to read it. It has been officially promoted in the Tensorflow documentation ;-)\nFig. 6. The structure of a LSTM cell. (Image source: http://colah.github.io/posts/2015-08-Understanding-LSTMs) To demonstrate the power of RNNs, Andrej Karpathy built a character-based language model using RNN with LSTM cells. Without knowing any English vocabulary beforehand, the model could learn the relationship between characters to form words and then the relationship between words to form sentences. It could achieve a decent performance even without a huge set of training data.\nFig. 7. A character-based recurrent neural network model writes like a Shakespeare. (Image source: http://karpathy.github.io/2015/05/21/rnn-effectiveness) RNN: Sequence-to-Sequence Model The sequence-to-sequence model is an extended version of RNN, but its application field is distinguishable enough that I would like to list it in a separated section. Same as RNN, a sequence-to-sequence model operates on sequential data, but particularly it is commonly used to develop chatbots or personal assistants, both generating meaningful response for input questions. A sequence-to-sequence model consists of two RNNs, encoder and decoder. The encoder learns the contextual information from the input words and then hands over the knowledge to the decoder side through a \u0026ldquo;context vector\u0026rdquo; (or \u0026ldquo;thought vector\u0026rdquo;, as shown in Fig 8.). Finally, the decoder consumes the context vector and generates proper responses.\nFig. 8. A sequence-to-sequence model for generating Gmail auto replies. (Image source: https://research.googleblog.com/2015/11/computer-respond-to-this-email.html) Autoencoders Different from the previous models, autoencoders are for unsupervised learning. It is designed to learn a low-dimensional representation of a high-dimensional data set, similar to what Principal Components Analysis (PCA) does. The autoencoder model tries to learn an approximation function $ f(x) \\approx x $ to reproduce the input data. However, it is restricted by a bottleneck layer in the middle with a very small number of nodes. With limited capacity, the model is forced to form a very efficient encoding of the data, that is essentially the low-dimensional code we learned.\nFig. 9. An autoencoder model has a bottleneck layer with only a few neurons. (Image source: Geoffrey Hinton’s Coursera class \"Neural Networks for Machine Learning\" - Week 15) Hinton and Salakhutdinov used autoencoders to compress documents on a variety of topics. As shown in Fig 10, when both PCA and autoencoder were applied to reduce the documents onto two dimensions, autoencoder demonstrated a much better outcome. With the help of autoencoder, we can do efficient data compression to speed up the information retrieval including both documents and images.\nFig. 10. The outputs of PCA (left) and autoencoder (right) when both try to compress documents into two numbers. (Image source: Hinton \u0026 Salakhutdinov 2006) Reinforcement (Deep) Learning Since I started my post with AlphaGo, let us dig a bit more on why AlphaGo worked out. Reinforcement learning (\u0026ldquo;RL\u0026rdquo;) is one of the secrets behind its success. RL is a subfield of machine learning which allows machines and software agents to automatically determine the optimal behavior within a given context, with a goal to maximize the long-term performance measured by a given metric.\nFig. 11. AlphaGo neural network training pipeline and architecture. (Image source: Silver et al. 2016) The AlphaGo system starts with a supervised learning process to train a fast rollout policy and a policy network, relying on the manually curated training dataset of professional players' games. It learns what is the best strategy given the current position on the game board. Then it applies reinforcement learning by setting up self-play games. The RL policy network gets improved when it wins more and more games against previous versions of the policy network. In the self-play stage, AlphaGo becomes stronger and stronger by playing against itself without requiring additional external training data.\nGenerative Adversarial Network Generative adversarial network, short for \u0026ldquo;GAN\u0026rdquo;, is a type of deep generative models. GAN is able to create new examples after learning through the real data. It is consist of two models competing against each other in a zero-sum game framework. The famous deep learning researcher Yann LeCun gave it a super high praise: Generative Adversarial Network is the most interesting idea in the last ten years in machine learning. (See the Quora question: \u0026ldquo;What are some recent and potentially upcoming breakthroughs in deep learning?\u0026quot;)\nFig. 12. The architecture of a generative adversarial network. (Image source: http://www.kdnuggets.com/2017/01/generative-adversarial-networks-hot-topic-machine-learning.html) In the original GAN paper, GAN was proposed to generate meaningful images after learning from real photos. It comprises two independent models: the Generator and the Discriminator. The generator produces fake images and sends the output to the discriminator model. The discriminator works like a judge, as it is optimized for identifying the real photos from the fake ones. The generator model is trying hard to cheat the discriminator while the judge is trying hard not to be cheated. This interesting zero-sum game between these two models motivates both to develop their designed skills and improve their functionalities. Eventually, we take the generator model for producing new images.\nToolkits and Libraries After learning all these models, you may start wondering how you can implement the models and use them for real. Fortunately, we have many open source toolkits and libraries for building deep learning models. Tensorflow is fairly new but has attracted a lot of popularity. It turns out, TensorFlow was the most forked Github project of 2015. All that happened in a period of 2 months after its release in Nov 2015.\nHow to Learn? If you are very new to the field and willing to devote some time to studying deep learning in a more systematic way, I would recommend you to start with the book Deep Learning by Ian Goodfellow, Yoshua Bengio, and Aaron Courville. The Coursera course \u0026ldquo;Neural Networks for Machine Learning\u0026rdquo; by Geoffrey Hinton (Godfather of deep learning!). The content for the course was prepared around 2006, pretty old, but it helps you build up a solid foundation for understanding deep learning models and expedite further exploration.\nMeanwhile, maintain your curiosity and passion. The field is making progress every day. Even classical or widely adopted deep learning models may just have been proposed 1-2 years ago. Reading academic papers can help you learn stuff in depth and keep up with the cutting-edge findings.\nUseful resources Google Scholar: http://scholar.google.com arXiv cs section: https://arxiv.org/list/cs/recent Unsupervised Feature Learning and Deep Learning Tutorial Tensorflow Tutorials Data Science Weekly KDnuggets Tons of blog posts and online tutorials Related Cousera courses awesome-deep-learning-papers Blog posts mentioned Explained Visually: Image Kernels Understanding LSTM Networks The Unreasonable Effectiveness of Recurrent Neural Networks Computer, respond to this email. Interesting blogs worthy of checking www.wildml.com colah.github.io karpathy.github.io blog.openai.com Papers mentioned [1] He, Kaiming, et al. \u0026ldquo;Deep residual learning for image recognition.\u0026quot; Proc. IEEE Conf. on computer vision and pattern recognition. 2016.\n[2] Wang, Haohan, Bhiksha Raj, and Eric P. Xing. \u0026ldquo;On the Origin of Deep Learning.\u0026quot; arXiv preprint arXiv:1702.07800, 2017.\n[3] Sutskever, Ilya, James Martens, and Geoffrey E. Hinton. \u0026ldquo;Generating text with recurrent neural networks.\u0026quot; Proc. of the 28th Intl. Conf. on Machine Learning (ICML). 2011.\n[4] Liwicki, Marcus, et al. \u0026ldquo;A novel approach to on-line handwriting recognition based on bidirectional long short-term memory networks.\u0026quot; Proc. of 9th Intl. Conf. on Document Analysis and Recognition. 2007.\n[5] LeCun, Yann, Yoshua Bengio, and Geoffrey Hinton. \u0026ldquo;Deep learning.\u0026quot; Nature 521.7553 (2015): 436-444.\n[6] Hochreiter, Sepp, and Jurgen Schmidhuber. \u0026ldquo;Long short-term memory.\u0026quot; Neural computation 9.8 (1997): 1735-1780.\n[7] Cho, Kyunghyun. et al. \u0026ldquo;Learning phrase representations using RNN encoder-decoder for statistical machine translation.\u0026quot; Proc. Conference on Empirical Methods in Natural Language Processing 1724–1734 (2014).\n[8] Hinton, Geoffrey E., and Ruslan R. Salakhutdinov. \u0026ldquo;Reducing the dimensionality of data with neural networks.\u0026quot; science 313.5786 (2006): 504-507.\n[9] Silver, David, et al. \u0026ldquo;Mastering the game of Go with deep neural networks and tree search.\u0026quot; Nature 529.7587 (2016): 484-489.\n[10] Goodfellow, Ian, et al. \u0026ldquo;Generative adversarial nets.\u0026quot; NIPS, 2014.\n","permalink":"https://lilianweng.github.io/posts/2017-06-21-overview/","summary":"(The post was originated from my talk for WiMLDS x Fintech meetup hosted by Affirm.)\nI believe many of you have watched or heard of the games between AlphaGo and professional Go player Lee Sedol in 2016. Lee has the highest rank of nine dan and many world championships. No doubt, he is one of the best Go players in the world, but he lost by 1-4 in this series versus AlphaGo.","title":"An Overview of Deep Learning for Curious People"},{"content":"","permalink":"https://lilianweng.github.io/faq/","summary":"","title":"FAQ"}] \ No newline at end of file diff --git a/index.xml b/index.xml index 9028d31..c6699d2 100644 --- a/index.xml +++ b/index.xml @@ -14,8 +14,7 @@ https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/ Prompt Engineering, also known as In-Context Prompting, refers to methods for how to communicate with LLM to steer its behavior for desired outcomes without updating the model weights. It is an empirical science and the effect of prompt engineering methods can vary a lot among models, thus requiring heavy experimentation and heuristics. -Useful resources: - OpenAI Cookbook has many in-depth examples for how to utilize LLM efficiently. Prompt Engineering Guide repo contains a pretty comprehensive collection of education materials on prompt engineering. +This post only focuses on prompt engineering for autoregressive language models, so nothing with Cloze tests, image generation or multimodality models. diff --git a/posts/2023-03-15-prompt-engineering/index.html b/posts/2023-03-15-prompt-engineering/index.html index 0b34242..996cfbc 100644 --- a/posts/2023-03-15-prompt-engineering/index.html +++ b/posts/2023-03-15-prompt-engineering/index.html @@ -8,8 +8,7 @@ Prompt Engineering | Lil'Log +This post only focuses on prompt engineering for autoregressive language models, so nothing with Cloze tests, image generation or multimodality models."> @@ -72,8 +71,7 @@ } +This post only focuses on prompt engineering for autoregressive language models, so nothing with Cloze tests, image generation or multimodality models." /> @@ -82,8 +80,7 @@ +This post only focuses on prompt engineering for autoregressive language models, so nothing with Cloze tests, image generation or multimodality models."/>