<![CDATA[rj45's learning exhaust]]>https://learning-exhaust.hashnode.devRSS for NodeTue, 15 Oct 2024 07:09:21 GMT60<![CDATA[Is it the model or the data that's low rank?]]>https://learning-exhaust.hashnode.dev/is-it-the-model-or-the-data-thats-low-rankhttps://learning-exhaust.hashnode.dev/is-it-the-model-or-the-data-thats-low-rankSun, 11 Aug 2024 05:19:24 GMT<![CDATA[<p>I mentioned this little bit of analysis that I recently did during the <a target="_blank" href="https://lu.ma/llm-paper-club">Latent Space Paper Club</a>, and got a lot of positive feedback, so I did a quick writeup.</p><p>The recently released <a target="_blank" href="https://arxiv.org/abs/2407.21075">Apple Intelligence Foundation Language Models</a> paper has spark a lot of interest in their strategy for quantization loss recovery, in which the researchers quantized their models, and then trained a LoRA adapter using full precision weights with the purpose of recovering the quantization loss. If youve been reading my blog, youll know that researchers recently named this approach (or something very similar) Quantization Aware LoRA, or QA-LoRA.</p><p>Here are two papers that are close to what the Apple researchers seems to have done:</p><ul><li><p><a target="_blank" href="https://arxiv.org/abs/2310.03270v4">EfficientDM: Efficient Quantization-Aware Fine-Tuning of Low-Bit Diffusion Models</a></p></li><li><p><a target="_blank" href="https://arxiv.org/abs/2309.14717">QA-LoRA: Quantization-Aware Low-Rank Adaptation of Large Language Models</a></p></li></ul><p>I also wonder if some of the efficiency gains we've seen with the "turbo" versions of GPT4 and Claude 3 Haiku might also be using similar techniques.</p><h1 id="heading-apples-10-billion-tokens">Apples 10 billion tokens</h1><p>Apple trained their 3B parameter model with around 3.1T tokens (this is a big simplification, you should read the paper if you havent already. Its a great read!) so about 1000 tokens/parameters. Man, weve come a long way since the 20 tokens-per-parameter <a target="_blank" href="https://arxiv.org/abs/2203.15556">Chinchilla</a> days! In any event, they trained their quantization recovery LoRA using 10B tokens of the same pre-training data they used to train the model (although they dont mention how the sampled from the original pre-training data). This represents about 0.15% of the original pre-training data, but it is not clear if this same ratio will work on a model trained on a smaller training set.</p><h1 id="heading-what-about-svd">What about SVD?</h1><p>My first thought is did Apple do way more than they needed to? (The guys and gals at Apple are obviously very smart, so more of a rhetorical question). How about we take the residual from the quantized weights and find a low rank approximation of it using SVD (this can be done efficiently for low ranks). This is a more direct approach than post quantization fine tuning for 10 billion tokens (!)</p><p>From a mathematical perspective</p><p>$$SVD(W) = SVD(W_q + W_r) \neq SVD(W_q) + SVD(W_r)$$</p><p>E.g. SVD, matrix inversion, etc. is not a linear transformation. This suggests that maybe whatever low rank the weight matrices happen to have will be lost in quantization. But obviously Apple is finding a low rank approximation of the residual.</p><p>In order to get a sense of this, I created a <a target="_blank" href="https://github.com/honicky/quantization-experiments/blob/main/Rank_experiements.ipynb">notebook</a> that iterates through the a few layers of <a target="_blank" href="https://huggingface.co/EleutherAI/pythia-160m-deduped">EleutherAI/pythia-160m-deduped</a> and show the top singular values, comparing the original weight matrix, the quantized weight matrix and the residual matrix.</p><p><img src="https://lh7-rt.googleusercontent.com/docsz/AD_4nXdknrY0eEVYzVt14g0WWVjALw8Vg0EpwjGnSaeFLNBJ3TMZ8i4POOVAiXa6FKdLsR8sGu8OpgMXOB7XZY5jWK27tbiiJWEASSGWzj9vAYeylefw-cxuRlcyVDvF1H6MEABxATTbaontB6t2zV5V3_7dabQ?key=fkJ4_ZPwJgNTY8uuEm5pTw" alt /></p><p>Here we see the cumulative sum of the first 200 singular values (e.g. eigenvalues) for the first two layers. Blue is unquantized, green is quantized and red is the residual. <strong>The shape of the curve doesnt seem to be dramatically affected by quantization</strong>, so the first takeaway might be that we might be able to recover some loss using a low rank approximation of the residual. When red is the top line, the residual kinda-sorta has lower rank than the original matrix, and a LoRA <strong>might</strong> be more effective to recover the loss from quantization, but it is unclear by how much.</p><p>On the other hand, we have this very different plot:</p><p><img src="https://lh7-rt.googleusercontent.com/docsz/AD_4nXfHERtYRn-lCyHLjq4eWV6k4xpnVAARlp8nrIv2lPNI9Y3EAZ0Z7Yao3KqIpx7NVoo6RZHNgZ7-xuuCwO8zlfwu7h9FyydViPB45NaNnGd-zM1BIfo9IXgz34udy0JXqhXeWPLOFcrHhk8Km2TUlPvUBy09?key=fkJ4_ZPwJgNTY8uuEm5pTw" alt class="image--center mx-auto" /></p><p>The above plot is what a truly rank 4 matrix would look like. So at first glance, it looks like low rank approximations of our weights <strong>shouldnt work at all</strong>, since the weight matrices dont seem to have low rank.</p><h1 id="heading-but-lora-does-work">But LoRA does work</h1><p>On the other hand, LoRA does actually work quite well, both for general purpose fine-tuning, and apparently for quantization loss recovery whats going on?</p><p>This suggested to me that the <strong>span of the data itself is low rank</strong>, so a low rank approximation is enough to fine-tune or recover quantization loss. Indeed, when I took a look, <a target="_blank" href="https://arxiv.org/abs/2012.13255">the</a> <a target="_blank" href="https://arxiv.org/abs/1804.08838">papers</a> about this that the <a target="_blank" href="https://arxiv.org/abs/2106.09685">LoRA paper</a> cites talk about the data and problem space, not the models. This means that we will probably need to actually train a LoRA to recover the quantization loss, although it might help to initialize the weight matrices using SVD.</p><p>This was surprising to me: I though that the model weights were low rank, but they don't seem to be. It's the "intrinsic dimension" of the problem itself! This is very interesting, and impacts how I think about this problem.</p><p>Fortunately, <a target="_blank" href="https://github.com/EleutherAI/pythia">Pythia</a> is a set of awesome research LLMs from EleutherAI that provide both the weights and the training data (plus everything else you need to recreate the weights). We can try quantizing and then training a loss recovery adapter against the original training data (~0.15%?).</p><p>I've already gotten started, stay tuned!</p>]]><![CDATA[<p>I mentioned this little bit of analysis that I recently did during the <a target="_blank" href="https://lu.ma/llm-paper-club">Latent Space Paper Club</a>, and got a lot of positive feedback, so I did a quick writeup.</p><p>The recently released <a target="_blank" href="https://arxiv.org/abs/2407.21075">Apple Intelligence Foundation Language Models</a> paper has spark a lot of interest in their strategy for quantization loss recovery, in which the researchers quantized their models, and then trained a LoRA adapter using full precision weights with the purpose of recovering the quantization loss. If youve been reading my blog, youll know that researchers recently named this approach (or something very similar) Quantization Aware LoRA, or QA-LoRA.</p><p>Here are two papers that are close to what the Apple researchers seems to have done:</p><ul><li><p><a target="_blank" href="https://arxiv.org/abs/2310.03270v4">EfficientDM: Efficient Quantization-Aware Fine-Tuning of Low-Bit Diffusion Models</a></p></li><li><p><a target="_blank" href="https://arxiv.org/abs/2309.14717">QA-LoRA: Quantization-Aware Low-Rank Adaptation of Large Language Models</a></p></li></ul><p>I also wonder if some of the efficiency gains we've seen with the "turbo" versions of GPT4 and Claude 3 Haiku might also be using similar techniques.</p><h1 id="heading-apples-10-billion-tokens">Apples 10 billion tokens</h1><p>Apple trained their 3B parameter model with around 3.1T tokens (this is a big simplification, you should read the paper if you havent already. Its a great read!) so about 1000 tokens/parameters. Man, weve come a long way since the 20 tokens-per-parameter <a target="_blank" href="https://arxiv.org/abs/2203.15556">Chinchilla</a> days! In any event, they trained their quantization recovery LoRA using 10B tokens of the same pre-training data they used to train the model (although they dont mention how the sampled from the original pre-training data). This represents about 0.15% of the original pre-training data, but it is not clear if this same ratio will work on a model trained on a smaller training set.</p><h1 id="heading-what-about-svd">What about SVD?</h1><p>My first thought is did Apple do way more than they needed to? (The guys and gals at Apple are obviously very smart, so more of a rhetorical question). How about we take the residual from the quantized weights and find a low rank approximation of it using SVD (this can be done efficiently for low ranks). This is a more direct approach than post quantization fine tuning for 10 billion tokens (!)</p><p>From a mathematical perspective</p><p>$$SVD(W) = SVD(W_q + W_r) \neq SVD(W_q) + SVD(W_r)$$</p><p>E.g. SVD, matrix inversion, etc. is not a linear transformation. This suggests that maybe whatever low rank the weight matrices happen to have will be lost in quantization. But obviously Apple is finding a low rank approximation of the residual.</p><p>In order to get a sense of this, I created a <a target="_blank" href="https://github.com/honicky/quantization-experiments/blob/main/Rank_experiements.ipynb">notebook</a> that iterates through the a few layers of <a target="_blank" href="https://huggingface.co/EleutherAI/pythia-160m-deduped">EleutherAI/pythia-160m-deduped</a> and show the top singular values, comparing the original weight matrix, the quantized weight matrix and the residual matrix.</p><p><img src="https://lh7-rt.googleusercontent.com/docsz/AD_4nXdknrY0eEVYzVt14g0WWVjALw8Vg0EpwjGnSaeFLNBJ3TMZ8i4POOVAiXa6FKdLsR8sGu8OpgMXOB7XZY5jWK27tbiiJWEASSGWzj9vAYeylefw-cxuRlcyVDvF1H6MEABxATTbaontB6t2zV5V3_7dabQ?key=fkJ4_ZPwJgNTY8uuEm5pTw" alt /></p><p>Here we see the cumulative sum of the first 200 singular values (e.g. eigenvalues) for the first two layers. Blue is unquantized, green is quantized and red is the residual. <strong>The shape of the curve doesnt seem to be dramatically affected by quantization</strong>, so the first takeaway might be that we might be able to recover some loss using a low rank approximation of the residual. When red is the top line, the residual kinda-sorta has lower rank than the original matrix, and a LoRA <strong>might</strong> be more effective to recover the loss from quantization, but it is unclear by how much.</p><p>On the other hand, we have this very different plot:</p><p><img src="https://lh7-rt.googleusercontent.com/docsz/AD_4nXfHERtYRn-lCyHLjq4eWV6k4xpnVAARlp8nrIv2lPNI9Y3EAZ0Z7Yao3KqIpx7NVoo6RZHNgZ7-xuuCwO8zlfwu7h9FyydViPB45NaNnGd-zM1BIfo9IXgz34udy0JXqhXeWPLOFcrHhk8Km2TUlPvUBy09?key=fkJ4_ZPwJgNTY8uuEm5pTw" alt class="image--center mx-auto" /></p><p>The above plot is what a truly rank 4 matrix would look like. So at first glance, it looks like low rank approximations of our weights <strong>shouldnt work at all</strong>, since the weight matrices dont seem to have low rank.</p><h1 id="heading-but-lora-does-work">But LoRA does work</h1><p>On the other hand, LoRA does actually work quite well, both for general purpose fine-tuning, and apparently for quantization loss recovery whats going on?</p><p>This suggested to me that the <strong>span of the data itself is low rank</strong>, so a low rank approximation is enough to fine-tune or recover quantization loss. Indeed, when I took a look, <a target="_blank" href="https://arxiv.org/abs/2012.13255">the</a> <a target="_blank" href="https://arxiv.org/abs/1804.08838">papers</a> about this that the <a target="_blank" href="https://arxiv.org/abs/2106.09685">LoRA paper</a> cites talk about the data and problem space, not the models. This means that we will probably need to actually train a LoRA to recover the quantization loss, although it might help to initialize the weight matrices using SVD.</p><p>This was surprising to me: I though that the model weights were low rank, but they don't seem to be. It's the "intrinsic dimension" of the problem itself! This is very interesting, and impacts how I think about this problem.</p><p>Fortunately, <a target="_blank" href="https://github.com/EleutherAI/pythia">Pythia</a> is a set of awesome research LLMs from EleutherAI that provide both the weights and the training data (plus everything else you need to recreate the weights). We can try quantizing and then training a loss recovery adapter against the original training data (~0.15%?).</p><p>I've already gotten started, stay tuned!</p>]]>https://cdn.hashnode.com/res/hashnode/image/upload/v1723352488556/1a80e567-4b3b-48db-bb85-f90fed1cc0c0.png<![CDATA[Can we improve quantization by fine tuning?]]>https://learning-exhaust.hashnode.dev/can-we-improve-quantization-by-fine-tuninghttps://learning-exhaust.hashnode.dev/can-we-improve-quantization-by-fine-tuningFri, 12 Jul 2024 04:55:43 GMT<![CDATA[<p>As a followup to my previous post <a target="_blank" href="https://learning-exhaust.hashnode.dev/are-all-large-language-models-really-in-158-bits">Are All Large Language Models Really in 1.58 Bits?</a>, I've been wondering if we could apply the same ideas to post-training quantization. The authors trained models from scratch in <a target="_blank" href="https://arxiv.org/abs/2402.17764">The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits</a> and found that they were able to train models that perform at the same level of quality as full precision models using a few tricks, including ternary (-1,0,1) weights, weight-only quantization (e.g. don't quantize embeddings, activations, biases or other parameters), and "passthrough" weight updates (e.g. use full precision in the backward pass during training).</p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text">Using quantization in the forward pass and (usually) a passthrough weight update in the backwards pass is called Quantization Aware Training</div></div><p>The passthrough weight update mean that the technique they describe in the paper has to be applied during training, so we can't apply the technique to existing models (e.g. Llama 3, Mixtral, Phi-3) that we know and love. On the other hand, if we continue training a model on data from the distribution of data in the original training, but with quantization during the forward pass, then perhaps we can reach a minimum with respect to the loss function that is close by the un-quantized minimum.</p><h1 id="heading-post-training-quantization">Post training quantization</h1><p>Post training quantization does poorly at higher quantization levels. From <a target="_blank" href="https://arxiv.org/abs/2310.11453">BitNet: Scaling 1-bit Transformers for Large Language Models</a> (by the same authors):</p><p><img src="https://lh7-us.googleusercontent.com/docsz/AD_4nXfmDqcth0dwKzr6faYQL5shCQfAEgtL-SZKxb98pI6-dwz6CmcWvoEzoq7YBmjHI5LvLsIbadUu6gnSHOoNfiiVfPhsD3VSXlUIFZwXC0jxSgjoD3vo1FUtxSfCkAfI87VF0YEf_6iX5PGcByoi6YKuE7pX?key=Kxhsul3lVQpnBkC-v6tdXA" alt /></p><p>This table doesn't really reflect the state of the art, however. The current state of the art post-training quantization seems to be <a target="_blank" href="https://github.com/intel/auto-round">auto-round</a>, which uses an optimizer to improve how to round during quantization.. It does well for 4-bit quantization on a <a target="_blank" href="https://huggingface.co/spaces/Intel/low_bit_open_llm_leaderboard">low_bit_open_llm_leaderboard</a> that the authors created, but performance falls off at higher quantizations.</p><p>This is encouraging: this suggests that there are local minima in quantized space that are nearby the unquantized minima. The authors of the paper about the algorithms used by auto-round, <a target="_blank" href="https://arxiv.org/abs/2309.05516">Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs</a>, state that</p><blockquote><p>...we still observed a noticeable gap in the accuracy performance of ultra-low bit quantization compared to the original model, such as 2-bit quantization. This challenge could potentially be addressed by exploring non-uniform quantization and mixed-precision quantization</p></blockquote><p>It seems likely (to me) that if the techniques they mentioned might find a good minimum in quantized weight space, then perhaps "just" continued pre-training would using a quantized forward pass might work as well.</p><h1 id="heading-passthrough-gradient">Passthrough gradient</h1><p>Again, one of the important tricks the 1.58-bit authors play is to quantize during the forward pass but maintain full precision in the backward pass (e.g the so-called passthrough gradient), as mentioned above.</p><p>Im very inexperienced, so before I jump in with both feet, I want to get familiar with the tools and ideas using a toy model that is easy to visualize. Instead of using a language model, lets try quantizing using the same technique, but with simple model that finds a polynomial fit to a set of random points. Here is a notebook that I used to do some quick experiments: <a target="_blank" href="https://github.com/honicky/quantization-experiments/blob/main/bitnet_1_58b_experiments.ipynb">bitnet_1_58b_experiments.ipynb</a>.</p><h1 id="heading-toy-problem-polynomial-fit">Toy problem: polynomial fit</h1><p>Here an example polynomial fit which I've normalized to be in both the domain and range [0,1]:</p><p><img src="https://lh7-us.googleusercontent.com/docsz/AD_4nXfLpZxYSw7mybVLUNGU1T9H_NlkGjkB7Ucb5LNHk0GRLqYQ1Wznz7bA5-Q-z47750dLj9u2SHofQsKmK4hNioPkGOjc-GI_d3xFZFXsnSvNTuyZA_TRFQc0zfThLCef33N-f88pr1GRAQPQgAvgHl_dgkU?key=Kxhsul3lVQpnBkC-v6tdXA" alt /></p><p>It takes about 6 minutes to generate a 2 million of these with which to train the network (using unoptimized python).</p><h2 id="heading-a-simple-network">A simple network</h2><p>I'm going to learn to map from 7 random points to 100 outputs representing the polynomial fit. I used GELU for stability during training, and three hidden layers, because ...it seems to work fine.</p><pre><code class="lang-python">PolynomialFitNetwork( (linear_gelu_stack): Sequential( (<span class="hljs-number">0</span>): Linear(in_features=<span class="hljs-number">14</span>, out_features=<span class="hljs-number">512</span>, bias=<span class="hljs-literal">True</span>) (<span class="hljs-number">1</span>): GELU(approximate=<span class="hljs-string">'none'</span>) (<span class="hljs-number">2</span>): Linear(in_features=<span class="hljs-number">512</span>, out_features=<span class="hljs-number">512</span>, bias=<span class="hljs-literal">True</span>) (<span class="hljs-number">3</span>): GELU(approximate=<span class="hljs-string">'none'</span>) (<span class="hljs-number">4</span>): Linear(in_features=<span class="hljs-number">512</span>, out_features=<span class="hljs-number">512</span>, bias=<span class="hljs-literal">True</span>) (<span class="hljs-number">5</span>): GELU(approximate=<span class="hljs-string">'none'</span>) (<span class="hljs-number">6</span>): Linear(in_features=<span class="hljs-number">512</span>, out_features=<span class="hljs-number">512</span>, bias=<span class="hljs-literal">True</span>) (<span class="hljs-number">7</span>): GELU(approximate=<span class="hljs-string">'none'</span>) (<span class="hljs-number">8</span>): Linear(in_features=<span class="hljs-number">512</span>, out_features=<span class="hljs-number">200</span>, bias=<span class="hljs-literal">True</span>) ))</code></pre><p>Notice that that we have 200 output features. We will learn both the <code>x</code> and <code>y</code> output values, even though the <code>x</code> are just grid points, and easily learned. I decided to do this because I wanted some easy parameters to learn, so that there is some "sparsity" in the network, meaning that the network is compressible and therefore can lose precision in the weights and still work well. The point is</p><p>I trained for 41 epochs and got the test loss down to about 0.0045 (FWIW)</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1720300790751/947fdabf-8571-435d-bbff-bd01f678f668.png" alt class="image--center mx-auto" /></p><p>Lets take a look at some test examples</p><p><img src="https://lh7-us.googleusercontent.com/docsz/AD_4nXcHiQzCEMaBHUndNdLeaZJ-TsTIGcOwhYZQWJIqGqZBDuMMlnSAVTWuQBezGpAQgLkQARjYZnmbrV9JW5WuSi-b4DBBJyu5cwU2tsg2C-7Cw5W33dvZEeqEDpS4cfWipT7McmJhx_V3SkWCiuJ9ceuD2RF3?key=Kxhsul3lVQpnBkC-v6tdXA" alt /></p><p>Some things I notice about these examples:</p><ol><li><p>The blue line (the output from our model) doesn't necessarily pass through the points the function is fit to. This makes sense, since it is not a criterion in our loss function.</p></li><li><p>The model seems to struggle most when the variance of the input is high, and is good at predicting the values if the swing is very wide. Surprisingly, the model is good at predicting very wide swings precisely.</p></li></ol><p>None of this is particularly important, except to illustrate that the problem is non-trivial for this neural network to solve.</p><h1 id="heading-now-quantize">Now quantize!</h1><p>So now, lets define a quantized model and load the weights and see how we do.</p><p>Here's the code I used for <code>BitLinear</code> , adapted from <a target="_blank" href="https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/main/utils_quant.py">https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/main/utils_quant.py</a>:</p><pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">weight_quant_158b</span>(<span class="hljs-params">weight, num_bits=<span class="hljs-number">1</span></span>):</span> dtype = weight.dtype weight = weight.float() s = <span class="hljs-number">1</span> / weight.abs().mean().clamp(min=<span class="hljs-number">1e-5</span>) result = (weight * s).round().clamp(<span class="hljs-number">-1</span>, <span class="hljs-number">1</span>) / s <span class="hljs-keyword">return</span> result.type(dtype)<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">weight_quant</span>(<span class="hljs-params">x, num_bits=<span class="hljs-number">8</span></span>):</span> dtype = x.dtype x = x.float() Qn = <span class="hljs-number">-2</span> ** (num_bits - <span class="hljs-number">1</span>) Qp = <span class="hljs-number">2</span> ** (num_bits - <span class="hljs-number">1</span>) - <span class="hljs-number">1</span> s = Qp / x.abs().max(dim=<span class="hljs-number">-1</span>, keepdim=<span class="hljs-literal">True</span>).values.clamp(min=<span class="hljs-number">1e-5</span>) result = (x * s).round().clamp(Qn, Qp) / s <span class="hljs-keyword">return</span> result.type(dtype)<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">activation_quant</span>(<span class="hljs-params">x, num_bits=<span class="hljs-number">8</span></span>):</span> dtype = x.dtype x = x.float() Qn = <span class="hljs-number">-2</span> ** (num_bits - <span class="hljs-number">1</span>) Qp = <span class="hljs-number">2</span> ** (num_bits - <span class="hljs-number">1</span>) - <span class="hljs-number">1</span> s = Qp / x.abs().max(dim=<span class="hljs-number">-1</span>, keepdim=<span class="hljs-literal">True</span>).values.clamp(min=<span class="hljs-number">1e-5</span>) result = (x * s).round().clamp(Qn, Qp) / s <span class="hljs-keyword">return</span> result.type(dtype)<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">BitLinear</span>(<span class="hljs-params">nn.Linear</span>):</span> <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span>(<span class="hljs-params">self, *kargs, weight_bits=<span class="hljs-number">1</span>, input_bits=<span class="hljs-number">8</span>, **kwargs </span>):</span> super(BitLinear, self).__init__(*kargs, **kwargs) <span class="hljs-string">""" RMSNorm is placed outside BitLinear """</span> self.weight_bits = weight_bits self.input_bits = input_bits <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span>(<span class="hljs-params">self, input</span>):</span> <span class="hljs-keyword">if</span> self.weight_bits == <span class="hljs-number">1.58</span>: quant = weight_quant_158b <span class="hljs-keyword">else</span>: quant = weight_quant <span class="hljs-comment"># quant_input = input + (activation_quant(input, self.input_bits) - input).detach()</span> quant_weight = self.weight + (quant(self.weight, self.weight_bits) - self.weight).detach() <span class="hljs-comment"># out = nn.functional.linear(quant_input, quant_weight)</span> out = nn.functional.linear(input, quant_weight) <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> self.bias <span class="hljs-keyword">is</span> <span class="hljs-literal">None</span>: out += self.bias.view(<span class="hljs-number">1</span>, <span class="hljs-number">-1</span>).expand_as(out) <span class="hljs-keyword">return</span> out</code></pre><p>I combined support for quantization by any natural number of bits and ternary (1.58 bits). I also disabled the activation (input) quantization, because I want to focus on the impact of weight quantization for this toy model and I think that the activation quantization was peripheral in the paper.</p><p>Lets see how well we do with a 1.58-bit quantization. First we load the weights into a quantized model:</p><pre><code class="lang-python">bit_model = BitPolynomialFitNetwork(weight_bits=<span class="hljs-number">1.58</span>).to(device)bit_model.load_state_dict(model.state_dict())</code></pre><p>Lets see what our polynomials look like before training:</p><p><img src="https://lh7-us.googleusercontent.com/docsz/AD_4nXcgDVYJQsd-TvcHNUe5Qu0kuFDseBqf9pNkuVJyw2Xr59Y3qEf9JkBlY4hy_U_58OieUtGv0ITKZ9rvG0x3R0B62VXt6EMbW2msFBzyE8x0ShmrGBsMlh6yFhFHbjIo9nTcIyxKOzBO3jqSrNVOO0oCj_HM?key=Kxhsul3lVQpnBkC-v6tdXA" alt /></p><p>Oh man, that's pretty bad. Ok, lets try continued pre-training on the quantized model. We can try learning rates starting 1e-3 and 1e-4 with 1.58 bits and 8 bits for 10 epochs each to get a sense of how learning rate impacts the test loss. Here is a table:</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1720670326964/18373b82-de45-4586-bdd1-b67fc78100eb.png" alt class="image--center mx-auto" /></p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text">Sadly, HashNode doesn't want to format my table well when I paste it in :(. Here is a link to the spreadsheet if you want to check out the Weights and Biases links: <a target="_blank" href="https://docs.google.com/spreadsheets/d/1olFu-fiGu2OS3Kzgy1IbfJkGDTrFGKgGd0fWPVBpnzE/edit?usp=sharing">https://docs.google.com/spreadsheets/d/1olFu-fiGu2OS3Kzgy1IbfJkGDTrFGKgGd0fWPVBpnzE/edit?usp=sharing</a></div></div><p>I've added 41-epoch runs for 1.58 bits and 8 bits so that we can get a sense for the best we can expect to do at a give level of quantization. At 8-bits, we actually do slightly better than at 16 bits (probably not significant), but at 1.58 bits, we only get down to 0.007482, or about 67% greater loss. So maybe <a target="_blank" href="https://learning-exhaust.hashnode.dev/are-all-large-language-models-really-in-158-bits">All Polynomial Fit Models Are Not In 1.58 Bits?</a></p><p>In any event, the 1.58b model that started at learning rate 1e-3 (lyric-serenity-67) seems to have reached a reasonably low loss of 0.008855 after 10 epochs; much better than the corresponding random run at 0.010385. The effect is even more pronouced at 8 bit, with the continued pre-training model starting at learning rate 1e-4 (swift-wildflower-64) reach a loss of 0.004603, and basically reaching the lower limit we estimated, whereas the corresponding random-init model only reaches 0.006500.</p><h2 id="heading-learning-rate-and-convergence">Learning rate and convergence</h2><p>In many cases, I observed that higher learning rates caused instability (as is often the case). For example, from <a target="_blank" href="https://wandb.ai/honicky/bitnet%201.58b%20retraining%20experiments/runs/4lycgh9o">rural-lake-71</a>, it looks like the instability is impacting the time to convergence.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1720673859070/2a1660a4-831b-40ac-b4a4-8d144372d2dd.png" alt class="image--center mx-auto" /></p><p>Presumably we would continue to improve if we continued training, but the point here is to converge quickly.</p><h1 id="heading-takeaways">Takeaways</h1><p>Here are some things I learned from this experiment</p><ol><li><p>At least for this toy model, fine-tuning (continued pre-training) a quantized model seems to be faster than starting from scratching</p></li><li><p>The training process seems to have different sensitivity to learning rate depending on the level of quantization. This is also consistent with the paper, in which they mentioned that more aggressive quantization needed higher learning rates in the beginning of the training process.</p></li></ol><p>This is quite encouraging!</p><h2 id="heading-but-wait-qalora">But wait... QALoRA?</h2><p>Oh man! Here is an exciting paper: <a target="_blank" href="https://arxiv.org/abs/2310.03270v4">EfficientDM: Efficient Quantization-Aware Fine-Tuning of Low-Bit Diffusion Models</a>. The authors have done something similar to my proposal, only</p><ol><li><p>they fine-tuned a quantized <strong>diffusion</strong> model instead of an LLM. Some of the techniques they use may not be relevant, but</p></li><li><p>they used a quantized LoRA during fine tuning instead of fine tuning the whole model... brilliant! The LoRA can find the local minimum, and also</p></li><li><p>they use the un-quantized model to generate in-distribution data for training - another idea I had planned to use!</p></li></ol><p>Awesome! They've paved the way! They call their technique Quantization Aware LoRA, or QALoRA.</p><p>I'm excited to try this idea on my toy model to see how much performance increase we get, and then implement it on a small language model.</p><p>The <a target="_blank" href="https://github.com/EleutherAI/pythia">pythia</a> family of models have <a target="_blank" href="https://huggingface.co/collections/EleutherAI/pythia-scaling-suite-64fb5dfa8c21ebb3db7ad2e1">Huggingface transformers versions</a> that we can test these ideas on by making some minor changes to the <a target="_blank" href="https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#L71">GPTNeoXPreTrainedModel</a> implementation to use our <code>BitLinear</code> implementation (with RMSNorm too, as in the paper).</p><p>I'll get working on it and write it up ASAP!</p>]]><![CDATA[<p>As a followup to my previous post <a target="_blank" href="https://learning-exhaust.hashnode.dev/are-all-large-language-models-really-in-158-bits">Are All Large Language Models Really in 1.58 Bits?</a>, I've been wondering if we could apply the same ideas to post-training quantization. The authors trained models from scratch in <a target="_blank" href="https://arxiv.org/abs/2402.17764">The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits</a> and found that they were able to train models that perform at the same level of quality as full precision models using a few tricks, including ternary (-1,0,1) weights, weight-only quantization (e.g. don't quantize embeddings, activations, biases or other parameters), and "passthrough" weight updates (e.g. use full precision in the backward pass during training).</p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text">Using quantization in the forward pass and (usually) a passthrough weight update in the backwards pass is called Quantization Aware Training</div></div><p>The passthrough weight update mean that the technique they describe in the paper has to be applied during training, so we can't apply the technique to existing models (e.g. Llama 3, Mixtral, Phi-3) that we know and love. On the other hand, if we continue training a model on data from the distribution of data in the original training, but with quantization during the forward pass, then perhaps we can reach a minimum with respect to the loss function that is close by the un-quantized minimum.</p><h1 id="heading-post-training-quantization">Post training quantization</h1><p>Post training quantization does poorly at higher quantization levels. From <a target="_blank" href="https://arxiv.org/abs/2310.11453">BitNet: Scaling 1-bit Transformers for Large Language Models</a> (by the same authors):</p><p><img src="https://lh7-us.googleusercontent.com/docsz/AD_4nXfmDqcth0dwKzr6faYQL5shCQfAEgtL-SZKxb98pI6-dwz6CmcWvoEzoq7YBmjHI5LvLsIbadUu6gnSHOoNfiiVfPhsD3VSXlUIFZwXC0jxSgjoD3vo1FUtxSfCkAfI87VF0YEf_6iX5PGcByoi6YKuE7pX?key=Kxhsul3lVQpnBkC-v6tdXA" alt /></p><p>This table doesn't really reflect the state of the art, however. The current state of the art post-training quantization seems to be <a target="_blank" href="https://github.com/intel/auto-round">auto-round</a>, which uses an optimizer to improve how to round during quantization.. It does well for 4-bit quantization on a <a target="_blank" href="https://huggingface.co/spaces/Intel/low_bit_open_llm_leaderboard">low_bit_open_llm_leaderboard</a> that the authors created, but performance falls off at higher quantizations.</p><p>This is encouraging: this suggests that there are local minima in quantized space that are nearby the unquantized minima. The authors of the paper about the algorithms used by auto-round, <a target="_blank" href="https://arxiv.org/abs/2309.05516">Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs</a>, state that</p><blockquote><p>...we still observed a noticeable gap in the accuracy performance of ultra-low bit quantization compared to the original model, such as 2-bit quantization. This challenge could potentially be addressed by exploring non-uniform quantization and mixed-precision quantization</p></blockquote><p>It seems likely (to me) that if the techniques they mentioned might find a good minimum in quantized weight space, then perhaps "just" continued pre-training would using a quantized forward pass might work as well.</p><h1 id="heading-passthrough-gradient">Passthrough gradient</h1><p>Again, one of the important tricks the 1.58-bit authors play is to quantize during the forward pass but maintain full precision in the backward pass (e.g the so-called passthrough gradient), as mentioned above.</p><p>Im very inexperienced, so before I jump in with both feet, I want to get familiar with the tools and ideas using a toy model that is easy to visualize. Instead of using a language model, lets try quantizing using the same technique, but with simple model that finds a polynomial fit to a set of random points. Here is a notebook that I used to do some quick experiments: <a target="_blank" href="https://github.com/honicky/quantization-experiments/blob/main/bitnet_1_58b_experiments.ipynb">bitnet_1_58b_experiments.ipynb</a>.</p><h1 id="heading-toy-problem-polynomial-fit">Toy problem: polynomial fit</h1><p>Here an example polynomial fit which I've normalized to be in both the domain and range [0,1]:</p><p><img src="https://lh7-us.googleusercontent.com/docsz/AD_4nXfLpZxYSw7mybVLUNGU1T9H_NlkGjkB7Ucb5LNHk0GRLqYQ1Wznz7bA5-Q-z47750dLj9u2SHofQsKmK4hNioPkGOjc-GI_d3xFZFXsnSvNTuyZA_TRFQc0zfThLCef33N-f88pr1GRAQPQgAvgHl_dgkU?key=Kxhsul3lVQpnBkC-v6tdXA" alt /></p><p>It takes about 6 minutes to generate a 2 million of these with which to train the network (using unoptimized python).</p><h2 id="heading-a-simple-network">A simple network</h2><p>I'm going to learn to map from 7 random points to 100 outputs representing the polynomial fit. I used GELU for stability during training, and three hidden layers, because ...it seems to work fine.</p><pre><code class="lang-python">PolynomialFitNetwork( (linear_gelu_stack): Sequential( (<span class="hljs-number">0</span>): Linear(in_features=<span class="hljs-number">14</span>, out_features=<span class="hljs-number">512</span>, bias=<span class="hljs-literal">True</span>) (<span class="hljs-number">1</span>): GELU(approximate=<span class="hljs-string">'none'</span>) (<span class="hljs-number">2</span>): Linear(in_features=<span class="hljs-number">512</span>, out_features=<span class="hljs-number">512</span>, bias=<span class="hljs-literal">True</span>) (<span class="hljs-number">3</span>): GELU(approximate=<span class="hljs-string">'none'</span>) (<span class="hljs-number">4</span>): Linear(in_features=<span class="hljs-number">512</span>, out_features=<span class="hljs-number">512</span>, bias=<span class="hljs-literal">True</span>) (<span class="hljs-number">5</span>): GELU(approximate=<span class="hljs-string">'none'</span>) (<span class="hljs-number">6</span>): Linear(in_features=<span class="hljs-number">512</span>, out_features=<span class="hljs-number">512</span>, bias=<span class="hljs-literal">True</span>) (<span class="hljs-number">7</span>): GELU(approximate=<span class="hljs-string">'none'</span>) (<span class="hljs-number">8</span>): Linear(in_features=<span class="hljs-number">512</span>, out_features=<span class="hljs-number">200</span>, bias=<span class="hljs-literal">True</span>) ))</code></pre><p>Notice that that we have 200 output features. We will learn both the <code>x</code> and <code>y</code> output values, even though the <code>x</code> are just grid points, and easily learned. I decided to do this because I wanted some easy parameters to learn, so that there is some "sparsity" in the network, meaning that the network is compressible and therefore can lose precision in the weights and still work well. The point is</p><p>I trained for 41 epochs and got the test loss down to about 0.0045 (FWIW)</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1720300790751/947fdabf-8571-435d-bbff-bd01f678f668.png" alt class="image--center mx-auto" /></p><p>Lets take a look at some test examples</p><p><img src="https://lh7-us.googleusercontent.com/docsz/AD_4nXcHiQzCEMaBHUndNdLeaZJ-TsTIGcOwhYZQWJIqGqZBDuMMlnSAVTWuQBezGpAQgLkQARjYZnmbrV9JW5WuSi-b4DBBJyu5cwU2tsg2C-7Cw5W33dvZEeqEDpS4cfWipT7McmJhx_V3SkWCiuJ9ceuD2RF3?key=Kxhsul3lVQpnBkC-v6tdXA" alt /></p><p>Some things I notice about these examples:</p><ol><li><p>The blue line (the output from our model) doesn't necessarily pass through the points the function is fit to. This makes sense, since it is not a criterion in our loss function.</p></li><li><p>The model seems to struggle most when the variance of the input is high, and is good at predicting the values if the swing is very wide. Surprisingly, the model is good at predicting very wide swings precisely.</p></li></ol><p>None of this is particularly important, except to illustrate that the problem is non-trivial for this neural network to solve.</p><h1 id="heading-now-quantize">Now quantize!</h1><p>So now, lets define a quantized model and load the weights and see how we do.</p><p>Here's the code I used for <code>BitLinear</code> , adapted from <a target="_blank" href="https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/main/utils_quant.py">https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/main/utils_quant.py</a>:</p><pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">weight_quant_158b</span>(<span class="hljs-params">weight, num_bits=<span class="hljs-number">1</span></span>):</span> dtype = weight.dtype weight = weight.float() s = <span class="hljs-number">1</span> / weight.abs().mean().clamp(min=<span class="hljs-number">1e-5</span>) result = (weight * s).round().clamp(<span class="hljs-number">-1</span>, <span class="hljs-number">1</span>) / s <span class="hljs-keyword">return</span> result.type(dtype)<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">weight_quant</span>(<span class="hljs-params">x, num_bits=<span class="hljs-number">8</span></span>):</span> dtype = x.dtype x = x.float() Qn = <span class="hljs-number">-2</span> ** (num_bits - <span class="hljs-number">1</span>) Qp = <span class="hljs-number">2</span> ** (num_bits - <span class="hljs-number">1</span>) - <span class="hljs-number">1</span> s = Qp / x.abs().max(dim=<span class="hljs-number">-1</span>, keepdim=<span class="hljs-literal">True</span>).values.clamp(min=<span class="hljs-number">1e-5</span>) result = (x * s).round().clamp(Qn, Qp) / s <span class="hljs-keyword">return</span> result.type(dtype)<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">activation_quant</span>(<span class="hljs-params">x, num_bits=<span class="hljs-number">8</span></span>):</span> dtype = x.dtype x = x.float() Qn = <span class="hljs-number">-2</span> ** (num_bits - <span class="hljs-number">1</span>) Qp = <span class="hljs-number">2</span> ** (num_bits - <span class="hljs-number">1</span>) - <span class="hljs-number">1</span> s = Qp / x.abs().max(dim=<span class="hljs-number">-1</span>, keepdim=<span class="hljs-literal">True</span>).values.clamp(min=<span class="hljs-number">1e-5</span>) result = (x * s).round().clamp(Qn, Qp) / s <span class="hljs-keyword">return</span> result.type(dtype)<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">BitLinear</span>(<span class="hljs-params">nn.Linear</span>):</span> <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span>(<span class="hljs-params">self, *kargs, weight_bits=<span class="hljs-number">1</span>, input_bits=<span class="hljs-number">8</span>, **kwargs </span>):</span> super(BitLinear, self).__init__(*kargs, **kwargs) <span class="hljs-string">""" RMSNorm is placed outside BitLinear """</span> self.weight_bits = weight_bits self.input_bits = input_bits <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span>(<span class="hljs-params">self, input</span>):</span> <span class="hljs-keyword">if</span> self.weight_bits == <span class="hljs-number">1.58</span>: quant = weight_quant_158b <span class="hljs-keyword">else</span>: quant = weight_quant <span class="hljs-comment"># quant_input = input + (activation_quant(input, self.input_bits) - input).detach()</span> quant_weight = self.weight + (quant(self.weight, self.weight_bits) - self.weight).detach() <span class="hljs-comment"># out = nn.functional.linear(quant_input, quant_weight)</span> out = nn.functional.linear(input, quant_weight) <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> self.bias <span class="hljs-keyword">is</span> <span class="hljs-literal">None</span>: out += self.bias.view(<span class="hljs-number">1</span>, <span class="hljs-number">-1</span>).expand_as(out) <span class="hljs-keyword">return</span> out</code></pre><p>I combined support for quantization by any natural number of bits and ternary (1.58 bits). I also disabled the activation (input) quantization, because I want to focus on the impact of weight quantization for this toy model and I think that the activation quantization was peripheral in the paper.</p><p>Lets see how well we do with a 1.58-bit quantization. First we load the weights into a quantized model:</p><pre><code class="lang-python">bit_model = BitPolynomialFitNetwork(weight_bits=<span class="hljs-number">1.58</span>).to(device)bit_model.load_state_dict(model.state_dict())</code></pre><p>Lets see what our polynomials look like before training:</p><p><img src="https://lh7-us.googleusercontent.com/docsz/AD_4nXcgDVYJQsd-TvcHNUe5Qu0kuFDseBqf9pNkuVJyw2Xr59Y3qEf9JkBlY4hy_U_58OieUtGv0ITKZ9rvG0x3R0B62VXt6EMbW2msFBzyE8x0ShmrGBsMlh6yFhFHbjIo9nTcIyxKOzBO3jqSrNVOO0oCj_HM?key=Kxhsul3lVQpnBkC-v6tdXA" alt /></p><p>Oh man, that's pretty bad. Ok, lets try continued pre-training on the quantized model. We can try learning rates starting 1e-3 and 1e-4 with 1.58 bits and 8 bits for 10 epochs each to get a sense of how learning rate impacts the test loss. Here is a table:</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1720670326964/18373b82-de45-4586-bdd1-b67fc78100eb.png" alt class="image--center mx-auto" /></p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text">Sadly, HashNode doesn't want to format my table well when I paste it in :(. Here is a link to the spreadsheet if you want to check out the Weights and Biases links: <a target="_blank" href="https://docs.google.com/spreadsheets/d/1olFu-fiGu2OS3Kzgy1IbfJkGDTrFGKgGd0fWPVBpnzE/edit?usp=sharing">https://docs.google.com/spreadsheets/d/1olFu-fiGu2OS3Kzgy1IbfJkGDTrFGKgGd0fWPVBpnzE/edit?usp=sharing</a></div></div><p>I've added 41-epoch runs for 1.58 bits and 8 bits so that we can get a sense for the best we can expect to do at a give level of quantization. At 8-bits, we actually do slightly better than at 16 bits (probably not significant), but at 1.58 bits, we only get down to 0.007482, or about 67% greater loss. So maybe <a target="_blank" href="https://learning-exhaust.hashnode.dev/are-all-large-language-models-really-in-158-bits">All Polynomial Fit Models Are Not In 1.58 Bits?</a></p><p>In any event, the 1.58b model that started at learning rate 1e-3 (lyric-serenity-67) seems to have reached a reasonably low loss of 0.008855 after 10 epochs; much better than the corresponding random run at 0.010385. The effect is even more pronouced at 8 bit, with the continued pre-training model starting at learning rate 1e-4 (swift-wildflower-64) reach a loss of 0.004603, and basically reaching the lower limit we estimated, whereas the corresponding random-init model only reaches 0.006500.</p><h2 id="heading-learning-rate-and-convergence">Learning rate and convergence</h2><p>In many cases, I observed that higher learning rates caused instability (as is often the case). For example, from <a target="_blank" href="https://wandb.ai/honicky/bitnet%201.58b%20retraining%20experiments/runs/4lycgh9o">rural-lake-71</a>, it looks like the instability is impacting the time to convergence.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1720673859070/2a1660a4-831b-40ac-b4a4-8d144372d2dd.png" alt class="image--center mx-auto" /></p><p>Presumably we would continue to improve if we continued training, but the point here is to converge quickly.</p><h1 id="heading-takeaways">Takeaways</h1><p>Here are some things I learned from this experiment</p><ol><li><p>At least for this toy model, fine-tuning (continued pre-training) a quantized model seems to be faster than starting from scratching</p></li><li><p>The training process seems to have different sensitivity to learning rate depending on the level of quantization. This is also consistent with the paper, in which they mentioned that more aggressive quantization needed higher learning rates in the beginning of the training process.</p></li></ol><p>This is quite encouraging!</p><h2 id="heading-but-wait-qalora">But wait... QALoRA?</h2><p>Oh man! Here is an exciting paper: <a target="_blank" href="https://arxiv.org/abs/2310.03270v4">EfficientDM: Efficient Quantization-Aware Fine-Tuning of Low-Bit Diffusion Models</a>. The authors have done something similar to my proposal, only</p><ol><li><p>they fine-tuned a quantized <strong>diffusion</strong> model instead of an LLM. Some of the techniques they use may not be relevant, but</p></li><li><p>they used a quantized LoRA during fine tuning instead of fine tuning the whole model... brilliant! The LoRA can find the local minimum, and also</p></li><li><p>they use the un-quantized model to generate in-distribution data for training - another idea I had planned to use!</p></li></ol><p>Awesome! They've paved the way! They call their technique Quantization Aware LoRA, or QALoRA.</p><p>I'm excited to try this idea on my toy model to see how much performance increase we get, and then implement it on a small language model.</p><p>The <a target="_blank" href="https://github.com/EleutherAI/pythia">pythia</a> family of models have <a target="_blank" href="https://huggingface.co/collections/EleutherAI/pythia-scaling-suite-64fb5dfa8c21ebb3db7ad2e1">Huggingface transformers versions</a> that we can test these ideas on by making some minor changes to the <a target="_blank" href="https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#L71">GPTNeoXPreTrainedModel</a> implementation to use our <code>BitLinear</code> implementation (with RMSNorm too, as in the paper).</p><p>I'll get working on it and write it up ASAP!</p>]]>https://cdn.hashnode.com/res/hashnode/image/upload/v1720760034712/87d725c8-c000-4a81-9baf-fb688f814005.png<![CDATA[A toy problem and a bunch of models]]>https://learning-exhaust.hashnode.dev/a-toy-problem-and-a-bunch-of-modelshttps://learning-exhaust.hashnode.dev/a-toy-problem-and-a-bunch-of-modelsWed, 29 May 2024 03:33:06 GMT<![CDATA[<p>I have a ton of ideas bouncing around in my head about interesting things to do with or to LLMs and other types of models. I've been using a children's story co-authoring and publishing tool (<a target="_blank" href="https://github.com/honicky/story-time">https://github.com/honicky/story-time</a> and <a target="_blank" href="https://www.storytime.glass/">https://www.storytime.glass/</a>) as a toy project to learn about generative AI. One thing I've discovered is that consistency across images (<a target="_blank" href="https://openai.com/index/sora/">Sora</a> notwithstanding) is a tough problem. This actually matters quite a bit in a children's story since so much of the story is told visually.</p><p>One idea I've had some success with is to use identical language in image prompts for each character and location. The story generator first generates a story and then generates an image prompt for each paragraph to go along with the text (checking out an example at <a target="_blank" href="https://www.storytime.glass/">https://www.storytime.glass/</a> might help here). This means that I need to be able to identify all of the characters in the story, and also the characters in a give paragraph and image.</p><p>Since my god-model story generator is pretty slow, the app is difficult to interact with right now. That sounds like an excuse to learn about the best, fastest, cheapest way to figure out who are the characters in a story and a scene. I could obviously ask GPT-4o or claude-3-opus who are the characters in the story, but</p><ol><li><p>the generation loop is already slow and</p></li><li><p>where's the fun in that?</p></li></ol><p>In this post, I will play around with several options for extracting the character names from stories and paragraphs, including a 60M parameter BERT-based NER model (more below), all the way to my own fine-tuned <code>flan-t5</code>. I have links to notebooks for each experiment.</p><h2 id="heading-tldr">TL;DR</h2><p>In case you're mostly just interested in the results, skip to the bottom with the table and graphs. The nutshell: <code>haiku</code> is the top performer, and my <code>flan-t5-large</code> fine-tune is the fastest, and also the lowest cost, but only if we assume 50% occupancy on the GPU, which is a high bar to cross.</p><h2 id="heading-the-space-to-explore">The space to explore</h2><p>There are a few dimensions across which I would like to understand the behavior of different models. As I have done the experiments below, I've had as many new questions as answers. I've limited myself to a few representative examples, so this is obviously not a rigorous survey, but its enough to get a feel for how different models behave with this problem</p><h3 id="heading-architecture">Architecture</h3><p>I find it most useful to think about the different Transfomer-based LLMs in terms of their input and output during training and inference. In very broad strokes, Transfomer-based LLMs come in three main flavors</p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text">There are tons of diagrams online illustrating the detailed differences between the different architectures, but I find they don't give a good intuition for what each type will be good at. If you already have a good intuition, feel free to skip this section.</div></div><ol><li><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716787966107/132824f0-d8ee-4655-9e65-ff34b59c28c7.jpeg" alt class="image--center mx-auto" /></p><p> <strong>encoder-only</strong> - BERT is an encoder-only model. These models take in text and output a fixed set of probabilities. Exactly what those probabilities mean, how many of them there are, etc. depends on the specific problem they are designed to solve. With BERT, you can replace the last layer of the network with a layer that is appropriate for your problem (and then fine tune it). For example, if you were classifying text into three categories, then maybe you would have three output nodes, one representing each class. In our case we will have one output node per input-token, per class of token (is it part of a name?). The key characteristic of this type of model is that the output is a fixed, pre-determined number of numerical outputs. You have to squeeze your problem into this type of output in order to use an encoder-only model</p></li><li><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716952087152/2c5d9348-b83e-4eb4-8fdb-beb0e8014d63.jpeg" alt class="image--center mx-auto" /></p><p> <strong>decoder-only</strong> - the GPT-type models are decoder-only models. The problems that they solve must be framed in terms of a stream of text. Given some starting text, what is the "completion" of it. This is how you are probably used to using Chat GPT: 'What are the names of all the characters in the follow story: "Once upon a time..." ' and we expect a response like "The characters are RJ, Keivan and Aashray." </p><p> decoder-only models are "auto-regressive," meaning that we predict a token (part of a word) that follows the text so far, and then we feed the whole text back into the model and predict the next-next token, as if the token we just predicted was part of the prompt we provided. </p><p> We train models like this by providing lots of examples of the sequence of text in the form that we want the model to copy. For question answering like this, we would have a whole bunch of questions followed by answers stuck together (and and End-Of-Sequence token in between each pair of question/answers to tell the LLM that it doesn't need to generate more tokens after it has answered).</p></li><li><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716787925026/93312af8-cd60-443a-8f16-7b4353502ce0.jpeg" alt class="image--center mx-auto" /></p><p> encoder-decoder - T5 is an example of this type of model. This was the original Transformer architecture, and was originally designed for translating from one language to another. The input is a sequence of tokens, such as some text in Korean, and the output is another sequence of tokens such as some text in English. These models are often called sequence-to-sequence models.</p><p> One of the key differences between an encoder-decoder Transformer and a decoder-only transformer is that you train and encoder-decoder transformer with separate input and output sequences. (This is why decoders are the dominant architecture: you can use unsupervised training on tons of text to pretrain decoder-only LMs) For an encoder-decoder architecture, you're teaching the model to map one sequence of tokens to another, rather than predicting what comes next in a sequence.</p></li></ol><p>Obviously we can map many problems to any of these architectures. We can map our character extraction problem in the following way</p><ul><li><p>encoder-only: For each token in the input, output whether it is a part of a character name. We can then go through the labels and put together successive character-name tokens, deduplicate and we will have a list of character names.</p></li><li><p>decoder-only: prompt as above</p></li><li><p>encoder-decoder: Input is a story, output is a list of names</p></li></ul><p>The encoder-decoder architecture seems like a good match for this problem, but it is unclear to me which architecture is the most efficient use of parameters for this problem vs. a decoder-only architecture that can dedicate more resources to understand the text of the story, since the output format is so simple.</p><h3 id="heading-other-model-considerations">Other model considerations</h3><p><strong>Parameter Count:</strong> Obviously <em>parameter count</em> will matter, but maybe not in the way that we might expect. We will discuss "inverse-scaling" problems below, of which this is probably one.</p><p><strong>"Local" vs API</strong> will have big operational consequences. We must buy or rent a GPU, and have enough inference workload to keep it busy. If we use serverless GPU services, we will pay a latency penalty on startup while the model loads into the GPU. If we use an API, then <em>maybe</em> things get expensive at large scale? Or maybe economies of scale for the API providers will be enough to make up for the markup, even at some scale (I think so).</p><p><strong>Proprietary vs. open source</strong> matters in-as-much as we either care about open-source intrinsically, or want to keep our data very private, and therefore want things to be actually on-prem (I guess this is basically the same as "local").</p><p>For my character extractor problem, our data is not private, our scale is small, and latency matters. All of these things push us towards API inference, but lets find out how things shake out.</p><h2 id="heading-the-models-to-try">The models to try</h2><p>I'm trying to learn about the space of options, so I decided to generate some data and do an experiment to figure out the cost and speed and inference quality of various options:</p><ul><li><p>an existing old(ish) school Named Entity Recognition (NER) model called <code>DistilBERT-NER</code></p></li><li><p>open source LLMs: <code>mistral-7B</code> and <code>phi-3-mini</code></p></li><li><p>proprietary: <code>gpt-3.5-turbo</code>, <code>claude-3-haiku</code>, <code>claude-3-sonnet</code>, <code>mistral-small</code> (the dark horse)</p></li><li><p>a fine-tune of <code>flan-t5</code></p></li></ul><h2 id="heading-evaluation-metrics">Evaluation metrics</h2><p>For our problem, accuracy (well not <em>exactly</em> accuracy...) is pretty important. Missing characters, multiple characters that represent the same person, hallucinated characters, etc. will mean images that don't match the story.</p><p>On the other hand, we will also want to extract the names of text that the user enters <em>somewhat</em> interactively, so the performance of the model needs to be high.</p><p>Cost will matter more if there are every lots of users, but it also serves as a good comparitor since I used different hardware for different models, not to mention the proprietary models.</p><p>The metrics I used to evaluate the different models are pretty standard. If you squint a little bit, this actually is a retrieval problem (the query is the story, the response is the characters in the story). So we will use some standard retrieval metrics</p><ul><li><p><a target="_blank" href="https://en.wikipedia.org/wiki/Precision_and_recall"><strong>precision</strong></a> - of all the characters I "retrieved" from the story, how many were actually characters?</p></li><li><p><a target="_blank" href="https://en.wikipedia.org/wiki/Precision_and_recall"><strong>recall</strong></a> - of the actual characters in the story, how many did I "retrieve"</p></li><li><p><a target="_blank" href="https://en.wikipedia.org/wiki/F-score"><strong>f1</strong></a> - the <a target="_blank" href="https://en.wikipedia.org/wiki/Harmonic_mean">harmonic mean</a> of precision and recall. I think about harmonic mean as a "soft-min," meaning it tends to be close to the min of the precision and recall</p></li></ul><p>I calculate each of these metrics for each story, and then calculate the mean of the metrics over the validation or test set.</p><h2 id="heading-distilbert-ner">DistilBERT-NER</h2><p><a target="_blank" href="https://en.wikipedia.org/wiki/Named_entity">Named Entity Recognition</a> is a "classic" problem in NLP. NER usually mean recognizing any proper noun (including place names, organization, countries, etc. as well as people). BERT and it's reduced-size cousin DistilBERT are previous generation language models that are still in use because their design (encoder-only) works well for some tasks like text classification and... NER.</p><p>Fortunately, lots of people have fine tuned *BERT for NER, so we can just try out a model. <a target="_blank" href="https://huggingface.co/dslim/distilbert-NER"><code>dslim/distilbert-NER</code></a> on HuggingFace looks both small (60M parameters) and accurate, so I tried it out on Colab:</p><ul><li><a target="_blank" href="https://github.com/honicky/character-extraction/blob/main/Character_Extractor_DistilBERT_NER.ipynb"><code>Character_Extractor_DistilBERT_NER.ipynb</code></a></li></ul><p>Here are the results from the run:</p><div class="hn-table"><table><thead><tr><td>Metric</td><td>Value</td></tr></thead><tbody><tr><td>precision</td><td>0.733699</td></tr><tr><td>recall</td><td>0.736375</td></tr><tr><td>f1</td><td>0.735035</td></tr></tbody></table></div><p>Here is an example where the model did not have a perfect score:</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716658320675/43be54cc-d0ee-4c35-841f-54d30a68d3e5.png" alt class="image--center mx-auto" /></p><p>The NER model misses the "Mr." in "Mr. Delivery", and it sometimes breaks up "Mrs. Smarty Pants" in different ways.</p><p>Here's another example:</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716658512148/44d67fc1-19f8-4439-9491-b73bb1c43668.png" alt class="image--center mx-auto" /></p><p>The NER model also seems to have trouble with epithets like "Sammy the Smart Squirrel" and "Patty the Polite Parrot".</p><p>Perhaps I could add some heuristics and the model could do better, but maybe a larger, more general model with a decoder could do better without much work by me, so lets try that.</p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text">I also tried a larger BERT-based NER finetune called <a target="_blank" href="https://huggingface.co/dslim/bert-large-NER"><code>dslim/bert-large-NER</code></a>, but the model outputs the predictions in a different way from <a target="_blank" href="https://huggingface.co/dslim/distilbert-NER"><code>dslim/distilbert-NER</code></a>, and I don't think the encoder-only approach is going to be adequate for the more complex semantic names and epithets, so I gave up and moved on _()_/</div></div><h2 id="heading-open-source-models">Open source models</h2><p>Three small models have been at the front of the hype-train recently: <code>mistral-7B</code>, <code>llama-3-8B</code> and <code>phi-3-mini</code>. I have heard that smaller models actually tend to do better than larger ones on some easier tasks. Jason Wei from OpenAI has an interesting explanation of phenomena like this in a <a target="_blank" href="https://youtu.be/3gb-ZkVRemQ?si=pY56qAdxa-DfYpGn">Stanford CS25 V4 Lecture</a> (this is an excellent series, BTW!) I want to see how these smaller models do with NER.</p><p>Here is my notebook:</p><ul><li><a target="_blank" href="https://github.com/honicky/character-extraction/blob/main/Character_Extractor_open_source_local_models.ipynb">Character_Extractor_open_source_local_models.ipynb</a></li></ul><h3 id="heading-json-output">Json output</h3><p>One thing I have noticed about small models is that they tend to have trouble following instructions like "Output a comma-separated list of characters in this story. <strong>Don't output any other text, such as explanatory text</strong>." They tend to really, really want to output text like "Here is a list of characters: Fred, George, Michael. I hope this is helpful for your task," despite <a target="_blank" href="https://www.reddit.com/r/ChatGPT/comments/1894n1y/apparently_chatgpt_gives_you_better_responses_if/">offering tips</a> and warning of dead puppies.</p><p>We can fix this by using <a target="_blank" href="https://github.com/outlines-dev/outlines">outlines</a>, a library for enforcing a schema on the output of an LLM. It does this by altering the log-probabilities of the outputs at each step from the LLM to only to only allow valid json (or any other context-free grammar). The paper <a target="_blank" href="https://arxiv.org/abs/2307.09702">Efficient Guided Generation for Large Language Models</a> describes how they do this efficiently.</p><h3 id="heading-results">Results</h3><div class="hn-table"><table><thead><tr><td>Metric</td><td>Phi-3-mini 3.8B</td><td>Mistral-7B v0.3</td><td>Llama-3 8B</td></tr></thead><tbody><tr><td>precision</td><td>0.800610</td><td><strong>0.846937</strong></td><td>0.616764</td></tr><tr><td>recall</td><td>0.837257</td><td><strong>0.860567</strong></td><td>0.679253</td></tr><tr><td>f1</td><td>0.808432</td><td><strong>0.845475</strong></td><td>0.624396</td></tr></tbody></table></div><p>Oof! Llama 3 does very poorly on this task, maybe related to "inverse scaling"? Mistral-7B is hardly smaller, but does way better on this task. Llama-3 was trained on 15T tokens (!!!). It is unclear how many tokens Mistral 7B was trained on, but it seems likely that it is less than Llama-3 since Meta has been pushing the boundary on the number of training tokens (<a target="_blank" href="https://www.reddit.com/r/singularity/comments/16tmun0/mistral_ai_releases_mistral_7b_trained_on_8/">one rumor</a> puts it at 8B tokens) I wonder if inverse scaling could be related to token count as well as parameter count, since it is really related to capabilities, rather than parameters per se?</p><h2 id="heading-proprietary-models">Proprietary models</h2><p>OpenAI and Anthropic put a lot of money into building GPT-3.5/4/4o and Anthropic Claude 1/2/3, so presumably those models have good performance. They also have huge economies of scale, so presumably they also have good economics.</p><p>Since we are using the Mistral open-source model <code>mistral-7b-v0.3</code>, I'm also curious about their proprietary offering <code>mistral-small</code>. The cost per token is exactly twice that of gpt-3.5 at the time of writing, so it's not likely they're win on cost, but how does their performance stack up?</p><p>OpenAI has only <code>gpt-3.5-turbo</code> for their current generation small model offering (nobody knows for sure how big it is). Anthropic has <code>claude-3-haiku</code> which might be a 3B model (3 lines per poem) and <code>clause-3-sonnet</code> which might be a 14B model (14 lines per poem). The size of mistral is unclear, except that I would guess it is some sort of finetune or refinement of <code>mistral-7b</code>.</p><p>All three APIs offer "tool-use mode," which is presumably something similar to <code>outlines</code>, so we will use that. As it turns out, <code>sonnet</code> doesn't seem to work at all with "tool-use" mode, and haiku's performance (against our metrics) is lower in "tool-use" mode than just using a regular prompt (!). On top of that, the model uses more tokens in "tool-use" (perhaps because it is rejecting and retying tokens?)</p><p>Here is my notebook for the proprietary models:</p><ul><li><a target="_blank" href="https://github.com/honicky/character-extraction/blob/main/Character_Extractor_proprietary_models.ipynb">Character_Extractor_proprietary_models.ipynb</a></li></ul><p>Here are the results:</p><div class="hn-table"><table><thead><tr><td><strong>model</strong></td><td><strong>precision</strong></td><td><strong>recall</strong></td><td><strong>f1</strong></td></tr></thead><tbody><tr><td>gpt-3.5</td><td>0.885641</td><td>0.906672</td><td>0.890498</td></tr><tr><td>haiku</td><td><strong>0.902815</strong></td><td>0.936999</td><td><strong>0.912284</strong></td></tr><tr><td>haiku-tool</td><td>0.875685</td><td><strong>0.938288</strong></td><td>0.898414</td></tr><tr><td>sonnet</td><td>0.800982</td><td>0.897480</td><td>0.836882</td></tr><tr><td>mistral-small</td><td>0.868994</td><td>0.913717</td><td>0.883758</td></tr></tbody></table></div><p>Interestingly, <code>sonnet</code> does worse that <code>haiku</code>, which again supports the "inverse scaling" theory mentioned above. Also, note that <code>haiku-tool</code> does worse than <code>haiku</code> except on recall. I don't have a good theory about why that is.</p><p><code>gpt-3.5</code> and <code>mistral-small</code> both have good performance.</p><h2 id="heading-t5-lora">T5 LoRA</h2><p>I know this is a lot to keep track of, and there are some plots and a cost analysis below, but before we review that, lets see how well a small fine tuned model will do. I chose <a target="_blank" href="https://huggingface.co/google/flan-t5-large">google/flan-t5-large</a>, with 770M parameters, which is quite a bit smaller than the other models we have been testing besides the NER model we tried first.</p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text"><a target="_blank" href="https://arxiv.org/pdf/2106.09685">LoRA</a> stands for Low Rank Adapter, and is a way to fine tune models by inserting a small number of parameters into the model that "adapt" the model to behave in a way specific to the fine-tuning. This mechanism is much easier to train, and it also is easy to share, since the adapter has a small number of parameters as well.</div></div><p>T5 is an encoder-decoder model (like the original Transformer models), which roughly means that it is designed to translate one type of text into another. This could mean language-to-language, but it can also mean from story to list of characters.</p><p>Unlike a decoder-only model (like <code>gtp</code>), the output is separate from the input, so we don't need to extract the characters from the output text.</p><p>I'm not sure how much the model architecture will impact its performance on this task vs. a decoder-only model.</p><p>Here are the artifacts from fine-tuning</p><ul><li><p>Notebook: <a target="_blank" href="https://colab.research.google.com/github/honicky/character-extraction/blob/main/Character_Extractor_T5_LoRA.ipynb">Character_Extractor_T5_LoRA.ipynb</a></p></li><li><p>Model: <a target="_blank" href="https://huggingface.co/honicky/t5-large-lora-character-extraction">honicky/t5-large-lora-character-extraction</a></p></li><li><p>Weights and Biases: <a target="_blank" href="https://wandb.ai/honicky/t5_target_finetune_for_character_extraction/runs/mx57gh45">t5_target_finetune_for_character_extraction</a></p></li></ul><p>And the results:</p><div class="hn-table"><table><thead><tr><td>Metric</td><td>Value</td></tr></thead><tbody><tr><td>precision</td><td>0.875659</td></tr><tr><td>recall</td><td>0.886578</td></tr><tr><td>f1</td><td>0.874909</td></tr></tbody></table></div><p>Fine tuning was obviously more work than the other methods, but a big chunk of the work was actually just figuring out how to use the <code>transformer</code> and <code>peft</code> libraries to do fine tune the particular model in the particular way I wanted to, and setting up the tokenizer, data collator, etc. I got caught by cut-and-paste errors a couple of times, and a mismatched collator caused tricky memory problems. Lots of other details such as <code>fp16</code> not working for Seq2Seq models in the <code>transformers</code> library in some circumstances meant a lot of debugging to get the fine-tune to actually happen.</p><p>In the end, however, we got pretty good results, and very fast performance, and I learned a lot, so I'll take the "W".</p><h2 id="heading-analysis">Analysis</h2><p>For my use case, I have a couple of considerations besides the metrics we have been tracking. Obviously inference time and cost are important, but the operational complexity of using a model, and the sensitivity to traffic volume are also things I need to consider.</p><h3 id="heading-local-vs-api">Local vs. api</h3><p>The open-source models can be served by inference providers like <a target="_blank" href="https://replicate.com/">replicate</a>, but their off-the-shelf offerings don't include JSON mode, which was important for our evals. <a target="_blank" href="https://fireworks.ai/">fireworks.ai</a> has the advantage that i also provides "json-mode" for the open-source models it serves, so i have included that in the evaluation, although I just assume we are using a particular number of tokens and that the performance metrics are the same as doing local inference.</p><h3 id="heading-cost-of-local-inference">Cost of local inference</h3><p>In order to compare the cost of the proprietary models to local ones, I have assumed 50% occupancy on the local GPU. This obviously doesn't hold for a little demo, but maybe for a real deployment???</p><p>Since I have done some of the evaluations of the smaller models on smaller, older GPUs like T4s and L4s, I used a aggregate metric I found to create a performance ratio <a target="_blank" href="https://technical.city/en/video/Tesla-T4-vs-RTX-A4000">between T4 and RTX A4000</a>, and <a target="_blank" href="https://technical.city/en/video/RTX-A4000-vs-L4">between L4 and RTX A4000</a>, since RTX A4000s are widely available for inference on RunPod and other GPU-as-a-service providers, and offer good cost/performance ratios and similar memory.</p><h3 id="heading-finally-the-results">Finally, the results</h3><p>Here is the analysis notebook:</p><ul><li><a target="_blank" href="https://github.com/honicky/character-extraction/blob/main/Character_Extractor_Performance_Analysis.ipynb">Character_Extractor_Performance_Analysis.ipynb</a></li></ul><p>And the results:</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716700942209/9cbb95f4-f2b8-4a1c-a84b-9f5c52495ac8.png" alt class="image--center mx-auto" /></p><p>There's so much to talk about! Lets take a look at cost-per-story vs F1 score to start. Green means local, blue means hosted.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716740964313/3573ea4b-ba58-40a7-9aad-a7ec22f0a92d.png" alt class="image--center mx-auto" /></p><p><code>haiku</code> is the clear winner for F1, and our T5 wins on cost. The other open source models are mostly trailing the proprietary models on performance and hosted generally also wins on cost, except for our very small models. <code>sonnet</code> and <code>Lamma-3-8B</code> are sad outliers :(</p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text">I may have screwed something up with <code>Lamma-3-8B</code> setting up inference that made it slow. I didn't see anything obvious, but there are probably ways to make it faster if it was performing well on the metrics. Since its not, there's no need to figure that out.</div></div><p>I added the approximate cost for <code>phi-3-mini</code> and <code>mistral-7B</code> on <a target="_blank" href="https://fireworks.ai/">fireworks.ai</a>, since fireworks.ai support serving open source models using JSON mode</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716741028232/77722508-db21-4363-b5e0-9173a154892d.png" alt class="image--center mx-auto" /></p><p>Time and cost are mostly the same thing, so not-surprisingly, the plot is very similar. One thing that stands out is that <code>mistral-small</code> is very fast. Too bad its so expensive!</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716741056096/013a6726-6aec-4eef-9e96-e0fb26b5d8ee.png" alt class="image--center mx-auto" /></p><p>Finally when we just look at precision and recall, the proprietary models stand out as doing strictly better with recall. This means that they tend to get all of the characters in the story, but may throw in characters that GPT-4 didn't think were main characters. Maybe this is ok for our use case.</p><p>In summary:</p><ul><li><p>T5 LoRA has the lowest cost per story assuming volume, but DistilBERT-NER might do better if we used a better GPU <em>and</em> a lot of volume</p></li><li><p>T5 LoRA performs pretty close to the proprietary models, despite only have 770M parameters</p></li><li><p>Llama-3-8B has low performance and high cost, as does sonnet</p></li><li><p>Phi-3-mini disappointed, and broke the inverse scaling "law": worse performance despite smaller size</p></li><li><p>haiku is the all-round winner</p><ul><li><p>highest performer</p></li><li><p>easy to use - no need to rent GPUs, suffer slow startup for serverless, or maintain models</p></li><li><p>close to lowest cost at scale, lowest at low scale</p></li></ul></li><li><p>T5 LoRA did not disappoint! Great cost / performance ratio, and we might be able to improve performance with better fine-tuning.</p></li></ul><h2 id="heading-id-still-like-to-try">I'd still like to try...</h2><p>I wonder if <code>gpt-2-large</code>, which also has 770M parameters, would perform better after fine-tuning than our T5 LoRA. It seems like having a sophisticated decoder to output a few names is probably a waste of parameters, and would be better spent on a single chain that looks at the input.</p><h2 id="heading-feedback-always-appreciated">Feedback always appreciated</h2><p>I'm writing this to learn. I'm sure I've said some dumb things, or missed something important. Please comment if you hand any feedback, questions or comments!</p>]]><![CDATA[<p>I have a ton of ideas bouncing around in my head about interesting things to do with or to LLMs and other types of models. I've been using a children's story co-authoring and publishing tool (<a target="_blank" href="https://github.com/honicky/story-time">https://github.com/honicky/story-time</a> and <a target="_blank" href="https://www.storytime.glass/">https://www.storytime.glass/</a>) as a toy project to learn about generative AI. One thing I've discovered is that consistency across images (<a target="_blank" href="https://openai.com/index/sora/">Sora</a> notwithstanding) is a tough problem. This actually matters quite a bit in a children's story since so much of the story is told visually.</p><p>One idea I've had some success with is to use identical language in image prompts for each character and location. The story generator first generates a story and then generates an image prompt for each paragraph to go along with the text (checking out an example at <a target="_blank" href="https://www.storytime.glass/">https://www.storytime.glass/</a> might help here). This means that I need to be able to identify all of the characters in the story, and also the characters in a give paragraph and image.</p><p>Since my god-model story generator is pretty slow, the app is difficult to interact with right now. That sounds like an excuse to learn about the best, fastest, cheapest way to figure out who are the characters in a story and a scene. I could obviously ask GPT-4o or claude-3-opus who are the characters in the story, but</p><ol><li><p>the generation loop is already slow and</p></li><li><p>where's the fun in that?</p></li></ol><p>In this post, I will play around with several options for extracting the character names from stories and paragraphs, including a 60M parameter BERT-based NER model (more below), all the way to my own fine-tuned <code>flan-t5</code>. I have links to notebooks for each experiment.</p><h2 id="heading-tldr">TL;DR</h2><p>In case you're mostly just interested in the results, skip to the bottom with the table and graphs. The nutshell: <code>haiku</code> is the top performer, and my <code>flan-t5-large</code> fine-tune is the fastest, and also the lowest cost, but only if we assume 50% occupancy on the GPU, which is a high bar to cross.</p><h2 id="heading-the-space-to-explore">The space to explore</h2><p>There are a few dimensions across which I would like to understand the behavior of different models. As I have done the experiments below, I've had as many new questions as answers. I've limited myself to a few representative examples, so this is obviously not a rigorous survey, but its enough to get a feel for how different models behave with this problem</p><h3 id="heading-architecture">Architecture</h3><p>I find it most useful to think about the different Transfomer-based LLMs in terms of their input and output during training and inference. In very broad strokes, Transfomer-based LLMs come in three main flavors</p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text">There are tons of diagrams online illustrating the detailed differences between the different architectures, but I find they don't give a good intuition for what each type will be good at. If you already have a good intuition, feel free to skip this section.</div></div><ol><li><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716787966107/132824f0-d8ee-4655-9e65-ff34b59c28c7.jpeg" alt class="image--center mx-auto" /></p><p> <strong>encoder-only</strong> - BERT is an encoder-only model. These models take in text and output a fixed set of probabilities. Exactly what those probabilities mean, how many of them there are, etc. depends on the specific problem they are designed to solve. With BERT, you can replace the last layer of the network with a layer that is appropriate for your problem (and then fine tune it). For example, if you were classifying text into three categories, then maybe you would have three output nodes, one representing each class. In our case we will have one output node per input-token, per class of token (is it part of a name?). The key characteristic of this type of model is that the output is a fixed, pre-determined number of numerical outputs. You have to squeeze your problem into this type of output in order to use an encoder-only model</p></li><li><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716952087152/2c5d9348-b83e-4eb4-8fdb-beb0e8014d63.jpeg" alt class="image--center mx-auto" /></p><p> <strong>decoder-only</strong> - the GPT-type models are decoder-only models. The problems that they solve must be framed in terms of a stream of text. Given some starting text, what is the "completion" of it. This is how you are probably used to using Chat GPT: 'What are the names of all the characters in the follow story: "Once upon a time..." ' and we expect a response like "The characters are RJ, Keivan and Aashray." </p><p> decoder-only models are "auto-regressive," meaning that we predict a token (part of a word) that follows the text so far, and then we feed the whole text back into the model and predict the next-next token, as if the token we just predicted was part of the prompt we provided. </p><p> We train models like this by providing lots of examples of the sequence of text in the form that we want the model to copy. For question answering like this, we would have a whole bunch of questions followed by answers stuck together (and and End-Of-Sequence token in between each pair of question/answers to tell the LLM that it doesn't need to generate more tokens after it has answered).</p></li><li><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716787925026/93312af8-cd60-443a-8f16-7b4353502ce0.jpeg" alt class="image--center mx-auto" /></p><p> encoder-decoder - T5 is an example of this type of model. This was the original Transformer architecture, and was originally designed for translating from one language to another. The input is a sequence of tokens, such as some text in Korean, and the output is another sequence of tokens such as some text in English. These models are often called sequence-to-sequence models.</p><p> One of the key differences between an encoder-decoder Transformer and a decoder-only transformer is that you train and encoder-decoder transformer with separate input and output sequences. (This is why decoders are the dominant architecture: you can use unsupervised training on tons of text to pretrain decoder-only LMs) For an encoder-decoder architecture, you're teaching the model to map one sequence of tokens to another, rather than predicting what comes next in a sequence.</p></li></ol><p>Obviously we can map many problems to any of these architectures. We can map our character extraction problem in the following way</p><ul><li><p>encoder-only: For each token in the input, output whether it is a part of a character name. We can then go through the labels and put together successive character-name tokens, deduplicate and we will have a list of character names.</p></li><li><p>decoder-only: prompt as above</p></li><li><p>encoder-decoder: Input is a story, output is a list of names</p></li></ul><p>The encoder-decoder architecture seems like a good match for this problem, but it is unclear to me which architecture is the most efficient use of parameters for this problem vs. a decoder-only architecture that can dedicate more resources to understand the text of the story, since the output format is so simple.</p><h3 id="heading-other-model-considerations">Other model considerations</h3><p><strong>Parameter Count:</strong> Obviously <em>parameter count</em> will matter, but maybe not in the way that we might expect. We will discuss "inverse-scaling" problems below, of which this is probably one.</p><p><strong>"Local" vs API</strong> will have big operational consequences. We must buy or rent a GPU, and have enough inference workload to keep it busy. If we use serverless GPU services, we will pay a latency penalty on startup while the model loads into the GPU. If we use an API, then <em>maybe</em> things get expensive at large scale? Or maybe economies of scale for the API providers will be enough to make up for the markup, even at some scale (I think so).</p><p><strong>Proprietary vs. open source</strong> matters in-as-much as we either care about open-source intrinsically, or want to keep our data very private, and therefore want things to be actually on-prem (I guess this is basically the same as "local").</p><p>For my character extractor problem, our data is not private, our scale is small, and latency matters. All of these things push us towards API inference, but lets find out how things shake out.</p><h2 id="heading-the-models-to-try">The models to try</h2><p>I'm trying to learn about the space of options, so I decided to generate some data and do an experiment to figure out the cost and speed and inference quality of various options:</p><ul><li><p>an existing old(ish) school Named Entity Recognition (NER) model called <code>DistilBERT-NER</code></p></li><li><p>open source LLMs: <code>mistral-7B</code> and <code>phi-3-mini</code></p></li><li><p>proprietary: <code>gpt-3.5-turbo</code>, <code>claude-3-haiku</code>, <code>claude-3-sonnet</code>, <code>mistral-small</code> (the dark horse)</p></li><li><p>a fine-tune of <code>flan-t5</code></p></li></ul><h2 id="heading-evaluation-metrics">Evaluation metrics</h2><p>For our problem, accuracy (well not <em>exactly</em> accuracy...) is pretty important. Missing characters, multiple characters that represent the same person, hallucinated characters, etc. will mean images that don't match the story.</p><p>On the other hand, we will also want to extract the names of text that the user enters <em>somewhat</em> interactively, so the performance of the model needs to be high.</p><p>Cost will matter more if there are every lots of users, but it also serves as a good comparitor since I used different hardware for different models, not to mention the proprietary models.</p><p>The metrics I used to evaluate the different models are pretty standard. If you squint a little bit, this actually is a retrieval problem (the query is the story, the response is the characters in the story). So we will use some standard retrieval metrics</p><ul><li><p><a target="_blank" href="https://en.wikipedia.org/wiki/Precision_and_recall"><strong>precision</strong></a> - of all the characters I "retrieved" from the story, how many were actually characters?</p></li><li><p><a target="_blank" href="https://en.wikipedia.org/wiki/Precision_and_recall"><strong>recall</strong></a> - of the actual characters in the story, how many did I "retrieve"</p></li><li><p><a target="_blank" href="https://en.wikipedia.org/wiki/F-score"><strong>f1</strong></a> - the <a target="_blank" href="https://en.wikipedia.org/wiki/Harmonic_mean">harmonic mean</a> of precision and recall. I think about harmonic mean as a "soft-min," meaning it tends to be close to the min of the precision and recall</p></li></ul><p>I calculate each of these metrics for each story, and then calculate the mean of the metrics over the validation or test set.</p><h2 id="heading-distilbert-ner">DistilBERT-NER</h2><p><a target="_blank" href="https://en.wikipedia.org/wiki/Named_entity">Named Entity Recognition</a> is a "classic" problem in NLP. NER usually mean recognizing any proper noun (including place names, organization, countries, etc. as well as people). BERT and it's reduced-size cousin DistilBERT are previous generation language models that are still in use because their design (encoder-only) works well for some tasks like text classification and... NER.</p><p>Fortunately, lots of people have fine tuned *BERT for NER, so we can just try out a model. <a target="_blank" href="https://huggingface.co/dslim/distilbert-NER"><code>dslim/distilbert-NER</code></a> on HuggingFace looks both small (60M parameters) and accurate, so I tried it out on Colab:</p><ul><li><a target="_blank" href="https://github.com/honicky/character-extraction/blob/main/Character_Extractor_DistilBERT_NER.ipynb"><code>Character_Extractor_DistilBERT_NER.ipynb</code></a></li></ul><p>Here are the results from the run:</p><div class="hn-table"><table><thead><tr><td>Metric</td><td>Value</td></tr></thead><tbody><tr><td>precision</td><td>0.733699</td></tr><tr><td>recall</td><td>0.736375</td></tr><tr><td>f1</td><td>0.735035</td></tr></tbody></table></div><p>Here is an example where the model did not have a perfect score:</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716658320675/43be54cc-d0ee-4c35-841f-54d30a68d3e5.png" alt class="image--center mx-auto" /></p><p>The NER model misses the "Mr." in "Mr. Delivery", and it sometimes breaks up "Mrs. Smarty Pants" in different ways.</p><p>Here's another example:</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716658512148/44d67fc1-19f8-4439-9491-b73bb1c43668.png" alt class="image--center mx-auto" /></p><p>The NER model also seems to have trouble with epithets like "Sammy the Smart Squirrel" and "Patty the Polite Parrot".</p><p>Perhaps I could add some heuristics and the model could do better, but maybe a larger, more general model with a decoder could do better without much work by me, so lets try that.</p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text">I also tried a larger BERT-based NER finetune called <a target="_blank" href="https://huggingface.co/dslim/bert-large-NER"><code>dslim/bert-large-NER</code></a>, but the model outputs the predictions in a different way from <a target="_blank" href="https://huggingface.co/dslim/distilbert-NER"><code>dslim/distilbert-NER</code></a>, and I don't think the encoder-only approach is going to be adequate for the more complex semantic names and epithets, so I gave up and moved on _()_/</div></div><h2 id="heading-open-source-models">Open source models</h2><p>Three small models have been at the front of the hype-train recently: <code>mistral-7B</code>, <code>llama-3-8B</code> and <code>phi-3-mini</code>. I have heard that smaller models actually tend to do better than larger ones on some easier tasks. Jason Wei from OpenAI has an interesting explanation of phenomena like this in a <a target="_blank" href="https://youtu.be/3gb-ZkVRemQ?si=pY56qAdxa-DfYpGn">Stanford CS25 V4 Lecture</a> (this is an excellent series, BTW!) I want to see how these smaller models do with NER.</p><p>Here is my notebook:</p><ul><li><a target="_blank" href="https://github.com/honicky/character-extraction/blob/main/Character_Extractor_open_source_local_models.ipynb">Character_Extractor_open_source_local_models.ipynb</a></li></ul><h3 id="heading-json-output">Json output</h3><p>One thing I have noticed about small models is that they tend to have trouble following instructions like "Output a comma-separated list of characters in this story. <strong>Don't output any other text, such as explanatory text</strong>." They tend to really, really want to output text like "Here is a list of characters: Fred, George, Michael. I hope this is helpful for your task," despite <a target="_blank" href="https://www.reddit.com/r/ChatGPT/comments/1894n1y/apparently_chatgpt_gives_you_better_responses_if/">offering tips</a> and warning of dead puppies.</p><p>We can fix this by using <a target="_blank" href="https://github.com/outlines-dev/outlines">outlines</a>, a library for enforcing a schema on the output of an LLM. It does this by altering the log-probabilities of the outputs at each step from the LLM to only to only allow valid json (or any other context-free grammar). The paper <a target="_blank" href="https://arxiv.org/abs/2307.09702">Efficient Guided Generation for Large Language Models</a> describes how they do this efficiently.</p><h3 id="heading-results">Results</h3><div class="hn-table"><table><thead><tr><td>Metric</td><td>Phi-3-mini 3.8B</td><td>Mistral-7B v0.3</td><td>Llama-3 8B</td></tr></thead><tbody><tr><td>precision</td><td>0.800610</td><td><strong>0.846937</strong></td><td>0.616764</td></tr><tr><td>recall</td><td>0.837257</td><td><strong>0.860567</strong></td><td>0.679253</td></tr><tr><td>f1</td><td>0.808432</td><td><strong>0.845475</strong></td><td>0.624396</td></tr></tbody></table></div><p>Oof! Llama 3 does very poorly on this task, maybe related to "inverse scaling"? Mistral-7B is hardly smaller, but does way better on this task. Llama-3 was trained on 15T tokens (!!!). It is unclear how many tokens Mistral 7B was trained on, but it seems likely that it is less than Llama-3 since Meta has been pushing the boundary on the number of training tokens (<a target="_blank" href="https://www.reddit.com/r/singularity/comments/16tmun0/mistral_ai_releases_mistral_7b_trained_on_8/">one rumor</a> puts it at 8B tokens) I wonder if inverse scaling could be related to token count as well as parameter count, since it is really related to capabilities, rather than parameters per se?</p><h2 id="heading-proprietary-models">Proprietary models</h2><p>OpenAI and Anthropic put a lot of money into building GPT-3.5/4/4o and Anthropic Claude 1/2/3, so presumably those models have good performance. They also have huge economies of scale, so presumably they also have good economics.</p><p>Since we are using the Mistral open-source model <code>mistral-7b-v0.3</code>, I'm also curious about their proprietary offering <code>mistral-small</code>. The cost per token is exactly twice that of gpt-3.5 at the time of writing, so it's not likely they're win on cost, but how does their performance stack up?</p><p>OpenAI has only <code>gpt-3.5-turbo</code> for their current generation small model offering (nobody knows for sure how big it is). Anthropic has <code>claude-3-haiku</code> which might be a 3B model (3 lines per poem) and <code>clause-3-sonnet</code> which might be a 14B model (14 lines per poem). The size of mistral is unclear, except that I would guess it is some sort of finetune or refinement of <code>mistral-7b</code>.</p><p>All three APIs offer "tool-use mode," which is presumably something similar to <code>outlines</code>, so we will use that. As it turns out, <code>sonnet</code> doesn't seem to work at all with "tool-use" mode, and haiku's performance (against our metrics) is lower in "tool-use" mode than just using a regular prompt (!). On top of that, the model uses more tokens in "tool-use" (perhaps because it is rejecting and retying tokens?)</p><p>Here is my notebook for the proprietary models:</p><ul><li><a target="_blank" href="https://github.com/honicky/character-extraction/blob/main/Character_Extractor_proprietary_models.ipynb">Character_Extractor_proprietary_models.ipynb</a></li></ul><p>Here are the results:</p><div class="hn-table"><table><thead><tr><td><strong>model</strong></td><td><strong>precision</strong></td><td><strong>recall</strong></td><td><strong>f1</strong></td></tr></thead><tbody><tr><td>gpt-3.5</td><td>0.885641</td><td>0.906672</td><td>0.890498</td></tr><tr><td>haiku</td><td><strong>0.902815</strong></td><td>0.936999</td><td><strong>0.912284</strong></td></tr><tr><td>haiku-tool</td><td>0.875685</td><td><strong>0.938288</strong></td><td>0.898414</td></tr><tr><td>sonnet</td><td>0.800982</td><td>0.897480</td><td>0.836882</td></tr><tr><td>mistral-small</td><td>0.868994</td><td>0.913717</td><td>0.883758</td></tr></tbody></table></div><p>Interestingly, <code>sonnet</code> does worse that <code>haiku</code>, which again supports the "inverse scaling" theory mentioned above. Also, note that <code>haiku-tool</code> does worse than <code>haiku</code> except on recall. I don't have a good theory about why that is.</p><p><code>gpt-3.5</code> and <code>mistral-small</code> both have good performance.</p><h2 id="heading-t5-lora">T5 LoRA</h2><p>I know this is a lot to keep track of, and there are some plots and a cost analysis below, but before we review that, lets see how well a small fine tuned model will do. I chose <a target="_blank" href="https://huggingface.co/google/flan-t5-large">google/flan-t5-large</a>, with 770M parameters, which is quite a bit smaller than the other models we have been testing besides the NER model we tried first.</p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text"><a target="_blank" href="https://arxiv.org/pdf/2106.09685">LoRA</a> stands for Low Rank Adapter, and is a way to fine tune models by inserting a small number of parameters into the model that "adapt" the model to behave in a way specific to the fine-tuning. This mechanism is much easier to train, and it also is easy to share, since the adapter has a small number of parameters as well.</div></div><p>T5 is an encoder-decoder model (like the original Transformer models), which roughly means that it is designed to translate one type of text into another. This could mean language-to-language, but it can also mean from story to list of characters.</p><p>Unlike a decoder-only model (like <code>gtp</code>), the output is separate from the input, so we don't need to extract the characters from the output text.</p><p>I'm not sure how much the model architecture will impact its performance on this task vs. a decoder-only model.</p><p>Here are the artifacts from fine-tuning</p><ul><li><p>Notebook: <a target="_blank" href="https://colab.research.google.com/github/honicky/character-extraction/blob/main/Character_Extractor_T5_LoRA.ipynb">Character_Extractor_T5_LoRA.ipynb</a></p></li><li><p>Model: <a target="_blank" href="https://huggingface.co/honicky/t5-large-lora-character-extraction">honicky/t5-large-lora-character-extraction</a></p></li><li><p>Weights and Biases: <a target="_blank" href="https://wandb.ai/honicky/t5_target_finetune_for_character_extraction/runs/mx57gh45">t5_target_finetune_for_character_extraction</a></p></li></ul><p>And the results:</p><div class="hn-table"><table><thead><tr><td>Metric</td><td>Value</td></tr></thead><tbody><tr><td>precision</td><td>0.875659</td></tr><tr><td>recall</td><td>0.886578</td></tr><tr><td>f1</td><td>0.874909</td></tr></tbody></table></div><p>Fine tuning was obviously more work than the other methods, but a big chunk of the work was actually just figuring out how to use the <code>transformer</code> and <code>peft</code> libraries to do fine tune the particular model in the particular way I wanted to, and setting up the tokenizer, data collator, etc. I got caught by cut-and-paste errors a couple of times, and a mismatched collator caused tricky memory problems. Lots of other details such as <code>fp16</code> not working for Seq2Seq models in the <code>transformers</code> library in some circumstances meant a lot of debugging to get the fine-tune to actually happen.</p><p>In the end, however, we got pretty good results, and very fast performance, and I learned a lot, so I'll take the "W".</p><h2 id="heading-analysis">Analysis</h2><p>For my use case, I have a couple of considerations besides the metrics we have been tracking. Obviously inference time and cost are important, but the operational complexity of using a model, and the sensitivity to traffic volume are also things I need to consider.</p><h3 id="heading-local-vs-api">Local vs. api</h3><p>The open-source models can be served by inference providers like <a target="_blank" href="https://replicate.com/">replicate</a>, but their off-the-shelf offerings don't include JSON mode, which was important for our evals. <a target="_blank" href="https://fireworks.ai/">fireworks.ai</a> has the advantage that i also provides "json-mode" for the open-source models it serves, so i have included that in the evaluation, although I just assume we are using a particular number of tokens and that the performance metrics are the same as doing local inference.</p><h3 id="heading-cost-of-local-inference">Cost of local inference</h3><p>In order to compare the cost of the proprietary models to local ones, I have assumed 50% occupancy on the local GPU. This obviously doesn't hold for a little demo, but maybe for a real deployment???</p><p>Since I have done some of the evaluations of the smaller models on smaller, older GPUs like T4s and L4s, I used a aggregate metric I found to create a performance ratio <a target="_blank" href="https://technical.city/en/video/Tesla-T4-vs-RTX-A4000">between T4 and RTX A4000</a>, and <a target="_blank" href="https://technical.city/en/video/RTX-A4000-vs-L4">between L4 and RTX A4000</a>, since RTX A4000s are widely available for inference on RunPod and other GPU-as-a-service providers, and offer good cost/performance ratios and similar memory.</p><h3 id="heading-finally-the-results">Finally, the results</h3><p>Here is the analysis notebook:</p><ul><li><a target="_blank" href="https://github.com/honicky/character-extraction/blob/main/Character_Extractor_Performance_Analysis.ipynb">Character_Extractor_Performance_Analysis.ipynb</a></li></ul><p>And the results:</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716700942209/9cbb95f4-f2b8-4a1c-a84b-9f5c52495ac8.png" alt class="image--center mx-auto" /></p><p>There's so much to talk about! Lets take a look at cost-per-story vs F1 score to start. Green means local, blue means hosted.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716740964313/3573ea4b-ba58-40a7-9aad-a7ec22f0a92d.png" alt class="image--center mx-auto" /></p><p><code>haiku</code> is the clear winner for F1, and our T5 wins on cost. The other open source models are mostly trailing the proprietary models on performance and hosted generally also wins on cost, except for our very small models. <code>sonnet</code> and <code>Lamma-3-8B</code> are sad outliers :(</p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text">I may have screwed something up with <code>Lamma-3-8B</code> setting up inference that made it slow. I didn't see anything obvious, but there are probably ways to make it faster if it was performing well on the metrics. Since its not, there's no need to figure that out.</div></div><p>I added the approximate cost for <code>phi-3-mini</code> and <code>mistral-7B</code> on <a target="_blank" href="https://fireworks.ai/">fireworks.ai</a>, since fireworks.ai support serving open source models using JSON mode</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716741028232/77722508-db21-4363-b5e0-9173a154892d.png" alt class="image--center mx-auto" /></p><p>Time and cost are mostly the same thing, so not-surprisingly, the plot is very similar. One thing that stands out is that <code>mistral-small</code> is very fast. Too bad its so expensive!</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1716741056096/013a6726-6aec-4eef-9e96-e0fb26b5d8ee.png" alt class="image--center mx-auto" /></p><p>Finally when we just look at precision and recall, the proprietary models stand out as doing strictly better with recall. This means that they tend to get all of the characters in the story, but may throw in characters that GPT-4 didn't think were main characters. Maybe this is ok for our use case.</p><p>In summary:</p><ul><li><p>T5 LoRA has the lowest cost per story assuming volume, but DistilBERT-NER might do better if we used a better GPU <em>and</em> a lot of volume</p></li><li><p>T5 LoRA performs pretty close to the proprietary models, despite only have 770M parameters</p></li><li><p>Llama-3-8B has low performance and high cost, as does sonnet</p></li><li><p>Phi-3-mini disappointed, and broke the inverse scaling "law": worse performance despite smaller size</p></li><li><p>haiku is the all-round winner</p><ul><li><p>highest performer</p></li><li><p>easy to use - no need to rent GPUs, suffer slow startup for serverless, or maintain models</p></li><li><p>close to lowest cost at scale, lowest at low scale</p></li></ul></li><li><p>T5 LoRA did not disappoint! Great cost / performance ratio, and we might be able to improve performance with better fine-tuning.</p></li></ul><h2 id="heading-id-still-like-to-try">I'd still like to try...</h2><p>I wonder if <code>gpt-2-large</code>, which also has 770M parameters, would perform better after fine-tuning than our T5 LoRA. It seems like having a sophisticated decoder to output a few names is probably a waste of parameters, and would be better spent on a single chain that looks at the input.</p><h2 id="heading-feedback-always-appreciated">Feedback always appreciated</h2><p>I'm writing this to learn. I'm sure I've said some dumb things, or missed something important. Please comment if you hand any feedback, questions or comments!</p>]]>https://cdn.hashnode.com/res/hashnode/image/upload/v1716784067759/3bd8f2b8-66df-4a19-b425-ce6e5ecae9a5.webp<![CDATA[Are All Large Language Models Really in 1.58 Bits?]]>https://learning-exhaust.hashnode.dev/are-all-large-language-models-really-in-158-bitshttps://learning-exhaust.hashnode.dev/are-all-large-language-models-really-in-158-bitsFri, 12 Apr 2024 18:01:37 GMT<![CDATA[<h2 id="heading-introduction">Introduction</h2><p>This post is my <a target="_blank" href="https://www.swyx.io/learn-in-public">learning exhaust</a> from reading an exciting pre-print paper titled <a target="_blank" href="https://arxiv.org/abs/2402.17764">The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits</a> about very efficient representations of high-performing LLMs. I am trying to come up to speed on deep learning after having been heads down in telecommunications systems and data at a startup since leaving research 13 years ago (!!!). This was a really fun way to take a deep dive, so Im <a target="_blank" href="https://www.swyx.io/learn-in-public">learning in public</a>, so this post is the guide to the paper I wish I had before I read it :)</p><details><summary>1.58 bits?</summary><div data-type="detailsContent">1.58 bits is the number of bits needed to represent a base 3 (e.g. -1, 0, 1) digit. We can compute this with <code>log2(3) ~= 1.58</code> (sorry for the bad notation). It's weird to have a fractional number of bits. Intuitively we could fit <code>floor(16/1.58) = 10</code> properly encoded ternary digits in a 16-bit encoding, but in practice it's probably easier to just use a 2-bit encoding, waste 1 value, and take advantage of the highly optimized GPU hardware that is good at whole numbers of bits.</div></details><p>I dug into the details of this paper, and I'm hoping to help someone who is curious about the paper understand some of its nuances without digging through the references and reading up on all of the sub-topics.</p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text">In order to understand the paper, you really need to also read the authors' previous paper, <a target="_blank" href="https://arxiv.org/pdf/2310.11453.pdf">BitNet: Scaling 1-bit Transformers for Large Language Models</a> since the authors reference that paper both explicitly and implicitly. Even some of the high level conclusions, such as discussion of scaling laws, are only briefly mentioned, and not even explicitly referenced, although the BitNet paper treats this topic in much greater depth. This paper is a short pre-print, and doesn't yet stand on its own.</div></div><p>In this post, I've offered some questions and criticisms of the paper, but please keep in mind that this is a pre-print about a great result, and so we should be reading it that way. There will, of course, be issues with a pre-print like this, and I'm overall very excited about the work the authors have done!</p><h3 id="heading-the-punchline">The punchline</h3><p>At a very high level, the proposed architecture uses ternary weights during inference, which allows a dramatic increase in performance. Latency and memory consumption are (not surprisingly) dramatically better for the authors' models with ternary (-1, 0, 1) as opposed to full 16-bit FP weights. Here are plots from the paper:</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712198470217/84fb50ed-5338-4423-af34-428a71d367d9.png" alt="Latency and memory consumption are (not surprisingly) dramatically better for the authors' models with ternary as opposed to full 16-bit FP weights" class="image--center mx-auto" /></p><p>The punchline, however, is that the paper shows evidence that if we make a few tweaks to how we train ternary LLMs, <em>they have the same performance as LLMs trained with 16-bit weights</em>. The implication of this (if it is accurate) is that we can dramatically improve the performance of inference (but not training) without sacrificing quality.</p><p>This result is surprising (to me). Intuitively, by changing from a floating point representation to a ternary representation for the weights, we are dramatically reducing the amount of internal state that is available for the network to learn a mapping from input to output. On the other hand, researchers have spent a lot of effort investigating sparsity, with <a target="_blank" href="https://arxiv.org/pdf/2102.00554.pdf">lots of results</a> indicating that the active part of a neural network is often much smaller than the entire weight space (e.g. a lot of the weights are inactive, even in latent space). Maybe we pay for using fewer bits with denser networks? Or maybe higher precision in the weights is just unimportant because we add up a bunch of high precision activations, so we don't really need precise coefficients to blend the features from the previous layers?</p><p>Before we discuss the why, though, let's talk about the what and the how.</p><h2 id="heading-the-key-algorithmic-contributions">The key algorithmic contributions</h2><p>The concept of a binary weight neural network is not new at all. Im not an expert, so I wont point you to a particular survey, but <a target="_blank" href="https://www.google.com/search?q=binary+neuaral+network">papers abound</a>. Ternary weights have also been studied a lot because they allow (as the authors note in their introduction) the weights to mask features from previous layers by using the 0 values for some weights.</p><p>Regardless, the authors made some important decisions in the design that seem to have contributed to the model's effectiveness. Lets take a close look at how it works, particularly how it differs from regular, floating-point-weight models.</p><p>Most of the details about the model itself actually come from their previous paper, called <a target="_blank" href="https://arxiv.org/pdf/2310.11453.pdf">BitNet: Scaling 1-bit Transformers for Large Language Models</a>. The authors follow the same architecture in both papers, with the primary difference being using ternary vs. binary weights.</p><p>Here is a block diagram from the BitNet paper that explains the modified Transformer architecture at a high level.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712199049619/a30e81b0-1fda-425a-87bf-8af4fdf71212.png" alt class="image--center mx-auto" /></p><p>In other words they have replaced the standard linear blocks with the BitLinear block from the paper. Let's take a look at the way that BitLinear works in the context of the forward and backward passes.</p><h3 id="heading-the-forward-pass">The forward pass</h3><p><img src="https://lh7-us.googleusercontent.com/0M5WyGxzsAEVC61bOHeDp1uwt-SXcU3zWYLpOupin07Py0D6lXc2hDsEObD8OaU4facqfHrdY693TTWu5qAYngqtLtSFVH9vJcTAlFWUxiqg4dyRZseg2sivsQkcw1ooL8mxva-qltvbPpKP4a2zatw" alt /></p><p>I reproduced this diagram from the BitNet paper. It shows the BitLinear block (see the previous diagram for context). I added the various formulas from the BitNet paper and some of the references from <em>that</em> paper. I also unified the notational differences between the various papers.</p><p>The first important architectural decision the authors made is to only use ternary (or binary) weights; e.g. don't use binary/ternary activations. The consequence of this decision is that information propagates through the network at higher fidelity (actually at 8-bit precision), and <em>only the MLP transform layers (including attention layers) are ternary</em>.</p><p>In other words, we can add, subtract or ignore features from the previous layers, but not change their "strength". For example, the \(i\)th activation at layer \(L\) might look like this</p><p>$$a^{(L)}_i= a^{(L-1)}_1-a^{(L-1)}_2 +a^{(L-1)}_4 + \cdots$$</p><p>instead of something like</p><p>$$a^{(L)}_i= w^{(L)}_{1,i} a^{(L-1)}_{1} + w^{(L)}_{2,i} a^{(L-1)}_{2} + \cdots$$</p><p>where \(w_{j,i}\) is the weight between \(a_j^{(L-1)}\)and \(a_i^{(L)}\).</p><p>Here is a diagram of attention weights from an <a target="_blank" href="https://towardsdatascience.com/deconstructing-bert-part-2-visualizing-the-inner-workings-of-attention-60a16d86b5c1">article</a> that describes the <a target="_blank" href="https://github.com/jessevig/bertviz?tab=readme-ov-file#model-view">bertviz</a> tool. You can play with great interactive visualizations in the <a target="_blank" href="https://colab.research.google.com/drive/1hXIQ77A4TYS4y3UthWF-Ci7V7vVUoxmQ?usp=sharing#scrollTo=TG-dQt3NOlub">notebook</a> they link. (Note: these are not the weights from the BitNet models. I would love to follow up with a visualization like this of the weights from the actual models)</p><p><img src="https://miro.medium.com/v2/resize:fit:1400/1*Ak1_htrg0jctCVEqeMgaxQ.png" alt /></p><p>Notice that many of the plots are somewhat sparse, and that the intensities tend to be either on or off. This is probably partly due the way the weights are visualized, but what stands out for me is that <em>some are on, some are off, with not a lot in between</em> for many weights. Perhaps this visual intuition goes some way to explaining why ternary weights are good enough? Maybe we just need to activate or not activate upstream features, and not be too concerned with <em>how much</em> we activate them (although the sign of the activation seems to matter too)?</p><p>A recent paper called <a target="_blank" href="https://arxiv.org/pdf/2404.05405.pdf">Physics of Language Models: Part 3.3, Knowledge Capacity Scaling Laws</a> provides some more evidence about this. The paper shows that common Transformer architectures have a total information capacity of about 2 bits per parameter. They also show that this holds even in modified models that don't have any MLP (linear) layers, which suggest that information can move into other parts of the model during training. <strong>This raises the question of whether a ternary-weight model will reach its information capacity sooner than 2 or more bit model, since 1.58 < 2 (to state the obvious).</strong> Thanks to <a target="_blank" href="https://twitter.com/picocreator">@picocreator</a> for pointing me to this paper!</p><h4 id="heading-other-details">Other details</h4><p>The LayerNorm layer is designed to preserve variance from the input in the output. The BitNet paper refers to <a target="_blank" href="https://www.semanticscholar.org/reader/97fb4e3d45bb098e27e0071448b6152217bd35a5">Layer Normalization</a>, which explores how this normalization improves model performance during training.</p><p>The quantization itself works basically as you would expect, and the formulas are in the diagram above, again from the BitNet paper. The one important thing to note is that during quantization, the weights and activations get scaled to the total range of the quantization (\(\gamma_w\) and \(\gamma_x\) respectively). We use these values to re-scale the outputs to the same scale as the inputs. This reportedly improves stability during training.</p><p>Additionally, because the quantization requires calculating\(\gamma_x\) and LayerNorm requires calculating the \(\mathbb{E}[x]\) and \(\sigma_x\), the naive implementation would require extra inter-GPU communication in order to split the model across GPUs using model parallelism. The BitNet paper introduced "group quantization," which basically just uses the approximate \(\gamma_x\), \(\mathbb{E}[x]\) and \(\sigma_x\) values calculated over the activations local to the GPU. They found that it has minimal impact on performance.</p><h3 id="heading-the-backwards-pass">The backwards pass</h3><p>One of the most interesting things about the BitNet b1.58 models is that they actually use full FP16 precision weights during the backwards pass in order to ensure that we have a reasonable gradient for our optimizer. If we used ternary weights in our optimizer, most optimization steps would not have any gradient for most weights, which would make it difficult figure out how to change our weights in order to improve our loss.</p><p>The consequence of this decision is that we are keeping <em>both</em> a full precision and ternary copy of weights during training. This makes training <em>less</em> memory efficient than in a standard Transformer. Since large models are bottlenecked on memory bandwidth, this has a significant impact on training performance.</p><p>I would have liked to see the authors explore this tradeoff more. We'll get to that in the Scaling Laws section below.</p><p>The authors also claim in their <a target="_blank" href="https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf">supplemental notes</a>, that "1.58 bit models can save a substantial amount memory and communication bandwidth for distributed training," but don't elaborate much on why. I don't understand their argument here, since they presumably would need to transfer the full precision weights to other nodes.</p><p>Finally, the authors mention briefly that they use a "straight-through" estimator of the weight gradients, which means that they just ignore the quantization when calculating the gradient. This is because the quantization makes the function non-differentiable. Previous research they cite has found that the straight-through estimator does about as well as more complex methods, and is trivial to implement.</p><h3 id="heading-post-training-quantization">Post training quantization</h3><p>In the original BitNet paper, the authors compared 1-bit quantization using BitNet to to various methods of post-training quantization. The key difference between BitNet / BitNet b1.58 and post-training quantization is that the forward pass is done using the quantized binary / ternary weights, so that the weights and activations are trained to be accurate in the quantized state-space.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712770676255/111ac265-65f8-41dc-ad18-f08df45c96e7.png" alt class="image--center mx-auto" /></p><p>The blue line shows the accuracy of a FP16 transformer trained using the same training regime as the 1-bit BitNet across various sizes. Three post-training quantizations on the FP16 transformer are also shown. The authors do not explain exactly how they measure accuracy, but presumably it is the average of the benchmarks they ran on these models.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712770645739/898329b1-5315-495e-9596-6ae0ce06035d.png" alt class="image--center mx-auto" /></p><p>This table shows perplexity and accuracy on other data sets:</p><ul><li><p>Winograd - a benchmark of questions that measure the ability to reason about complex statements</p></li><li><p>Winogrande - a larger and more difficult version of Winograd</p></li><li><p>Storycloze - choose the correct ending of a four sentence story</p></li><li><p>Hellaswag - choose the correct ending of a story, curated by adversarial filtering to identify examples that are hard for NLP algorithms</p></li></ul><p>Similar to the accuracies in the plot above, BitNet slightly more poorly that the FP16 transformer, whereas the post-quantization methods degrade significantly as we remove bits. Note that at 1 bit, Absmax and SmoothQuant both perform about the same as random chance.</p><p>I think the authors really should have repeated this comparison more explicitly for the BitNet b1.58 paper. The sizes of models in the papers are not comparable except for the 3B model, so it's hard to compare the results.</p><h1 id="heading-model-inference-quality">Model inference quality</h1><p>As noted in the introduction, the model performance is the surprising part of this paper. The authors trained a full FP16 LLaMa-architecture LLM (e.g. the same architecture without the quantized weights and activations) to compete with their BitNet b1.58 version at three sizes: 700M, 1.3B, and 3B parameters. They train for 100B tokens using the RedPajamas data set. They also trained a BitNet b1.58 model at 3.9B parameters. Presumably they chose this size in order to illustrate that a small delta in parameter count allows them to beat the 3B model on all of the benchmarks they ran while requiring significantly less memory and inference latency.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712446241880/be73b182-31ff-44db-8f85-aa53ac8931ae.png" alt class="image--center mx-auto" /></p><p>This table comes from the BitNet b1.58 paper. The perplexity for the BitNet b1.58 3B and 3.9B models is slightly lower than for the 3B LLaMA model. This is certainly interesting and surprising because it is not the case for smaller models. Since perplexity is a measure of the confidence of the model, if it were smaller for all of the BitNet b1.58 models then we might guess that it is simply overconfident due to some artifact of the quantization and/or training process.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712448211545/96739f0c-8d85-40da-ac77-381e0e716058.png" alt class="image--center mx-auto" /></p><p>These plots come from the supplemental notes. The drop in loss at 50B tokens is due to a change in the learning rate and weight decay schedules at that time. As the caption notes, the gap between the training loss for the models gets lower as the size of the models get larger. This suggests that the lower perplexity is not due to overconfidence. On the other hand, the authors did not show the validation loss, so perhaps the lower loss is due to overfitting?</p><p><img src="https://lh7-us.googleusercontent.com/YWznfcK5wcanQ3Z2twEKT6lvd5syO5-ZREYEqgMeoKquitI6PI_JTJvcLnarRfPZW3mRngpkoW6Q8f9tG-Psjg7b51fl00RI8VdSF6d3DSOzdOqAKr2AiU6CmUO_AtN7DobglMta4DRziS7okYjJu-8" alt /></p><p>This table, from the BitNet b1.58 paper, shows model performance across a range of inference quality benchmarks. The average performance of the 3B model is higher than the FP16 model, although this is probably in the noise. In fact, the reproduced results (shown below) show the average results being very slightly worse for the 3B model. I don't think this impacts the result: the inference quality of the model is very close to the full precision model, especially considering that LLM benchmarks are notorious for not reflecting real inference quality anyways.</p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text">I have highlighted a few entries in the table because they do not follow the standard convention that the best comparable result in a column is bold. Instead, the authors seem to have simply highlighted the results from their own models. This is confusing, and distracts from the result. I hope the authors fix this as they flesh out their final publication.</div></div><h4 id="heading-other-details-1">Other details</h4><p>As we would expect, the memory usage and latency are better with the BitNet b1.58 models. Note that the paper says "The 2-bit kernel from Ladder is also integrated for BitNet b1.58," which references a poster at a conference. I was unable to find any other information on Ladder, so we have to presume that they are using a kernel that can represent ternary integers as 2-bits and efficiently load and operate on them. I hope the authors publish the Ladder 2-bit kernel in their final publication.</p><h3 id="heading-reproduced-results">Reproduced results</h3><p>These results sort of <a target="_blank" href="https://huggingface.co/papers/2402.17764#65df84e81aaeb2f4aca3a587">seem too good to be true</a>. How can we get the same performance as full precision model with such low precision weights? Maybe the researches just made a mistake like leaking validation data?</p><p>Fortunately, someone (<a target="_blank" href="https://nousresearch.com/">Nous Research</a>?) was able to replicate the core results and published the <a target="_blank" href="https://huggingface.co/1bitLLM/bitnet_b1_58-3B">models and a summary of their findings</a>.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712460746109/051720c1-bca1-4fa8-9826-0e3d62e57330.png" alt class="image--center mx-auto" /></p><p>The replicated results are very close to the reported results in almost every case, which is very promising. This rules out a basic mistakes such as data leaks. It does not necessarily rule out methodological or fundamental epistemological errors, but the people who replicated the results were presumably thinking about these as well, and presumably would have reported flaws they discovered.</p><h3 id="heading-2t-token-model">2T token model</h3><p>The authors also trained a 2T token model in order to validate scalability, by which I think they mean to validate that a model with ternary weights does not stop improving earlier than a FP16-weight model. They followed the same data recipe as <a target="_blank" href="https://stability.wandb.io/stability-llm/stable-lm/reports/StableLM-3B-4E1T--VmlldzoyMjU4?accessToken=u3zujipenkx5g7rtcj9qojjgxpconyjktjkli2po09nffrffdhhchq045vp0wyfo">StableLM-3B-4E1T</a>, which basically repeated a 1T token data set 4x. They seem to have use the data from the continuous validation dynamics section of the technical report to compare the halfway-point in the StableLM-3B-4E1T training (e.g. at 2T tokens) to their own results.</p><p>While this seems like a reasonable comparison, this is not at all clear in the text of the paper, and should be made explicit. It does raise some minor concerns, since the mid-training results lack a cooldown period, and the "validation dynamics across our hold-out set" of the StableLM-3B-4E1T training run are unclear. There may be other problems with this methodology that someone more experienced would notice as well.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712464077249/16b78f5a-226b-4acc-9432-25492e424d4b.png" alt class="image--center mx-auto" /></p><p>In any event the above table shows that BitNet b1.58 3B performs comparably to StableLM-3B after 2T tokens of training. I don't think we can presume that BitNet b1.58 3B would actually perform better than StableLM-3B after 2T tokens of training if the Stability AI team had planned to stop training after 2T tokens, but the result does seem to suggest that BitNet b1.58 3B does indeed continue to improve in a similar way to a FP16-weight model.</p><p>Again, the authors should be more explicit about their methodology, since it raises uncertainty about the validity of their results.</p><h2 id="heading-inference-performance">Inference performance</h2><p>The validation of on-par performance of the ternary weight models goes hand in hand with the dramatic improvement in inference latency and memory consumption. These two results together, if correct, represent a significant advance in the cost/accuracy Pareto front. Particularly interesting is the observation that the improvement over a FP16 model increase as the size of the model increases, (see the plot of inference latency and memory consumption vs model size above) since the Linear layers increasingly dominate the inference cost.</p><p>The paper also discusses energy consumption since this is a better metric of computational cost than FLOPs for this study, considering that the BitNet b1.58 models get some of their advantage by eliminating the need for most of the FLOPs in the weight multiplications.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712468527640/b936069d-66cd-483c-b1f0-5897c22015d8.png" alt class="image--center mx-auto" /></p><p>While these numbers are impressive, they mostly just reiterate, in combination with the latency numbers above, that the bottleneck for inference is not compute.</p><h3 id="heading-new-hardware">New hardware?</h3><p>The abstract and conclusion mention new hardware. As we discussed above, it seems likely that the most significant improvements would come from improvements in memory utilization by using special circuitry to pack, unpack and operate on ternary weights. Since an efficient encoding would enable 10 weights per 16 bits, whereas a 2-bit representation would only allow 8 weights per 16 bits, the improvement would be limited to 25%.</p><p>In as much as memory rather than compute is the bottleneck for LLM inference, this does not seems worth the enormous expense of specialized ternary circuitry, especially considering that it would limit the chip's usefulness for non-LLM applications.</p><h1 id="heading-scaling-laws">Scaling laws</h1><p>Following from the discussion in the inference performance section above, the improvement in inference cost will be driven by memory utilization, or perhaps total latency. Since previous scaling law papers have used FLOPs as an approximate proxy for cost, we can do a back-of-the-envelope approximation of the impact of BitNet b1.58 on cost by just using the latency or memory deltas to adjust the number of inference "FLOPs" in the scaling law.</p><p>For the unfamiliar, the <a target="_blank" href="https://arxiv.org/pdf/2203.15556.pdf">Chinchilla paper</a> found that cost optimal training regime, without consideration for inference costs, is about 20 training tokens per model parameter.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712516518183/980fc5dd-cc36-4ae4-85e8-72f6b0c212bd.png" alt class="image--center mx-auto" /></p><p>The above table, reproduced from the Chinchilla paper, shows the number of training tokens for various model sizes needed to optimize training costs.</p><p>Since the release of the Chinchilla paper, attention has shifted towards inference cost as companies have put models into production.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712516801300/c05b8fb7-5906-4a51-b115-68c13658f92a.png" alt class="image--center mx-auto" /></p><p>The above equation, from <a target="_blank" href="https://arxiv.org/pdf/2401.00448.pdf">Beyond Chinchilla-Optimal: Accounting for Inference in Language Model Scaling Laws</a> uses the common rule of thumb that Transformer-architecture LLMS need \(6N\)FLOPs per training token, and \(2N\)FLOPS per inference token. In this equation, the \(N\)values represent the number of model parameters, and the \(D\) values represent the number of tokens.</p><p>Note that if we don't do any inference with our model, then this equation reduces to the Chinchilla-optimal cost. So, intuitively, we are just scaling the second term as our inference becomes more efficient, and getting closer to the Chinchilla-optimal token/parameter ratio.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712198470217/84fb50ed-5338-4423-af34-428a71d367d9.png" alt="Latency and memory consumption are (not surprisingly) dramatically better for the authors' models with ternary as opposed to full 16-bit FP weights" class="image--center mx-auto" /></p><p>So for our back-of-the-envelope calculation, we can adjust the 2nd term in the equation above by our memory or latency factor from the very first plot in this post, reproduced above for clarity.</p><p>There's a lot to keep in our heads here, so as an example, assuming a 70B model, then we would divide the inference tokens term in the equation above by 4.1 if we are assuming cost scales with latency, or 7.16 if we assume that cost scales with memory utilization</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712773527558/984b29cf-d9da-44dc-a0d3-f88f5a372433.jpeg" alt class="image--center mx-auto" /></p><p>The above plot, also from inference scaling law paper show the adjustment factor to apply to the Chinchilla-optimal model configuration for given target training loss, and given number of inference tokens.</p><p>Again, this is a lot to keep in our heads, so let's continue with our example. Since training loss is highly dependent on all sorts of factors, we'll treat it as "funny money" and just pick something that seems reasonable, like 2. This seems reasonable based on the loss curves from the <a target="_blank" href="https://stability.wandb.io/stability-llm/stable-lm/reports/StableLM-3B-4E1T--VmlldzoyMjU4?accessToken=u3zujipenkx5g7rtcj9qojjgxpconyjktjkli2po09nffrffdhhchq045vp0wyfo">StableLM-3B-4E1T Technical Report</a>.</p><p>Now, lets take a look at the contour plot (the ratio of cost-optimal pre-training tokens to Chinchilla optimal tokens). If we look at the 2.0 training loss line, we can look up some total number of inference tokens that we expect to pay for. Let's call it 10e13 = 10T. Now if we adjust the number of inference tokens by the latency reduction factor of 4.10, then we need to look \(log_{10}(4.10)\approx0.6\) ticks on the graph to the left of the 10e13 line (green arrow). This puts us right around the 2x contour. So this tells us that we should use about 2x the chinchilla optimal number of parameters, which is about 1400B x 2 = 2800B training tokens.</p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text">In the original post, I missed a zero, and said 140B x 2 = 280B. Ooops!</div></div><p>We can do a similar exercise if we have a given number of training tokens and need to know how to pick a cost optimal size model at a given loss (see the paper for details).</p><p>This is pretty awesome! In essence, we need fewer training tokens to get to cost optimal for a known number of inference tokens. Of course, the number of inference tokens is hard to project for a proprietary model, let alone an open source model. This scaling law will also run out steam in the limits of model sizes and token counts, but it's not clear from the paper where those limits are. But it is certainly useful to develop an intuition for the impact of the BitNet b1.58 models on scaling.</p><h1 id="heading-tying-it-up">Tying it up</h1><p>Someone trained a <a target="_blank" href="https://huggingface.co/1bitLLM/bitnet_b1_58-3B">set of models</a> using the recipe from the authors. These models have the full FP16 weights, which is great because we can use them to study the relationship between the full precision weights and the quantized ones. It does mean, however, that if we want to get the performance benefits of the quantization, we have to actually quantize the FP16 weights down to ternary and update the inference code to take advantage of it.</p><p>Fortunately, HuggingFace user <a target="_blank" href="https://huggingface.co/kousw">kousw</a> has <a target="_blank" href="https://huggingface.co/kousw/bitnet_b1_58-3B_quantized">quantized the 3B model</a> and written the code to do inference on it! I have a demo notebook here: <a target="_blank" href="https://colab.research.google.com/drive/1KDQBle0hByR9oB1b9MVx9nmaDiHTr8_9?usp=sharing">https://colab.research.google.com/drive/1KDQBle0hByR9oB1b9MVx9nmaDiHTr8_9</a></p><p>In any event, this paper does a great job of telling us about an exciting result, but leaves a lot to the reader to dig up for themselves. This is fine for a pre-print, I just hope they will flesh out the paper a bit more before publishing it. I would have liked to see</p><ul><li><p>a more direct comparison of the quantization results</p></li><li><p>more explicit exposition of the scaling law claims</p></li><li><p>a better explanation of their arguments about hardware</p></li></ul><p>Perhaps more importantly, we got a lot of what, a bit of how, and very little why.</p><ul><li><p>why does post quantization work better? They have a hypothesis in their supplemental notes ("<em>This is because PTQ can only quantize model weights around the full-precision values, which may not necessarily be where the global minima of the low-precision landscape lie</em>.") The authors should design an experiment to support their hypothesis, or cite research on it more clearly.</p></li><li><p>Why does ternary work just as well as FP16? Why not binary?</p><ul><li><p>Since higher precision activations seem to be important, does information move from the weights into other places like biases or other parameters? We could test this by measuring the entropy in the weights, biases and activations in an unquantized vs. quantized weight training regime</p></li><li><p>For ternary vs. binary, can we correlate errors to the inability to mask features?</p></li></ul></li><li><p>Why does ternary work <em>better</em> than FP16?</p><ul><li>Is this a kind of regularization?</li></ul></li><li><p>Does using a ternary weighted model have a lower information capacity than models with more bits per weight? The authors state in the supplemental notes regarding use all 4 values of the 2-bit representation "While this is a reasonable approach due to the availability of shift operations for -2 and 2 to eliminate multiplication, ternary quantization already achieves comparable performance to full-precision models. Following Occams Razor principle, we refrained from introducing more values beyond {-1, 0, 1}." Perhaps, given the results from <a target="_blank" href="https://arxiv.org/pdf/2404.05405.pdf">Physics of Language Models: Part 3.3, Knowledge Capacity Scaling Laws</a>, this decision is premature?</p></li></ul><p>They have hypotheses, but don't explore them, and so this reads as a (very interesting) technical report, and not a research paper. That's fine for a pre-print, but I would love to know more about why, and not just what.</p><p>In conclusion, what a great result! I look forward to hearing more from the authors, hopefully some followup on the things I mentioned above, and of course, their models and code!</p><h1 id="heading-resources">Resources</h1><ul><li><p><a target="_blank" href="https://arxiv.org/abs/2402.17764">The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits</a> - the paper itself</p><ul><li><p><a target="_blank" href="https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf">Supplemental notes</a> - details about how to replicate their results, and answering some of the questions they got as a result of the pre-print</p></li><li><p><a target="_blank" href="https://arxiv.org/pdf/2310.11453.pdf">BitNet: Scaling 1-bit Transformers for Large Language Models</a> - the preceding paper. You probably need to read this paper as well if you want to understand this one</p></li><li><p><a target="_blank" href="https://www.semanticscholar.org/reader/97fb4e3d45bb098e27e0071448b6152217bd35a5">Layer Normalization</a> - the source of their LayerNorm stage</p></li><li><p><a target="_blank" href="https://huggingface.co/papers/2402.17764">The paper page on HuggingFace with some great discussion about the results and models</a></p></li></ul></li><li><p>The <a target="_blank" href="https://huggingface.co/1bitLLM">replicated models</a> on HuggingFace, including some replicated evals</p><ul><li><p>The <a target="_blank" href="https://huggingface.co/kousw/bitnet_b1_58-3B_quantized">quantized 3B model</a> for faster inference</p></li><li><p>My <a target="_blank" href="https://colab.research.google.com/drive/1KDQBle0hByR9oB1b9MVx9nmaDiHTr8_9?usp=sharing">notebook</a> to fiddle around with the quantized model</p></li><li><p><a target="_blank" href="https://stability.wandb.io/stability-llm/stable-lm/reports/StableLM-3B-4E1T--VmlldzoyMjU4?accessToken=u3zujipenkx5g7rtcj9qojjgxpconyjktjkli2po09nffrffdhhchq045vp0wyfo">StableLM-3B-4E1T Technical Report</a> - the presumptive source of the 2T token comparison that they make to argue that they continue to scale as the number of training tokens increasses</p></li></ul></li><li><p>Scaling laws</p><ul><li><p><a target="_blank" href="https://arxiv.org/pdf/2203.15556.pdf">Chinchilla paper</a></p></li><li><p><a target="_blank" href="https://arxiv.org/pdf/2401.00448.pdf">Beyond Chinchilla-Optimal: Accounting for Inference in Language Model Scaling Laws</a></p></li><li><p><a target="_blank" href="https://arxiv.org/pdf/2404.05405.pdf">Physics of Language Models: Part 3.3, Knowledge Capacity Scaling Laws</a> - a recent analysis of the knowledge capacity of transformer-based models</p></li></ul></li><li><p>Meta stuff</p><ul><li><p><a target="_blank" href="https://www.swyx.io/learn-in-public">learning in public</a> - a great post by <a target="_blank" href="https://twitter.com/swyx">@swyx</a> about how and why to share your learning journey with others</p></li><li><p><a target="_blank" href="https://www.swyx.io/learn-in-public">learning exhaust</a> - a great post by <a target="_blank" href="https://twitter.com/swyx">@swyx</a> about how to improve your learning by taking advantage of the artifacts you produce while learning</p></li></ul></li></ul>]]><![CDATA[<h2 id="heading-introduction">Introduction</h2><p>This post is my <a target="_blank" href="https://www.swyx.io/learn-in-public">learning exhaust</a> from reading an exciting pre-print paper titled <a target="_blank" href="https://arxiv.org/abs/2402.17764">The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits</a> about very efficient representations of high-performing LLMs. I am trying to come up to speed on deep learning after having been heads down in telecommunications systems and data at a startup since leaving research 13 years ago (!!!). This was a really fun way to take a deep dive, so Im <a target="_blank" href="https://www.swyx.io/learn-in-public">learning in public</a>, so this post is the guide to the paper I wish I had before I read it :)</p><details><summary>1.58 bits?</summary><div data-type="detailsContent">1.58 bits is the number of bits needed to represent a base 3 (e.g. -1, 0, 1) digit. We can compute this with <code>log2(3) ~= 1.58</code> (sorry for the bad notation). It's weird to have a fractional number of bits. Intuitively we could fit <code>floor(16/1.58) = 10</code> properly encoded ternary digits in a 16-bit encoding, but in practice it's probably easier to just use a 2-bit encoding, waste 1 value, and take advantage of the highly optimized GPU hardware that is good at whole numbers of bits.</div></details><p>I dug into the details of this paper, and I'm hoping to help someone who is curious about the paper understand some of its nuances without digging through the references and reading up on all of the sub-topics.</p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text">In order to understand the paper, you really need to also read the authors' previous paper, <a target="_blank" href="https://arxiv.org/pdf/2310.11453.pdf">BitNet: Scaling 1-bit Transformers for Large Language Models</a> since the authors reference that paper both explicitly and implicitly. Even some of the high level conclusions, such as discussion of scaling laws, are only briefly mentioned, and not even explicitly referenced, although the BitNet paper treats this topic in much greater depth. This paper is a short pre-print, and doesn't yet stand on its own.</div></div><p>In this post, I've offered some questions and criticisms of the paper, but please keep in mind that this is a pre-print about a great result, and so we should be reading it that way. There will, of course, be issues with a pre-print like this, and I'm overall very excited about the work the authors have done!</p><h3 id="heading-the-punchline">The punchline</h3><p>At a very high level, the proposed architecture uses ternary weights during inference, which allows a dramatic increase in performance. Latency and memory consumption are (not surprisingly) dramatically better for the authors' models with ternary (-1, 0, 1) as opposed to full 16-bit FP weights. Here are plots from the paper:</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712198470217/84fb50ed-5338-4423-af34-428a71d367d9.png" alt="Latency and memory consumption are (not surprisingly) dramatically better for the authors' models with ternary as opposed to full 16-bit FP weights" class="image--center mx-auto" /></p><p>The punchline, however, is that the paper shows evidence that if we make a few tweaks to how we train ternary LLMs, <em>they have the same performance as LLMs trained with 16-bit weights</em>. The implication of this (if it is accurate) is that we can dramatically improve the performance of inference (but not training) without sacrificing quality.</p><p>This result is surprising (to me). Intuitively, by changing from a floating point representation to a ternary representation for the weights, we are dramatically reducing the amount of internal state that is available for the network to learn a mapping from input to output. On the other hand, researchers have spent a lot of effort investigating sparsity, with <a target="_blank" href="https://arxiv.org/pdf/2102.00554.pdf">lots of results</a> indicating that the active part of a neural network is often much smaller than the entire weight space (e.g. a lot of the weights are inactive, even in latent space). Maybe we pay for using fewer bits with denser networks? Or maybe higher precision in the weights is just unimportant because we add up a bunch of high precision activations, so we don't really need precise coefficients to blend the features from the previous layers?</p><p>Before we discuss the why, though, let's talk about the what and the how.</p><h2 id="heading-the-key-algorithmic-contributions">The key algorithmic contributions</h2><p>The concept of a binary weight neural network is not new at all. Im not an expert, so I wont point you to a particular survey, but <a target="_blank" href="https://www.google.com/search?q=binary+neuaral+network">papers abound</a>. Ternary weights have also been studied a lot because they allow (as the authors note in their introduction) the weights to mask features from previous layers by using the 0 values for some weights.</p><p>Regardless, the authors made some important decisions in the design that seem to have contributed to the model's effectiveness. Lets take a close look at how it works, particularly how it differs from regular, floating-point-weight models.</p><p>Most of the details about the model itself actually come from their previous paper, called <a target="_blank" href="https://arxiv.org/pdf/2310.11453.pdf">BitNet: Scaling 1-bit Transformers for Large Language Models</a>. The authors follow the same architecture in both papers, with the primary difference being using ternary vs. binary weights.</p><p>Here is a block diagram from the BitNet paper that explains the modified Transformer architecture at a high level.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712199049619/a30e81b0-1fda-425a-87bf-8af4fdf71212.png" alt class="image--center mx-auto" /></p><p>In other words they have replaced the standard linear blocks with the BitLinear block from the paper. Let's take a look at the way that BitLinear works in the context of the forward and backward passes.</p><h3 id="heading-the-forward-pass">The forward pass</h3><p><img src="https://lh7-us.googleusercontent.com/0M5WyGxzsAEVC61bOHeDp1uwt-SXcU3zWYLpOupin07Py0D6lXc2hDsEObD8OaU4facqfHrdY693TTWu5qAYngqtLtSFVH9vJcTAlFWUxiqg4dyRZseg2sivsQkcw1ooL8mxva-qltvbPpKP4a2zatw" alt /></p><p>I reproduced this diagram from the BitNet paper. It shows the BitLinear block (see the previous diagram for context). I added the various formulas from the BitNet paper and some of the references from <em>that</em> paper. I also unified the notational differences between the various papers.</p><p>The first important architectural decision the authors made is to only use ternary (or binary) weights; e.g. don't use binary/ternary activations. The consequence of this decision is that information propagates through the network at higher fidelity (actually at 8-bit precision), and <em>only the MLP transform layers (including attention layers) are ternary</em>.</p><p>In other words, we can add, subtract or ignore features from the previous layers, but not change their "strength". For example, the \(i\)th activation at layer \(L\) might look like this</p><p>$$a^{(L)}_i= a^{(L-1)}_1-a^{(L-1)}_2 +a^{(L-1)}_4 + \cdots$$</p><p>instead of something like</p><p>$$a^{(L)}_i= w^{(L)}_{1,i} a^{(L-1)}_{1} + w^{(L)}_{2,i} a^{(L-1)}_{2} + \cdots$$</p><p>where \(w_{j,i}\) is the weight between \(a_j^{(L-1)}\)and \(a_i^{(L)}\).</p><p>Here is a diagram of attention weights from an <a target="_blank" href="https://towardsdatascience.com/deconstructing-bert-part-2-visualizing-the-inner-workings-of-attention-60a16d86b5c1">article</a> that describes the <a target="_blank" href="https://github.com/jessevig/bertviz?tab=readme-ov-file#model-view">bertviz</a> tool. You can play with great interactive visualizations in the <a target="_blank" href="https://colab.research.google.com/drive/1hXIQ77A4TYS4y3UthWF-Ci7V7vVUoxmQ?usp=sharing#scrollTo=TG-dQt3NOlub">notebook</a> they link. (Note: these are not the weights from the BitNet models. I would love to follow up with a visualization like this of the weights from the actual models)</p><p><img src="https://miro.medium.com/v2/resize:fit:1400/1*Ak1_htrg0jctCVEqeMgaxQ.png" alt /></p><p>Notice that many of the plots are somewhat sparse, and that the intensities tend to be either on or off. This is probably partly due the way the weights are visualized, but what stands out for me is that <em>some are on, some are off, with not a lot in between</em> for many weights. Perhaps this visual intuition goes some way to explaining why ternary weights are good enough? Maybe we just need to activate or not activate upstream features, and not be too concerned with <em>how much</em> we activate them (although the sign of the activation seems to matter too)?</p><p>A recent paper called <a target="_blank" href="https://arxiv.org/pdf/2404.05405.pdf">Physics of Language Models: Part 3.3, Knowledge Capacity Scaling Laws</a> provides some more evidence about this. The paper shows that common Transformer architectures have a total information capacity of about 2 bits per parameter. They also show that this holds even in modified models that don't have any MLP (linear) layers, which suggest that information can move into other parts of the model during training. <strong>This raises the question of whether a ternary-weight model will reach its information capacity sooner than 2 or more bit model, since 1.58 < 2 (to state the obvious).</strong> Thanks to <a target="_blank" href="https://twitter.com/picocreator">@picocreator</a> for pointing me to this paper!</p><h4 id="heading-other-details">Other details</h4><p>The LayerNorm layer is designed to preserve variance from the input in the output. The BitNet paper refers to <a target="_blank" href="https://www.semanticscholar.org/reader/97fb4e3d45bb098e27e0071448b6152217bd35a5">Layer Normalization</a>, which explores how this normalization improves model performance during training.</p><p>The quantization itself works basically as you would expect, and the formulas are in the diagram above, again from the BitNet paper. The one important thing to note is that during quantization, the weights and activations get scaled to the total range of the quantization (\(\gamma_w\) and \(\gamma_x\) respectively). We use these values to re-scale the outputs to the same scale as the inputs. This reportedly improves stability during training.</p><p>Additionally, because the quantization requires calculating\(\gamma_x\) and LayerNorm requires calculating the \(\mathbb{E}[x]\) and \(\sigma_x\), the naive implementation would require extra inter-GPU communication in order to split the model across GPUs using model parallelism. The BitNet paper introduced "group quantization," which basically just uses the approximate \(\gamma_x\), \(\mathbb{E}[x]\) and \(\sigma_x\) values calculated over the activations local to the GPU. They found that it has minimal impact on performance.</p><h3 id="heading-the-backwards-pass">The backwards pass</h3><p>One of the most interesting things about the BitNet b1.58 models is that they actually use full FP16 precision weights during the backwards pass in order to ensure that we have a reasonable gradient for our optimizer. If we used ternary weights in our optimizer, most optimization steps would not have any gradient for most weights, which would make it difficult figure out how to change our weights in order to improve our loss.</p><p>The consequence of this decision is that we are keeping <em>both</em> a full precision and ternary copy of weights during training. This makes training <em>less</em> memory efficient than in a standard Transformer. Since large models are bottlenecked on memory bandwidth, this has a significant impact on training performance.</p><p>I would have liked to see the authors explore this tradeoff more. We'll get to that in the Scaling Laws section below.</p><p>The authors also claim in their <a target="_blank" href="https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf">supplemental notes</a>, that "1.58 bit models can save a substantial amount memory and communication bandwidth for distributed training," but don't elaborate much on why. I don't understand their argument here, since they presumably would need to transfer the full precision weights to other nodes.</p><p>Finally, the authors mention briefly that they use a "straight-through" estimator of the weight gradients, which means that they just ignore the quantization when calculating the gradient. This is because the quantization makes the function non-differentiable. Previous research they cite has found that the straight-through estimator does about as well as more complex methods, and is trivial to implement.</p><h3 id="heading-post-training-quantization">Post training quantization</h3><p>In the original BitNet paper, the authors compared 1-bit quantization using BitNet to to various methods of post-training quantization. The key difference between BitNet / BitNet b1.58 and post-training quantization is that the forward pass is done using the quantized binary / ternary weights, so that the weights and activations are trained to be accurate in the quantized state-space.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712770676255/111ac265-65f8-41dc-ad18-f08df45c96e7.png" alt class="image--center mx-auto" /></p><p>The blue line shows the accuracy of a FP16 transformer trained using the same training regime as the 1-bit BitNet across various sizes. Three post-training quantizations on the FP16 transformer are also shown. The authors do not explain exactly how they measure accuracy, but presumably it is the average of the benchmarks they ran on these models.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712770645739/898329b1-5315-495e-9596-6ae0ce06035d.png" alt class="image--center mx-auto" /></p><p>This table shows perplexity and accuracy on other data sets:</p><ul><li><p>Winograd - a benchmark of questions that measure the ability to reason about complex statements</p></li><li><p>Winogrande - a larger and more difficult version of Winograd</p></li><li><p>Storycloze - choose the correct ending of a four sentence story</p></li><li><p>Hellaswag - choose the correct ending of a story, curated by adversarial filtering to identify examples that are hard for NLP algorithms</p></li></ul><p>Similar to the accuracies in the plot above, BitNet slightly more poorly that the FP16 transformer, whereas the post-quantization methods degrade significantly as we remove bits. Note that at 1 bit, Absmax and SmoothQuant both perform about the same as random chance.</p><p>I think the authors really should have repeated this comparison more explicitly for the BitNet b1.58 paper. The sizes of models in the papers are not comparable except for the 3B model, so it's hard to compare the results.</p><h1 id="heading-model-inference-quality">Model inference quality</h1><p>As noted in the introduction, the model performance is the surprising part of this paper. The authors trained a full FP16 LLaMa-architecture LLM (e.g. the same architecture without the quantized weights and activations) to compete with their BitNet b1.58 version at three sizes: 700M, 1.3B, and 3B parameters. They train for 100B tokens using the RedPajamas data set. They also trained a BitNet b1.58 model at 3.9B parameters. Presumably they chose this size in order to illustrate that a small delta in parameter count allows them to beat the 3B model on all of the benchmarks they ran while requiring significantly less memory and inference latency.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712446241880/be73b182-31ff-44db-8f85-aa53ac8931ae.png" alt class="image--center mx-auto" /></p><p>This table comes from the BitNet b1.58 paper. The perplexity for the BitNet b1.58 3B and 3.9B models is slightly lower than for the 3B LLaMA model. This is certainly interesting and surprising because it is not the case for smaller models. Since perplexity is a measure of the confidence of the model, if it were smaller for all of the BitNet b1.58 models then we might guess that it is simply overconfident due to some artifact of the quantization and/or training process.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712448211545/96739f0c-8d85-40da-ac77-381e0e716058.png" alt class="image--center mx-auto" /></p><p>These plots come from the supplemental notes. The drop in loss at 50B tokens is due to a change in the learning rate and weight decay schedules at that time. As the caption notes, the gap between the training loss for the models gets lower as the size of the models get larger. This suggests that the lower perplexity is not due to overconfidence. On the other hand, the authors did not show the validation loss, so perhaps the lower loss is due to overfitting?</p><p><img src="https://lh7-us.googleusercontent.com/YWznfcK5wcanQ3Z2twEKT6lvd5syO5-ZREYEqgMeoKquitI6PI_JTJvcLnarRfPZW3mRngpkoW6Q8f9tG-Psjg7b51fl00RI8VdSF6d3DSOzdOqAKr2AiU6CmUO_AtN7DobglMta4DRziS7okYjJu-8" alt /></p><p>This table, from the BitNet b1.58 paper, shows model performance across a range of inference quality benchmarks. The average performance of the 3B model is higher than the FP16 model, although this is probably in the noise. In fact, the reproduced results (shown below) show the average results being very slightly worse for the 3B model. I don't think this impacts the result: the inference quality of the model is very close to the full precision model, especially considering that LLM benchmarks are notorious for not reflecting real inference quality anyways.</p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text">I have highlighted a few entries in the table because they do not follow the standard convention that the best comparable result in a column is bold. Instead, the authors seem to have simply highlighted the results from their own models. This is confusing, and distracts from the result. I hope the authors fix this as they flesh out their final publication.</div></div><h4 id="heading-other-details-1">Other details</h4><p>As we would expect, the memory usage and latency are better with the BitNet b1.58 models. Note that the paper says "The 2-bit kernel from Ladder is also integrated for BitNet b1.58," which references a poster at a conference. I was unable to find any other information on Ladder, so we have to presume that they are using a kernel that can represent ternary integers as 2-bits and efficiently load and operate on them. I hope the authors publish the Ladder 2-bit kernel in their final publication.</p><h3 id="heading-reproduced-results">Reproduced results</h3><p>These results sort of <a target="_blank" href="https://huggingface.co/papers/2402.17764#65df84e81aaeb2f4aca3a587">seem too good to be true</a>. How can we get the same performance as full precision model with such low precision weights? Maybe the researches just made a mistake like leaking validation data?</p><p>Fortunately, someone (<a target="_blank" href="https://nousresearch.com/">Nous Research</a>?) was able to replicate the core results and published the <a target="_blank" href="https://huggingface.co/1bitLLM/bitnet_b1_58-3B">models and a summary of their findings</a>.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712460746109/051720c1-bca1-4fa8-9826-0e3d62e57330.png" alt class="image--center mx-auto" /></p><p>The replicated results are very close to the reported results in almost every case, which is very promising. This rules out a basic mistakes such as data leaks. It does not necessarily rule out methodological or fundamental epistemological errors, but the people who replicated the results were presumably thinking about these as well, and presumably would have reported flaws they discovered.</p><h3 id="heading-2t-token-model">2T token model</h3><p>The authors also trained a 2T token model in order to validate scalability, by which I think they mean to validate that a model with ternary weights does not stop improving earlier than a FP16-weight model. They followed the same data recipe as <a target="_blank" href="https://stability.wandb.io/stability-llm/stable-lm/reports/StableLM-3B-4E1T--VmlldzoyMjU4?accessToken=u3zujipenkx5g7rtcj9qojjgxpconyjktjkli2po09nffrffdhhchq045vp0wyfo">StableLM-3B-4E1T</a>, which basically repeated a 1T token data set 4x. They seem to have use the data from the continuous validation dynamics section of the technical report to compare the halfway-point in the StableLM-3B-4E1T training (e.g. at 2T tokens) to their own results.</p><p>While this seems like a reasonable comparison, this is not at all clear in the text of the paper, and should be made explicit. It does raise some minor concerns, since the mid-training results lack a cooldown period, and the "validation dynamics across our hold-out set" of the StableLM-3B-4E1T training run are unclear. There may be other problems with this methodology that someone more experienced would notice as well.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712464077249/16b78f5a-226b-4acc-9432-25492e424d4b.png" alt class="image--center mx-auto" /></p><p>In any event the above table shows that BitNet b1.58 3B performs comparably to StableLM-3B after 2T tokens of training. I don't think we can presume that BitNet b1.58 3B would actually perform better than StableLM-3B after 2T tokens of training if the Stability AI team had planned to stop training after 2T tokens, but the result does seem to suggest that BitNet b1.58 3B does indeed continue to improve in a similar way to a FP16-weight model.</p><p>Again, the authors should be more explicit about their methodology, since it raises uncertainty about the validity of their results.</p><h2 id="heading-inference-performance">Inference performance</h2><p>The validation of on-par performance of the ternary weight models goes hand in hand with the dramatic improvement in inference latency and memory consumption. These two results together, if correct, represent a significant advance in the cost/accuracy Pareto front. Particularly interesting is the observation that the improvement over a FP16 model increase as the size of the model increases, (see the plot of inference latency and memory consumption vs model size above) since the Linear layers increasingly dominate the inference cost.</p><p>The paper also discusses energy consumption since this is a better metric of computational cost than FLOPs for this study, considering that the BitNet b1.58 models get some of their advantage by eliminating the need for most of the FLOPs in the weight multiplications.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712468527640/b936069d-66cd-483c-b1f0-5897c22015d8.png" alt class="image--center mx-auto" /></p><p>While these numbers are impressive, they mostly just reiterate, in combination with the latency numbers above, that the bottleneck for inference is not compute.</p><h3 id="heading-new-hardware">New hardware?</h3><p>The abstract and conclusion mention new hardware. As we discussed above, it seems likely that the most significant improvements would come from improvements in memory utilization by using special circuitry to pack, unpack and operate on ternary weights. Since an efficient encoding would enable 10 weights per 16 bits, whereas a 2-bit representation would only allow 8 weights per 16 bits, the improvement would be limited to 25%.</p><p>In as much as memory rather than compute is the bottleneck for LLM inference, this does not seems worth the enormous expense of specialized ternary circuitry, especially considering that it would limit the chip's usefulness for non-LLM applications.</p><h1 id="heading-scaling-laws">Scaling laws</h1><p>Following from the discussion in the inference performance section above, the improvement in inference cost will be driven by memory utilization, or perhaps total latency. Since previous scaling law papers have used FLOPs as an approximate proxy for cost, we can do a back-of-the-envelope approximation of the impact of BitNet b1.58 on cost by just using the latency or memory deltas to adjust the number of inference "FLOPs" in the scaling law.</p><p>For the unfamiliar, the <a target="_blank" href="https://arxiv.org/pdf/2203.15556.pdf">Chinchilla paper</a> found that cost optimal training regime, without consideration for inference costs, is about 20 training tokens per model parameter.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712516518183/980fc5dd-cc36-4ae4-85e8-72f6b0c212bd.png" alt class="image--center mx-auto" /></p><p>The above table, reproduced from the Chinchilla paper, shows the number of training tokens for various model sizes needed to optimize training costs.</p><p>Since the release of the Chinchilla paper, attention has shifted towards inference cost as companies have put models into production.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712516801300/c05b8fb7-5906-4a51-b115-68c13658f92a.png" alt class="image--center mx-auto" /></p><p>The above equation, from <a target="_blank" href="https://arxiv.org/pdf/2401.00448.pdf">Beyond Chinchilla-Optimal: Accounting for Inference in Language Model Scaling Laws</a> uses the common rule of thumb that Transformer-architecture LLMS need \(6N\)FLOPs per training token, and \(2N\)FLOPS per inference token. In this equation, the \(N\)values represent the number of model parameters, and the \(D\) values represent the number of tokens.</p><p>Note that if we don't do any inference with our model, then this equation reduces to the Chinchilla-optimal cost. So, intuitively, we are just scaling the second term as our inference becomes more efficient, and getting closer to the Chinchilla-optimal token/parameter ratio.</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712198470217/84fb50ed-5338-4423-af34-428a71d367d9.png" alt="Latency and memory consumption are (not surprisingly) dramatically better for the authors' models with ternary as opposed to full 16-bit FP weights" class="image--center mx-auto" /></p><p>So for our back-of-the-envelope calculation, we can adjust the 2nd term in the equation above by our memory or latency factor from the very first plot in this post, reproduced above for clarity.</p><p>There's a lot to keep in our heads here, so as an example, assuming a 70B model, then we would divide the inference tokens term in the equation above by 4.1 if we are assuming cost scales with latency, or 7.16 if we assume that cost scales with memory utilization</p><p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1712773527558/984b29cf-d9da-44dc-a0d3-f88f5a372433.jpeg" alt class="image--center mx-auto" /></p><p>The above plot, also from inference scaling law paper show the adjustment factor to apply to the Chinchilla-optimal model configuration for given target training loss, and given number of inference tokens.</p><p>Again, this is a lot to keep in our heads, so let's continue with our example. Since training loss is highly dependent on all sorts of factors, we'll treat it as "funny money" and just pick something that seems reasonable, like 2. This seems reasonable based on the loss curves from the <a target="_blank" href="https://stability.wandb.io/stability-llm/stable-lm/reports/StableLM-3B-4E1T--VmlldzoyMjU4?accessToken=u3zujipenkx5g7rtcj9qojjgxpconyjktjkli2po09nffrffdhhchq045vp0wyfo">StableLM-3B-4E1T Technical Report</a>.</p><p>Now, lets take a look at the contour plot (the ratio of cost-optimal pre-training tokens to Chinchilla optimal tokens). If we look at the 2.0 training loss line, we can look up some total number of inference tokens that we expect to pay for. Let's call it 10e13 = 10T. Now if we adjust the number of inference tokens by the latency reduction factor of 4.10, then we need to look \(log_{10}(4.10)\approx0.6\) ticks on the graph to the left of the 10e13 line (green arrow). This puts us right around the 2x contour. So this tells us that we should use about 2x the chinchilla optimal number of parameters, which is about 1400B x 2 = 2800B training tokens.</p><div data-node-type="callout"><div data-node-type="callout-emoji">ðŸ’¡</div><div data-node-type="callout-text">In the original post, I missed a zero, and said 140B x 2 = 280B. Ooops!</div></div><p>We can do a similar exercise if we have a given number of training tokens and need to know how to pick a cost optimal size model at a given loss (see the paper for details).</p><p>This is pretty awesome! In essence, we need fewer training tokens to get to cost optimal for a known number of inference tokens. Of course, the number of inference tokens is hard to project for a proprietary model, let alone an open source model. This scaling law will also run out steam in the limits of model sizes and token counts, but it's not clear from the paper where those limits are. But it is certainly useful to develop an intuition for the impact of the BitNet b1.58 models on scaling.</p><h1 id="heading-tying-it-up">Tying it up</h1><p>Someone trained a <a target="_blank" href="https://huggingface.co/1bitLLM/bitnet_b1_58-3B">set of models</a> using the recipe from the authors. These models have the full FP16 weights, which is great because we can use them to study the relationship between the full precision weights and the quantized ones. It does mean, however, that if we want to get the performance benefits of the quantization, we have to actually quantize the FP16 weights down to ternary and update the inference code to take advantage of it.</p><p>Fortunately, HuggingFace user <a target="_blank" href="https://huggingface.co/kousw">kousw</a> has <a target="_blank" href="https://huggingface.co/kousw/bitnet_b1_58-3B_quantized">quantized the 3B model</a> and written the code to do inference on it! I have a demo notebook here: <a target="_blank" href="https://colab.research.google.com/drive/1KDQBle0hByR9oB1b9MVx9nmaDiHTr8_9?usp=sharing">https://colab.research.google.com/drive/1KDQBle0hByR9oB1b9MVx9nmaDiHTr8_9</a></p><p>In any event, this paper does a great job of telling us about an exciting result, but leaves a lot to the reader to dig up for themselves. This is fine for a pre-print, I just hope they will flesh out the paper a bit more before publishing it. I would have liked to see</p><ul><li><p>a more direct comparison of the quantization results</p></li><li><p>more explicit exposition of the scaling law claims</p></li><li><p>a better explanation of their arguments about hardware</p></li></ul><p>Perhaps more importantly, we got a lot of what, a bit of how, and very little why.</p><ul><li><p>why does post quantization work better? They have a hypothesis in their supplemental notes ("<em>This is because PTQ can only quantize model weights around the full-precision values, which may not necessarily be where the global minima of the low-precision landscape lie</em>.") The authors should design an experiment to support their hypothesis, or cite research on it more clearly.</p></li><li><p>Why does ternary work just as well as FP16? Why not binary?</p><ul><li><p>Since higher precision activations seem to be important, does information move from the weights into other places like biases or other parameters? We could test this by measuring the entropy in the weights, biases and activations in an unquantized vs. quantized weight training regime</p></li><li><p>For ternary vs. binary, can we correlate errors to the inability to mask features?</p></li></ul></li><li><p>Why does ternary work <em>better</em> than FP16?</p><ul><li>Is this a kind of regularization?</li></ul></li><li><p>Does using a ternary weighted model have a lower information capacity than models with more bits per weight? The authors state in the supplemental notes regarding use all 4 values of the 2-bit representation "While this is a reasonable approach due to the availability of shift operations for -2 and 2 to eliminate multiplication, ternary quantization already achieves comparable performance to full-precision models. Following Occams Razor principle, we refrained from introducing more values beyond {-1, 0, 1}." Perhaps, given the results from <a target="_blank" href="https://arxiv.org/pdf/2404.05405.pdf">Physics of Language Models: Part 3.3, Knowledge Capacity Scaling Laws</a>, this decision is premature?</p></li></ul><p>They have hypotheses, but don't explore them, and so this reads as a (very interesting) technical report, and not a research paper. That's fine for a pre-print, but I would love to know more about why, and not just what.</p><p>In conclusion, what a great result! I look forward to hearing more from the authors, hopefully some followup on the things I mentioned above, and of course, their models and code!</p><h1 id="heading-resources">Resources</h1><ul><li><p><a target="_blank" href="https://arxiv.org/abs/2402.17764">The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits</a> - the paper itself</p><ul><li><p><a target="_blank" href="https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf">Supplemental notes</a> - details about how to replicate their results, and answering some of the questions they got as a result of the pre-print</p></li><li><p><a target="_blank" href="https://arxiv.org/pdf/2310.11453.pdf">BitNet: Scaling 1-bit Transformers for Large Language Models</a> - the preceding paper. You probably need to read this paper as well if you want to understand this one</p></li><li><p><a target="_blank" href="https://www.semanticscholar.org/reader/97fb4e3d45bb098e27e0071448b6152217bd35a5">Layer Normalization</a> - the source of their LayerNorm stage</p></li><li><p><a target="_blank" href="https://huggingface.co/papers/2402.17764">The paper page on HuggingFace with some great discussion about the results and models</a></p></li></ul></li><li><p>The <a target="_blank" href="https://huggingface.co/1bitLLM">replicated models</a> on HuggingFace, including some replicated evals</p><ul><li><p>The <a target="_blank" href="https://huggingface.co/kousw/bitnet_b1_58-3B_quantized">quantized 3B model</a> for faster inference</p></li><li><p>My <a target="_blank" href="https://colab.research.google.com/drive/1KDQBle0hByR9oB1b9MVx9nmaDiHTr8_9?usp=sharing">notebook</a> to fiddle around with the quantized model</p></li><li><p><a target="_blank" href="https://stability.wandb.io/stability-llm/stable-lm/reports/StableLM-3B-4E1T--VmlldzoyMjU4?accessToken=u3zujipenkx5g7rtcj9qojjgxpconyjktjkli2po09nffrffdhhchq045vp0wyfo">StableLM-3B-4E1T Technical Report</a> - the presumptive source of the 2T token comparison that they make to argue that they continue to scale as the number of training tokens increasses</p></li></ul></li><li><p>Scaling laws</p><ul><li><p><a target="_blank" href="https://arxiv.org/pdf/2203.15556.pdf">Chinchilla paper</a></p></li><li><p><a target="_blank" href="https://arxiv.org/pdf/2401.00448.pdf">Beyond Chinchilla-Optimal: Accounting for Inference in Language Model Scaling Laws</a></p></li><li><p><a target="_blank" href="https://arxiv.org/pdf/2404.05405.pdf">Physics of Language Models: Part 3.3, Knowledge Capacity Scaling Laws</a> - a recent analysis of the knowledge capacity of transformer-based models</p></li></ul></li><li><p>Meta stuff</p><ul><li><p><a target="_blank" href="https://www.swyx.io/learn-in-public">learning in public</a> - a great post by <a target="_blank" href="https://twitter.com/swyx">@swyx</a> about how and why to share your learning journey with others</p></li><li><p><a target="_blank" href="https://www.swyx.io/learn-in-public">learning exhaust</a> - a great post by <a target="_blank" href="https://twitter.com/swyx">@swyx</a> about how to improve your learning by taking advantage of the artifacts you produce while learning</p></li></ul></li></ul>]]>https://cdn.hashnode.com/res/hashnode/image/upload/v1712944770607/569f6bdd-4e70-407a-a723-d131bf8e02d6.png