<?xml version="1.0" encoding="UTF-8"?><rss xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:content="http://purl.org/rss/1.0/modules/content/" xmlns:atom="http://www.w3.org/2005/Atom" version="2.0" xmlns:cc="http://cyber.law.harvard.edu/rss/creativeCommonsRssModule.html">
    <channel>
        <title><![CDATA[Stories by Ryan Pégoud on Medium]]></title>
        <description><![CDATA[Stories by Ryan Pégoud on Medium]]></description>
        <link>https://medium.com/@ryanpegoud?source=rss-27fba63b402e------2</link>
        <image>
            <url>https://cdn-images-1.medium.com/fit/c/150/150/1*ep3ebfAfE1csq42qzviNgg.jpeg</url>
            <title>Stories by Ryan Pégoud on Medium</title>
            <link>https://medium.com/@ryanpegoud?source=rss-27fba63b402e------2</link>
        </image>
        <generator>Medium</generator>
        <lastBuildDate>Mon, 06 Apr 2026 08:41:44 GMT</lastBuildDate>
        <atom:link href="https://medium.com/@ryanpegoud/feed" rel="self" type="application/rss+xml"/>
        <webMaster><![CDATA[yourfriends@medium.com]]></webMaster>
        <atom:link href="http://medium.superfeedr.com" rel="hub"/>
        <item>
            <title><![CDATA[Cutting LLM Memory by 84%, A Deep Dive into Fused Kernels]]></title>
            <link>https://medium.com/data-science-collective/cutting-llm-memory-by-84-a-deep-dive-into-fused-kernels-7028ca28bb75?source=rss-27fba63b402e------2</link>
            <guid isPermaLink="false">https://medium.com/p/7028ca28bb75</guid>
            <category><![CDATA[gpu-kernel]]></category>
            <category><![CDATA[llm]]></category>
            <category><![CDATA[triton]]></category>
            <category><![CDATA[deep-learning]]></category>
            <dc:creator><![CDATA[Ryan Pégoud]]></dc:creator>
            <pubDate>Thu, 19 Feb 2026 17:51:34 GMT</pubDate>
            <atom:updated>2026-02-19T17:51:34.405Z</atom:updated>
            <content:encoded><![CDATA[<h4>Why your final LLM layer is OOMing and how to fix it with a custom Triton kernel.</h4><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*JpmyxyAZkCq56-pB" /><figcaption>Photo by <a href="https://unsplash.com/@zelebb?utm_source=medium&amp;utm_medium=referral">Andrey Matveev</a> on <a href="https://unsplash.com?utm_source=medium&amp;utm_medium=referral">Unsplash</a></figcaption></figure><p>If you’ve ever trained or fine-tuned an LLM, you’ve likely hit a wall at the very last step: the <strong>Cross-Entropy Loss</strong>.</p><p>The culprit is the <strong>logit bottleneck</strong>. To predict the next token, we project a hidden state into a massive vocabulary space. For Llama 3 (128,256 tokens), the weight matrix alone is over <strong>525 million parameters</strong>. While that’s only ~1GB in bfloat16, the intermediate logit tensor is the real issue: for large batches, it can easily exceed <strong>80GB</strong> of VRAM just to compute a single scalar loss.</p><p>Optimising this layer is how libraries like Unsloth and Liger-Kernel achieve such massive memory reductions. In this article, we’ll build a fused Linear + Cross Entropy kernel from scratch in Triton. We will derive the math and implement a tiled forward and backward pass that slashes peak memory usage by <strong>84%</strong>.</p><blockquote><strong>Note on Performance:</strong> This implementation is primarily <strong>educational</strong>. We prioritise mathematical clarity and readable Triton code by using global atomic operations. While it solves the memory bottleneck, matching production-grade speeds would require significantly more complex implementations which are out of scope for this article.</blockquote><blockquote>This post is part of my Triton series. We’ll be using concepts like <a href="https://contributor.insightmediagroup.io/learning-triton-one-kernel-at-a-time-matrix-multiplication/">tiling</a> and <a href="https://contributor.insightmediagroup.io/learning-triton-one-kernel-at-a-time-softmax/">online softmax</a> that we’ve covered previously. If those sound unfamiliar, I recommend catching up there first!</blockquote><ul><li><a href="https://medium.com/data-science-collective/learning-triton-one-kernel-at-a-time-softmax-78e8ba73734d">Learning Triton One Kernel at a Time: Softmax</a></li><li><a href="https://medium.com/data-science-collective/learning-triton-one-kernel-at-a-time-matrix-multiplication-44851b4146dd">Learning Triton One Kernel at a Time: Matrix Multiplication</a></li><li><a href="https://medium.com/data-science-collective/learning-triton-one-kernel-at-a-time-vector-addition-5f57e9d2f3e1">Learning Triton One Kernel At a Time: Vector Addition</a></li></ul><h3>The Logit Bottleneck</h3><p>To get us started, let’s put some more numbers on the logit bottleneck. We consider an input matrix X with shape [NxD], a weight matrix W with shape [DxV] and a logit matrix Y=X@W with shape [NxV]. In the context of an LLM, N would be the sequence length multiplied by the batch size, D the size of the hidden state and V the vocabulary size. <br>For a Llama3 8B model, we would have a context window of 8192 tokens, a hidden state with 4096 dimensions and a vocabulary size of 128,256 tokens. Using a modest batch size of 8, we get N = 8192x8 = 65,536.<br>This results in the Y matrix having shape [NxV]=[65,536x128,256], or roughly <strong>8.4 billion</strong> elements. In bfloat16, this would take up <strong>16.8GB</strong> of memory. However, if we follow best practices and use float32 for the loss calculation to ensure numerical stability, the requirements double to <strong>33.6GB</strong>.<br>To put this number in perspective, we would also need around 16GB of memory to hold the weights of Llama3 8B in memory in bfloat16. One most GPUs, this leaves no space for the massive overhead of the optimiser states (e.g. Adam’s moments) and other activations, resulting in the infamous PyTorch OOM error.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*JvJQa8OlshjjkapnNkEmnw.png" /><figcaption>Representation of the input, weight and logit matrices along with their memory footprint. (All illustrations and animations in this article were made by the author unless specified otherwise)</figcaption></figure><p>Generally, this problem is dealt with by using:</p><ul><li><strong>Gradient accumulation:</strong> Use a smaller batch size and accumulate gradients over multiple batches between each optimiser step, emulating a larger batch size while holding less data in memory.</li><li><strong>Activation checkpointing:</strong> PyTorch stores all intermediate activations for reuse in the backward pass, checkpointing clears these activations and recomputes them on-the-fly during the backward pass. This leads to large memory savings but increases training time since the number of required forward passes is doubled.</li><li><strong>Micro-batching the loss:</strong> Instead of computing the loss over the N dimension at once, we can slice it and accumulate the loss over smaller chunks with size n &lt; N. Now, we only hold a slice of size [n, V] in memory at a time.</li><li><strong>Mixed precision training:</strong> Using half precision during training provides 2x memory reduction and significant speedups on Tensor Cores.</li></ul><p>While these solutions seem attractive, they all have significant drawbacks: gradient accumulation and activation checkpointing slow down training, mixed precision can be unstable and micro-batching requires (slow) PyTorch level iteration and even though n is chosen to be smaller than N, the vocabulary size remains huge in comparison.<br>More importantly, these solutions do not address the problem we have dealt with repeatedly throughout this series: <strong>data movement</strong>. Indeed, we are still wasting time by writing billions of logits to VRAM only to read them back milliseconds later.</p><h3>The Kernel Solution</h3><p>As we’ll see in a minute, the forward and backward pass of the cross-entropy loss involve dot products, matrix multiplication and a softmax. As we learned in this series, these are all operations that can be tiled efficiently. In other words, we can perform them iteratively while only holding a small piece of the inputs in memory at any time. <br>Furthermore, cross-entropy is generally preceded by a matrix multiplication: the linear projection from the hidden state into the vocabulary space. This is a great opportunity for <strong>operator fusion</strong>: fusing multiple operation within a single kernel, resulting in large speedups and potential memory gains.<br>In the following sections, we’ll take a look at how to derive and efficiently fuse the forward and backward passes through a kernel combining a linear layer with cross-entropy.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*Xoa90UGcS06PWGchvYTlHw.png" /><figcaption>Illustration of the Llama3 architecture, the operations handled by the kernel are highlighted in purple.</figcaption></figure><p>As mentioned in the last article, Triton kernels do not natively register in PyTorch’s autograd. Therefore we need to derive the gradient ourselves, a wonderful occasion to brush up on some calculus ;)</p><h3>The math behind Fused Linear Cross-Entropy</h3><h4>Definition and Forward Pass</h4><p>In this section, we derive the mathematical expression for our Fused Linear Cross-Entropy layer to see how it naturally lends itself to tiling.</p><p>For two discrete probability distributions p and q, cross-entropy is defined as:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*62uuCPhwHP70Dd-Bs-MkEg.png" /></figure><p>In our context, p is the <strong>one-hot vector</strong> representing the target token, while q is the <strong>model’s distribution</strong> over the vocabulary. We obtain q by applying a softmax to the logits l, themselves the outputs of the preceding linear layer.</p><p>Since p is positive for a single target token y, the summation collapses. We can then substitute the numerically stable softmax (as discussed in the <a href="https://contributor.insightmediagroup.io/learning-triton-one-kernel-at-a-time-softmax/">last article</a>) to derive the final expression:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*TIdSq20rxdqcbyEk-KnTFQ.png" /></figure><p>By substituting the logits l with the linear layer x . w, we see that the forward pass boils down to three primary quantities:</p><ol><li>The target logit x . w_y.</li><li>The log-sum-exp (LSE) of all dot products.</li><li>The global maximum logit used for numerical stability.</li></ol><p>Thanks to the online softmax algorithm, we can compute these quantities without ever materialising the full vocabulary in memory. Instead of an O(V) memory bottleneck, we iterate over the hidden dimension D and the vocabulary V in small tiles (D_block and V_block). This transforms the calculation into an O(1) register problem.</p><p>To parallelise this effectively, we launch one GPU program per row of the input matrix. Each program independently executes the following steps:</p><ol><li><strong>Pre-compute the target logit:</strong> Perform a tiled dot product between the current row of X and the column of W associated with token Y.</li><li><strong>Online reduction:</strong> Iterate through the hidden and vocabulary blocks to:<br> 1. Track the running maximum (m)<br> 2. Update the running sum of exponentials (d) using the online softmax formula:</li></ol><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*W9Uyu5OE7tjWKKSK57LCKA.png" /></figure><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*V7KDla0jVXhrPSC_XRINmw.png" /><figcaption>An example of tiled matrix multiplication for a single GPU program processing a row of <strong>X</strong>. The coloured squares represent elements loaded in memory and the coloured outline represent the complete tile that is iterated on. Tiling trades off speed for massive memory gains.</figcaption></figure><p>Now that we have a better understanding of the forward pass, let’s take a look at the derivation of the backward pass.</p><h3>Backward Pass</h3><h4>Notation</h4><p>To derive our gradients efficiently, we’ll use <strong>Einstein notation</strong> and the <strong>Kronecker delta</strong>.<br>In Einstein notation, repeated indices are implicitly summed over. For example, a standard matrix multiplication Y = X@W simplifies from a verbose summation to a clean index pairing:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*aQLLEF9L7RjvaIzMtymm4w.png" /></figure><p>The <strong>Kronecker delta</strong> (δ_ij) is used alongside this notation to handle identity logic. It is equal to 1 if i=j and 0 otherwise. As we’ll see, this is particularly useful for collapsing indices during differentiation.</p><h4>Matrix Multiplication</h4><p>In this section, we derive the back-propagated gradients for matrix multiplication. We assume the existence of an upstream gradient <strong>ℓ</strong>.</p><p>To determine how it back-propagates through matrix multiplication, we use the apply the chain rule to the inputs x and the weight matrix w. Here y represents the multiplication’s outputs:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*SeY97JFDqjJQq66MG3Fn6Q.png" /></figure><p>We start by deriving the partial derivatives of y with respect to x, following these steps:</p><ol><li>Express y in terms of x and w.</li><li>Notice that w is a constant with respect to the derivative of x, so we can pull it out of the derivative.</li><li>Express the fact that the partial derivative of x_ik with respect to x_mn is 1 only when i=m and k=n using the Kronecker delta.</li><li>Notice that ẟ_kn enforces k=n, therefore w_kj * ẟ_kn reduces to w_nj.</li></ol><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*DhnD4S_t0KYkcYBdipovTA.png" /></figure><p>Then, we consider the full expression and obtain the gradient. We derive the last step by noticing once again that 1/y_ij * ẟ_im reduces to 1/y_mj.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*oBvJIlKxuYWqjxBIP1kCMQ.png" /></figure><p>However, matrix notation is conceptually closer to our Triton kernel, therefore, we rewrite this expression as a matrix multiplication by using the identity X_ij = [X^T]_ji:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*xJ9kdRBxWIqs_Vix8LZxow.png" /></figure><p>We follow the exact same steps to derive the gradient with respect to W:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*ndxRutXIkQ2IFBm_OGmocA.png" /></figure><p>Then, the back-propagated gradient follows:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*fHVqr78r9tu2-7baHK_udg.png" /></figure><p>Which is equivalent to the matrix notation:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*F1sd2qM_0zz-dujMVpq6qQ.png" /></figure><h4>Cross-Entropy</h4><p>In this section, we’ll focus on cross-entropy applied to discrete probability distributions. Considering a tensor of j logits, with a label y, the cross-entropy is computed as follows:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*lEnTObrAEXOspl7BcHfpAg.png" /></figure><p>Where x_y corresponds to the logit associated to the label.<br>Once again, we are interested in the partial derivative of any output i with respect to any input k. Because of the normalising factor, every element i affects the value of every other element, therefore, the partial derivative is obtained by defining the function piecewise depending on the value of i:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*CmcW6cd1W513p4uuz4Zzyw.png" /></figure><p>Summing both cases, we obtain the gradient:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*iRDjwR2uxZMmo8llPQWbzw.png" /></figure><p>And in matrix notation:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*sORnvMDQzXDEktW438eOhg.png" /></figure><p>Where y_{one hot} is a vector of zeros with the entry corresponding to the label set to one. This result tells us that the gradient is simply <strong>the difference between the prediction and the ground truth</strong>.</p><h4>Fused Linear Cross-Entropy</h4><p>Combining the linear projection with cross-entropy in a single expression, we get:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*Om6hbzQWTGKPN379SI1SVA.png" /></figure><p>Thanks to the chain rule, deriving the gradient of this expression boils down to multiplying the gradients we computed previously:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*XBjBJsM3tUQwKDqWKc--fw.png" /></figure><p>Where x and y refer to the inputs and outputs to the linear layer respectively and w to the associated weight matrix.</p><blockquote>Note: in a batched setting, we’ll need to reduce the W gradients over the batch dimension. Generally, we use a sum or mean reduction.</blockquote><h3>Kernel Implementation</h3><p>With the theory established, we can implement the fused kernel in Triton. Since cross-entropy is typically the final layer in a language model, we can combine the forward and backward passes into a <em>single kernel</em>. This <strong>fusion</strong> offers two advantages: it minimises the overhead of multiple kernel launches and significantly improves data locality by keeping intermediate values on-chip.</p><p>We will analyse the kernel step-by-step from the perspective of a <strong>single program instance</strong>, which, in our parallelisation strategy, handles one specific row of the input matrix.</p><h4>1. Setup and Target Logit Pre-computation</h4><p>The initial phase involves standard Triton setup:</p><ul><li><strong>Program Identification:</strong> We use tl.program_id to determine which row of the input matrix the current program is responsible for.</li><li><strong>Parameter Initialisation:</strong> We define tiles using D_BLOCK and V_BLOCK and initialise the running maximum (m) and sum (d) required for the online softmax algorithm.</li><li><strong>Pointer Arithmetic:</strong> We calculate the base memory addresses for our tensors. Pointers for X (input) and dX (gradient) are offset using the <strong>row stride</strong> so each program accesses its unique token vector. Conversely, the W (weight) pointer remains at the base address because every program must eventually iterate through the entire vocabulary space.</li><li><strong>Masking and Early Exit:</strong> We define an ignore_index (defaulting to -100). If a program encounters this label (e.g. for padding tokens), it terminates early with a loss of 0 to save cycles.</li></ul><h4>2. Computing the Target Logit</h4><p>Before the main loop, we must isolate the <strong>target logit</strong> x . w_y. We iterate over the hidden dimension D in D_BLOCK chunks, performing a dot product between the input row X and the specific column of W corresponding to the ground-truth label Y.<br>Because W is a 2D matrix, calculating the pointers for these specific column tiles requires precise stride manipulation. The illustration below helps visualising how we “jump” through memory to extract only the necessary weights for the target token.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*RxiOsFSjCu5eaCONxY05Pg.png" /><figcaption>Representation of the pointer arithmetic executed to compute the target logit <strong>Y</strong>. Here, we consider that the label is <strong>4</strong>, meaning that the target logit is <strong>X</strong>’s dot product with <strong>W</strong>’s 5th column. Vectors of different colours represent different steps of the iteration along <strong>D</strong> (i.e. different values of <strong>d_idx</strong>). Numbers refer to the memory address of each element assuming a row-major layout.</figcaption></figure><p>Once the tiles are loaded, we cast them to float32 to ensure numerical stability and add their dot product to an accumulator variable before moving to the next iteration.</p><p>Here’s the code so far:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/1551a2685e31109843e60f87081c7533/href">https://medium.com/media/1551a2685e31109843e60f87081c7533/href</a></iframe><p>Next, we execute the forward pass, which processes the vocabulary space in two nested stages:</p><ol><li><strong>Tiled Logit Computation:</strong> We compute the logits for a V_BLOCK at a time. This is achieved by iterating over vocabulary dimension V (outer loop) and the hidden dimension D (inner loop). Within the inner loop, we load a tile of X and a block of W, accumulating their partial dot products into a high-precision register.</li><li><strong>Online Softmax Update:</strong> Once the full dot product for a logit tile is finalised, we don’t store it to VRAM. Instead, we immediately update our running statistics: the maximum value m and the running sum of exponentials d using the online softmax formula. By doing this “on the fly”, we ensure that we only ever hold a small V_BLOCK of logits in the GPU’s registers at any given moment.</li></ol><p>Following these iterations, the final values of m and d are used to reconstruct the LSE. The final scalar loss for the row is then computed by subtracting the target logit (x . w_y) from this LSE value.</p><p>Here’s a visual representation of the forward pass:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*vX5pSTqezf8R-3UeDPl_0A.gif" /><figcaption>Visual representation of the tiled matrix multiplication with running statistics updates. At each step, we load elements coloured in green or dark blue and compute the dot products of vectors highlighted in green. Elements of <strong>Y</strong> are accumulated by iterating over the <strong>D</strong> dimension, when this is done (i.e. the cells are green), we update <strong>m</strong> and <strong>d</strong> based on the freshly computed tile.</figcaption></figure><p>Here’s the code for the forward pass:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/2d9c22ed7e470e2e8f9f1dc32078890d/href">https://medium.com/media/2d9c22ed7e470e2e8f9f1dc32078890d/href</a></iframe><p>We are now down to the last part of the kernel: the backward pass. Our goal is to compute the gradients with respect to X and W using the expression we derived earlier:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*M-jNpc6xu8L5des7YRM3Gw.png" /></figure><p>To remain memory-efficient, we once again process the vocabulary in tiles using a two-staged approach:</p><ol><li><strong>Recomputing Normalised Probabilities (</strong><strong>P):</strong> Because we didn’t store the full logit matrix during the forward pass, we must recompute the activations for each tile. By reusing the <strong>Log-Sum-Exp</strong> calculated in the forward pass, we can normalise these activations on-the-fly. Subtracting the ground-truth label Y from the target logit within this tile gives us a local chunk of the gradient logit, P.<br>2. <strong>Gradient Accumulation:</strong> With a tile of P in hand, we calculate the partial gradients. For dX, we perform a dot product with blocks of W^T; for dW, we multiply by tiles of X^T. To safely aggregate these values across the entire batch, we use Triton’s <strong>tl.atomic_add</strong>.<br>This operation acts as a thread-safe +=, ensuring that different programs updating the same weight gradient do not overwrite one another.</li></ol><p>Here are some additional details on the implementation:</p><ul><li><strong>The Stride Swap:</strong> When computing P . W_T, we don’t actually need to physically transpose the massive W matrix in memory. Instead, we invert the shapes and strides in W’s block pointer to read the rows of W as columns of W^T. This results in a “free” transpose that saves both time and VRAM.</li><li><strong>Numerical Precision: </strong>It is worth noting that while X and W might be in bfloat16, the accumulation of dW and dX via atomic_add is usually performed in <strong>float32</strong> to prevent the accumulation of tiny rounding errors across thousands of rows.</li><li><strong>Contention Note:</strong> While atomic_add is necessary for dW (because every program updates the same weights), dX is private to each program, meaning there is zero contention between program IDs for that specific tensor.</li><li><strong>Atomic Add Masking:</strong> atomic_add doesn’t support block pointers. Therefore, we implement the pointer and mask logic for dW explicitly.</li></ul><p>The following figure is a representation of the backward pass for one iteration of the outer loop (i.e. one block along V and all blocks along D):</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*-Rs5JjNXG0iNTCtbSyVa_w.png" /><figcaption>Representation of the backward pass for a single step along the <strong>V</strong> dimension and a full iteration along the <strong>D</strong> dimension. In stage 4, we highlight how <strong>dX</strong> is accumulated over <strong><em>iterations</em></strong> (every program updates its private row once per step along <strong>V</strong>) whereas <strong>dW</strong> is accumulated over <strong>programs</strong> (<strong>N</strong> programs update the values of a single block in <strong>dW</strong> at every step along <strong>V</strong>).</figcaption></figure><p>Here’s the full code for the backward pass:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/2a6fc3bffebe0825d734234611aad19c/href">https://medium.com/media/2a6fc3bffebe0825d734234611aad19c/href</a></iframe><p>This concludes the implementation of our kernel! The full code including the kernel and benchmark script is available <a href="https://gist.github.com/RPegoud/76a2e9042b929889a158d7d17c81c9f7">here</a>.</p><h4>Memory Benchmark</h4><p>Finally, we compare our kernel with the PyTorch baseline using hyperparameters inspired from Llama3 and an A100 GPU. Specifically, we consider a sequence length of S=16,384, a batch size of B=1 and an embedding dimension of D=4096; the vocabulary size is set to V=128,256.<br>As expected, the PyTorch baseline allocates a massive intermediate tensor to store the activations, resulting in a peak memory usage of <strong>36.02GB</strong>. In comparison, our Triton kernel reduces the peak memory usage by <strong>84%</strong> by allocating only <strong>5.04GB</strong> using D_BLOCK=64 and V_BLOCK=64!<br>Using even smaller block sizes would allow for further memory gains at the cost of efficiency.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*duVTP6fMbXquEZRFjjLkPQ.png" /></figure><h4>Atomic Limitations and Production Scaling</h4><p>In this article, we focused on the technical and mathematical intuition behind fused Linear Cross-Entropy kernels. We used atomic operations like tl.atomic_add to keep the code minimal and readable. However, while our kernel successfully slashed memory usage by a staggering <strong>86%</strong>, the Triton kernel is significantly slower than native PyTorch.<br>Unfortunately, the same atomic operations which make this kernel easier to write and comprehend come at the cost of a massive traffic jam since thousands of threads try to modify the same memory address at once. Generally, tl.atomic_add is performant when <em>contention is low</em>. In our current implementation, we have:</p><ol><li><strong>High Contention:</strong> For the weight gradient, every single program in the batch (up to 16,384 in our test) is trying to update the same memory tiles simultaneously. The hardware must serialise these updates, forcing thousands of threads to wait in line.</li><li><strong>Numerical Non-associativity:</strong> In computers, floating-point addition is <strong>non-associative</strong>. Rounding errors can accumulate differently depending on the order of operations, which is why correctness tests might pass on a T4 but fail on an A100, the latter has more streaming multiprocessors (SMs) performing more concurrent, non-deterministic additions.</li></ol><blockquote><strong>Note on Precision:</strong> On Ampere and newer architectures, the <strong>TF32</strong> format can further contribute to these discrepancies. For strict numerical parity, one should set allow_tf32=False or use higher precision types during the accumulation steps.</blockquote><h4>Path to Production</h4><p>To move beyond this educational implementation and toward a production-ready kernel (I recommend looking at the <a href="https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/fused_linear_cross_entropy.py">Liger-Kernel implementation</a>), one could implement several optimisations:</p><ul><li><strong>Replacing </strong><strong>dX Atomics:</strong> Since each program “owns” its row of X, we can use simple register accumulation followed by a tl.store, eliminating atomics for the input gradients entirely.</li><li><strong>A dedicated </strong><strong>dW Kernel:</strong> To optimise the computation of dW, production kernels generally use a different grid strategy where each program handles a block of W and iterates through the batch dimension, accumulating gradients locally before a single global write.</li><li><strong>Micro-batching:</strong> Advanced implementations, such as those in the <a href="https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/fused_linear_cross_entropy.py"><strong>Liger-Kernel</strong></a> library, process the sequence by blocks along the N dimension, making the memory scaling constant in the sequence length rather than linear. This enables the use much larger batch sizes at a reduced memory cost.</li></ul><h3>Conclusion</h3><p>This concludes our deep dive into fused linear cross-entropy kernels. Thanks for reading all the way through, I hope this article gave you both the intuition and the practical understanding needed to build on these ideas and explore them further.<br>If you found this useful, consider sharing the article; it genuinely helps support the time and effort that goes into producing this work. And as always, feel free to <a href="https://www.linkedin.com/in/ryan-pegoud">contact me</a> if you have questions, thoughts, or ideas for follow-ups.</p><p>Until next time! 👋</p><h3>Sources</h3><ol><li><a href="https://ai.meta.com/blog/meta-llama-3/#:~:text=Llama%203%20uses%20a%20tokenizer,to%20substantially%20improved%20model%20performance">Introducing Meta Llama 3: The most capable openly available LLM to date</a></li><li><a href="https://www.youtube.com/watch?v=gWble4FreV4">LigerKernel (lecture)</a></li><li><a href="https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/fused_linear_cross_entropy.py">LigerKernel Linear Cross-Entropy Implementation</a></li><li><a href="https://github.com/unslothai/unsloth/blob/main/unsloth/kernels/cross_entropy_loss.py">Unsloth Implementation (cross-entropy only)</a></li></ol><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=7028ca28bb75" width="1" height="1" alt=""><hr><p><a href="https://medium.com/data-science-collective/cutting-llm-memory-by-84-a-deep-dive-into-fused-kernels-7028ca28bb75">Cutting LLM Memory by 84%, A Deep Dive into Fused Kernels</a> was originally published in <a href="https://medium.com/data-science-collective">Data Science Collective</a> on Medium, where people are continuing the conversation by highlighting and responding to this story.</p>]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[AlpamayoR1: Large Causal Reasoning Models for Autonomous Driving]]></title>
            <link>https://medium.com/data-science-collective/alpamayor1-large-causal-reasoning-models-for-autonomous-driving-5b287216634c?source=rss-27fba63b402e------2</link>
            <guid isPermaLink="false">https://medium.com/p/5b287216634c</guid>
            <category><![CDATA[nvidia]]></category>
            <category><![CDATA[deep-learning]]></category>
            <category><![CDATA[autonomous-cars]]></category>
            <category><![CDATA[autonomous-driving]]></category>
            <category><![CDATA[autonomous-vehicles]]></category>
            <dc:creator><![CDATA[Ryan Pégoud]]></dc:creator>
            <pubDate>Thu, 19 Feb 2026 13:30:33 GMT</pubDate>
            <atom:updated>2026-04-05T11:13:34.814Z</atom:updated>
            <content:encoded><![CDATA[<h4>All you need to know about Chain of Causation reasoning and the current state of Autonomous Driving!</h4><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*pM926G8reEMLhjp0.jpg" /><figcaption><a href="https://unsplash.com/photos/a-black-car-on-a-white-platform--FANRALn9WI">Photo</a> by Pramod Tiwari on Unsplash</figcaption></figure><p>Recently, Nvidia took the world of autonomous driving by storm with their new AlpamayoR1 architecture integrating a large Vision-Language Model as a causally-grounded reasoning backbone. This release, accompanied by a new large-scale dataset and a photo-realistic driving simulator, already positions the company as one of the main players in the field in 2026.</p><p>In this article, we’ll break down the AlpamayoR1 architecture, chain of causation reasoning, as well as the elaborate training procedure used to train the model.</p><h3>The Current State of Autonomous Driving</h3><p>The release of AlpamayoR1 (AR1) finds context in the current paradigm of End-to-End (E2E) architectures. E2E models aim to map raw sensory inputs (cameras, LiDAR, radar, …) to trajectories in a fully differentiable architecture optimising a unified objective.</p><p>An emerging trend in E2E involves leveraging the extensive world knowledge of large Vision-Language Models (VLMs) to tackle complex driving situations. This generally involves using VLMs as reasoning backbones to inform future trajectories or as expert teachers to provide supervisory signal to smaller student models.</p><h3>The AR1 Architecture</h3><p>AR1 is a prime example of the reasoning-VLM-as-a-backbone approach. Despite its massive size, the architecture is optimised for real-world deployment and runs a latency of <strong>99ms</strong> or <strong>10Hz</strong> on a single BlackWell GPU, which is considered to be a general target for safety reasons. In this section, we’ll break down the architecture and its numerous innovations.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*NRn8NA4bkFmbxVMk.png" /><figcaption>High-level overview of the AR1 architecture, source: [1]</figcaption></figure><h3>Vision Encoder</h3><p>AR1 uses both visual and textual inputs in the form of tokenised camera feeds and natural language instructions. For performance, it is crucial for the vision encoder to produce as few tokens as possible.</p><p>To this end, the authors used a Vision Transformer (ViT)[2] for single-image tokenisation. ViTs partition images in a sequence of tokens encoded by a regular transformer. Note that the integration of more efficient algorithms like Flex [3] for multi-video tokenisation is left for future work.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*uOufTuMslf_3ZdSA.png" /><figcaption>Vision Transformer architecture, source: [2]</figcaption></figure><h3>Reasoning Backbone</h3><p>The AR1 architecture is built around Cosmos-Reason, one of Nvidia’s VLMs trained specifically for embodied reasoning in Physical AI use cases. Its usual training set includes 3.7M general Visual Question-Answering (VQA) samples to improve the model’s physical common set as well, complemented by 24.7K driving samples. These include video VQA annotated with DeepSeek-R1 reasoning traces to predict the next action.</p><p>Cosmos-Reason processes visual and text tokens along with the recent ego-history (past x-y positions and angle of the ego-vehicle) to output <strong>chain of causation</strong> reasoning traces to inform future trajectories.</p><h3>Chain of Causation</h3><p>A crucial limitation of language models lies in the inherent ambiguity of text labels in visual datasets. This includes vague descriptions lacking a causal structure. Models trained on such data exhibit a low correlation between their reasoning traces and predicted actions as well as causal confusion.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*D_I34O4owPS3Gfxc.png" /><figcaption>Driving datasets tend to include vague annotations with weak causal grounding, source: [1]</figcaption></figure><p>For an embodied agent like an autonomous car, strong causal reasoning abilities are essential. To circumvent those problems, the Nvidia team deployed significant efforts to create a driving dataset with causally consistent annotations.</p><p>Specifically, the dataset contains 20-second clips extracted from real-world driving recordings in various environments and countries. Each clip contains 2 seconds of context leading to a driving decision (e.g. overtaking, yielding, passing an intersection, …) and its consequences. The causal structure of these scenarios is exposed by consistent textual annotations following a strict template.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*7R1Yxp_-vY2fAdTs.png" /><figcaption>Annotation pipeline for the Chain of Causation dataset, source: [1]</figcaption></figure><p>The first 10% of the dataset are annotated by humans, while the remainder are annotated by state-of-the-art VLMs like GPT5 to scale the labeling process. Once again, significant efforts are deployed to ensure the consistency, quality and correctness of these human and AI annotations.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*f3GZ_IkSLKSulkpH.png" /><figcaption>Examples of chain of causation reasoning produced by AR1, source: [1]</figcaption></figure><h3>Trajectory Decoder</h3><p>The last step of the forward pass consists in decoding the reasoning traces into a 64 point trajectory. While trajectories are usually decoded as a sequence of waypoints (x-y coordinates), the Nvidia team found that using unicycle dynamics (i.e. generating a sequence of acceleration values and steering angles) produced more consistent results. In particular, it facilitates the learning task by preventing the model from predicting physically impossible trajectories (e.g. point t being too far from point t+1).</p><p>Interestingly, the authors adopt a dual representation of the trajectory where the model auto-regressively generates discrete tokens during training and uses flow-matching to generate a continuous trajectory at inference time. The main reasons behind this design are as follows:</p><ol><li><strong>Joint Action-Reasoning Token Space:</strong> Using discrete action tokens allows for a tighter coupling between reasoning traces and actions. When the model generates a reasoning trace, the next tokens in the sequence are (acceleration and curvatures) are mathematically linked to that explanation, preventing hallucinations.</li><li><strong>Facilitating RL Optimisation:</strong> Restricting the set of possible action tokens to a discrete set makes RL optimisation significantly easier. Indeed, sampling the correct token from a discrete vocabulary (e.g. ACCEL_NEG_2) is significantly easier than providing a gradient for a continuous value like -2.145 m/s^2. As we&#39;ll see in the next section, this enables RL post-training, which is crucial to improve the model&#39;s safety and consistency.</li><li><strong>Stronger Supervisory Signal: </strong>Using a cross-entropy loss on discrete tokens acts like a classification task and better captures the <em>multi-modality</em> (e.g. the distinct probability of turning left or right) than an MSE loss on coordinates.</li><li><strong>Flow Matching for Inference: </strong>While discrete tokens are great for learning, they typically result in jerky trajectories. Moreover, generating a sequence of 128 tokens auto-regressively is too slow for real-time inference. To address those limitations, the authors introduce an action expert: a smaller variant of the main architecture using the KV cache (which contains visual tokens, historical motions and reasoning traces) to decode a continuous trajectory in one pass using flow-matching diffusion. This is one of the main reasons why AR1 can run at such low latency.</li></ol><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*k5925Ql5Gwqx8Nyx.png" /><figcaption>Latency benchmark for several AR1 variants, generating trajectories via flow-matching saves close to 200ms at inference time. Source: [1]</figcaption></figure><h3>Supervised Fine-Tuning and RL Post-Training</h3><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*OIg8S7kimizzUDdo.png" /><figcaption>Multi-stage training pipeline for the Cosmos-Reason backbone and the AR1 architecture, source: [1]</figcaption></figure><p>In order to transform the VLM backbone into a performant driving policy, it undergoes supervised fine-tuning (SFT) on the chain of causation dataset. Specifically, it learns to reproduce the reasoning traces and associated ground-truth actions by maximising the log-likelihood of the action-reasoning sequence:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*GmPhSn5mSArnzGYcsTd8hA.png" /><figcaption>Supervised Fine-Tuning loss, made by the author</figcaption></figure><p>However, SFT on its own is not enough. VLMs are notoriously suffering from discrepancies between their reasoning and predicted actions. The static nature of open-loop datasets allows the model to mimic reasoning traces, but the lack of environmental feedback prevents them from truly internalising causal reactions.</p><p>Fortunately, RL post-training helps alleviate those limitations by providing inference feedback on the model’s rollouts. In this paper, the authors use RL for three main purposes:</p><ol><li><strong>Improving reasoning quality:</strong> a large reasoning model (e.g. DeepSeek-R1) evaluates AR1’s reasoning traces to ensure there are no inconsistencies or hallucinations and assigns a discrete reward on a scale of 0 to 5 accordingly. While DeepSeek is not expected to be able to generate high-quality reasoning traces for driving, it is significantly easier to evaluate AR1’s reasoning, this is known as the <em>generation-verification gap.</em></li><li><strong>Enforcing reasoning-action consistency:</strong> the authors extract <em>meta-actions </em>(accelerate, steer, go straight, …) from the CoC dataset using rule-based systems. If those meta-actions correspond to those mentioned in the reasoning traces, the model receives an additional reward of 1, otherwise 0.</li><li><strong>Trajectory Quality:</strong> a trajectory reward measures the L2 distance between the predicted and expert trajectory, penalises trajectories leading to collisions and high-magnitude jerks.</li></ol><p>During post-training, AR1 generates multiple parallel rollouts and collects rewards <strong>r_i</strong> based on the three reward signals above. These rewards are then used to compute the GRPO loss [4]. GRPO computes the advantage of each rollout relative to the group average. This baseline-free approach (as opposed to other RL algorithms like PPO), stabilises training by rewarding reasoning paths that outperform their counterparts for the same input, rather than relying on an arbitrary absolute score.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*RearKXfrUPvYGmNgY1ThFg.png" /><figcaption>GRPO loss, made by the author</figcaption></figure><p>All you need to understand about this objective is that it aims to maximise the probability of trajectories (the log term) with a high advantage (the softmax term) relative to others. To avoid losing vision-language priors from the VLM and the driving knowledge obtained during SFT, the objective is regularised by a KL divergence between the current policy and the reference (the policy obtained at the end of SFT).</p><h3>Evaluation</h3><p>The evaluation protocol includes 4 sections: Open-loop trajectory prediction, closed-loop simulation, ablation studies and on-vehicle road tests. While the fact that AR1 was deployed in real-world scenarios is impressive, the open and closed-loop results are somewhat opaque <em>in my opinion</em>; the main reason being that they were obtained on Nvidia datasets (closed loop: PhysicalAI-AV dataset, closed-loop: AlpaSim) released at the same time as the model. This implies a lack of baselines to contextualise AR1’s performances.</p><p>For instance, the closed-loop results only feature AR1 and a non-reasoning baseline on 75 scenarios. While AR1 outperforms the baseline on all measured metrics, it often does so by a single percent on average and with a much larger variance than the baseline.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*PS5eCTcxGOhxl0v2.png" /><figcaption>Closed-loop results for AR1 and a non-reasoning baseline, source: [1]</figcaption></figure><p>For this reason, I would advise taking these results with a grain of salt before other frontier architectures are evaluated in AlpaSim.</p><h3>Conclusion</h3><p>Despite the lack of contextualised results, AR1 and the accompanying datasets remain an impressive engineering achievement and a good indication of where autonomous driving is headed: end-to-end models inheriting world knowledge from massive VLMs trained on embodied tasks.</p><p>However, the collection of causally-grounded datasets required to enable chain of causation require significant investments and labeling efforts which limits reproducibility <em>until these datasets are made public. </em>In my next article, I’ll contrast the AR1 approach with another state-of-the-art model which entirely disposes textual labels and instead trains VLMs to act and reason in a latent space.</p><h3>Thank you for reading this far!</h3><p>If you found this article useful, please consider <strong>sharing it</strong>; it genuinely helps support the time and effort that goes into producing this work. As always, feel free to if you have questions, thoughts, or ideas for follow-ups. If you’d like to support my independent research and writing, feel free to <a href="https://buymeacoffee.com/ryanpegoud"><strong>buy me a coffee</strong></a> 😉</p><p>Until next time! 👋</p><h3>Sources</h3><ul><li>[1] <a href="http://arxiv.org/abs/2511.00088">Alpamayo-R1: Bridging Reasoning and Action Prediction for Generalizable Autonomous Driving in the Long Tail</a></li><li><a href="https://arxiv.org/pdf/2010.11929">[2] (Vision Transformer) AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE</a></li><li>[3] <a href="https://arxiv.org/pdf/2512.10947">(Flex) Towards Efficient and Effective Multi-Camera Encoding for End-to-End Driving</a></li><li>[4] <a href="https://arxiv.org/pdf/2402.03300">(GRPO loss) DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models</a></li></ul><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=5b287216634c" width="1" height="1" alt=""><hr><p><a href="https://medium.com/data-science-collective/alpamayor1-large-causal-reasoning-models-for-autonomous-driving-5b287216634c">AlpamayoR1: Large Causal Reasoning Models for Autonomous Driving</a> was originally published in <a href="https://medium.com/data-science-collective">Data Science Collective</a> on Medium, where people are continuing the conversation by highlighting and responding to this story.</p>]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[Learning Triton One Kernel at a Time: Softmax]]></title>
            <link>https://medium.com/data-science-collective/learning-triton-one-kernel-at-a-time-softmax-78e8ba73734d?source=rss-27fba63b402e------2</link>
            <guid isPermaLink="false">https://medium.com/p/78e8ba73734d</guid>
            <category><![CDATA[deep-learning]]></category>
            <category><![CDATA[pytorch]]></category>
            <category><![CDATA[gpu]]></category>
            <category><![CDATA[triton]]></category>
            <category><![CDATA[kernel]]></category>
            <dc:creator><![CDATA[Ryan Pégoud]]></dc:creator>
            <pubDate>Sat, 27 Dec 2025 09:42:06 GMT</pubDate>
            <atom:updated>2026-01-14T12:57:51.186Z</atom:updated>
            <content:encoded><![CDATA[<h4>All you need to know to write a fast, readable and PyTorch-ready softmax kernel!</h4><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*CLgcyvQSMqdXVbGD" /><figcaption>Photo by <a href="https://unsplash.com/@nanadua96?utm_source=medium&amp;utm_medium=referral">Nana Dua</a> on <a href="https://unsplash.com?utm_source=medium&amp;utm_medium=referral">Unsplash</a></figcaption></figure><p>In the <a href="https://medium.com/data-science-collective/learning-triton-one-kernel-at-a-time-matrix-multiplication-44851b4146dd">previous article</a> of this series, we covered an ubiquitous operation in all fields of computer science: matrix multiplication. It is heavily used in neural networks to compute the activation of linear layers. However, activations on their own are difficult to interpret, since their values and statistics (mean, variance, min-max amplitude) can vary wildly from layer to layer. This is one of the reasons why we use activation functions, for example the logistic function (a.k.a sigmoid) which projects any real number in the [0; 1] range.</p><p>The softmax function, also known as the normalised exponential function, is a multi-dimensional generalisation of the sigmoid. It converts a vector of raw scores (logits) into a <strong>probability distribution</strong> over <strong>M</strong> classes. We can interpret it as a <strong>weighted average</strong> that behaves as a <strong>smooth function</strong> and can be conveniently <strong>differentiated</strong>. It is a crucial component of dot-product attention, language modeling and multinomial logistic regression.</p><p>In this article, we’ll cover:<br>1. Implementing an efficient softmax kernel in Triton.<br>2. Implementing the backward pass (autograd).<br>3. Optimisation: cache modifiers and auto-tuning.</p><p>If you aren’t familiar with Triton yet, refer to the previous articles!</p><ul><li><a href="https://medium.com/data-science-collective/learning-triton-one-kernel-at-a-time-vector-addition-5f57e9d2f3e1">Learning Triton One Kernel At a Time: Vector Addition</a></li><li><a href="https://medium.com/data-science-collective/learning-triton-one-kernel-at-a-time-matrix-multiplication-44851b4146dd">Learning Triton One Kernel at a Time: Matrix Multiplication</a></li></ul><p><em>Disclaimer: all the illustrations and animations are made by the author unless specified otherwise.</em></p><h3>Definition</h3><p>The softmax is defined as follows:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*Zo1Sz5GQlSri1VDl6eFbzA.png" /></figure><p>The normalisation ensures that the vector sums to <strong>1</strong>, so that it can be interpreted as a valid probability distribution.</p><p>Note that this formulation of the softmax is highly sensitive to <strong>numerical overflow</strong>. Recall that the maximum value a standard <strong>float16</strong> can represent is <strong>65 504</strong>, which is roughly <strong>exp(11)</strong>. This means that any input value greater than ~11 will result in exp(z_i) exceeding the representable range, leading to <strong>overflow</strong>.</p><p>A common trick to mitigate this issue is to subtract the maximum value of the input vector from every element, such that the new maximum is <strong>0</strong> before exponentiation and <strong>1</strong> after.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*g8Fcxuur00vKJ0uuBNK2uQ.png" /></figure><h3>Naive Implementation</h3><p>As you can see, computing the softmax involves <strong>two reduction operations</strong>, a <strong>max</strong> and a <strong>sum</strong>. A naive algorithm require three separate passes over the input vector. First to compute the maximum, then the sum, and finally the normalised outputs.</p><p>Here’s what a naive Numpy implementation looks like:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/dd6af75416fdde06105c1e6af4eda67e/href">https://medium.com/media/dd6af75416fdde06105c1e6af4eda67e/href</a></iframe><p>A recurrent theme in this Triton series is minimising high-latency <strong>global memory access</strong>. Our current Numpy implementation requires three separate memory reads of the full input vector, which is highly inefficient.</p><h3>Online Softmax</h3><p>Fortunately, we can use a clever trick, known as the <strong>online softmax</strong>, to fuse the max and sum steps, reducing the number of memory reads to <strong>2</strong>. <br>First, we define the sum of exponentials recursively. In the following set of equalities, m_i refers to the maximum over x until the <strong><em>i</em></strong>-th index.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*RnIQU8TqusIWdE6SokP5gw.png" /></figure><p>This equality allows us to compute the sum of exponentials <strong>iteratively</strong> using the maximum value <strong>so far</strong>. We can leverage it to fuse the first and second loop in the naive implementation and compute the maximum and sum of exponentials iteratively.</p><p>Our algorithm becomes:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*qilMIPTUuRTW-V90PPLHvw.png" /></figure><p>This is easily translated to Numpy:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/545c7470df17f08a4a50df16aeae48c1/href">https://medium.com/media/545c7470df17f08a4a50df16aeae48c1/href</a></iframe><p>Now that we understand the main principles behind the softmax, we’ll implement it in Triton, starting by the simple, single-block version and building up to the online, multi-block formulation. In the end, we want our kernel to behave like a PyTorch module and be compatible with autograd.</p><p>Unfortunately, from PyTorch’s point of view, Triton kernels behave like black boxes: the operations they perform are not traced by autograd. This requires us to implement the backward pass ourselves and explicitly specify how gradients should be computed. Let’s brush up on our beloved chain rule and derive the softmax gradient.</p><h3>Gradient</h3><p>Since the outputs of the softmax are strictly positive, we can use the <strong>logarithmic derivative</strong> to make the derivation of the gradient easier. Here, we take the derivative of the <strong>log</strong> of the output and apply the chain rule:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*PScWEwQ25s873DfP3ueBuw.png" /></figure><p>From there, we rearrange the terms and follow these steps:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*hQROp_vwcm_Ia-Ovr_fq3A.png" /></figure><p>Now assume that we have some upstream gradient, for example generated by a loss function <strong><em>L</em></strong> (e.g. a cross-entropy loss). We get the following expression of the gradient:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*3OjKQ_Y_pNG5AuQ6Cxj08g.png" /></figure><p>The simplification of the left term in <strong>(9)</strong> is due to the fact that δ_ij will only be equal to <strong>1</strong> for the <strong><em>i</em></strong>-th element, collapsing the sum over <strong>j</strong> to a single term.</p><h3>Triton Implementation</h3><h4>Single Block Softmax</h4><p>Now that we worked through the derivation of the gradient, we can write the forward and backward softmax kernels. First, let’s focus on the PyTorch wrapper to understand how the single block implementation works at a high level. Given a 2D input tensor, the forward and backward kernels are going to process all rows in parallel. <br>For simplicity, we’ll define the BLOCK_SIZE to be large enough to handle all columns at once. Specifically, we’ll set it as the next power of 2 superior to the number of columns, as required by Triton. <br>Then, we’ll define our `grid` to be the number of rows (it could potentially also handle a batch dimension).</p><p>The PyTorch wrapper for our SoftmaxSingleBlock is a class inheriting from torch.autograd.Function that implements forward and backward. Both methods take a ctx argument, which we’ll use to cache the softmax outputs during the forward pass and reuse them during the backward pass.</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/2bdeb4343db215d190ab46912a739ddf/href">https://medium.com/media/2bdeb4343db215d190ab46912a739ddf/href</a></iframe><p>Both kernels are pretty straightforward, we start by loading the row inputs using the same syntax as in my previous <a href="https://medium.com/data-science-collective/learning-triton-one-kernel-at-a-time-vector-addition-5f57e9d2f3e1"><strong>vector addition</strong></a><strong> </strong>article. Notice that BLOCK_SIZE and num_warps are computed using a <a href="https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43">calculate_settings</a> function. This function comes from the <a href="https://unsloth.ai/"><strong>Unsloth</strong></a> library and was reused in other kernel libraries such as <a href="https://github.com/linkedin/Liger-Kernel"><strong>LigerKernel</strong></a> (which the kernels in this article are loosely based on), it provides heuristics to tune both variables:</p><pre>def calculate_settings(n: int) -&gt; tuple[int, int]:<br> MAX_FUSED_SIZE = 65536 # maximum grid dimension on Nvidia GPUs<br>    BLOCK_SIZE = next_power_of_2(n)<br>    if BLOCK_SIZE &gt; MAX_FUSED_SIZE:<br>        # we remove this assertion in this article<br>        raise RuntimeError(<br>            f&quot;Cannot launch Triton kernel since n = {n} exceeds &quot;<br>            f&quot;the maximum CUDA blocksize = {MAX_FUSED_SIZE}.&quot;<br>        )<br>    num_warps = 4<br>    if BLOCK_SIZE &gt;= 32768:<br>        num_warps = 32<br>    elif BLOCK_SIZE &gt;= 8192:<br>        num_warps = 16<br>    elif BLOCK_SIZE &gt;= 2048:<br>        num_warps = 8<br>    return BLOCK_SIZE, num_warps</pre><p>Then, we implement the regular softmax for the forward pass and equation <strong>(10)</strong> for the backward pass. The only novelty here compared to previous articles is the use of cache modifiers, which tell the compiler how to cache and evict data. For now, we’ll only focus on three cache modifiers:</p><ul><li><strong>.ca</strong> (<strong>Cache at all levels</strong>): Tells the compiler to load the data in both L1 and L2 cache, suggesting that it might be reused soon. This modifier should be used when the data is small enough to fit into L1 (~128–192KB per SM on an A100) and will likely be accessed repeatedly.</li><li><strong>.cs</strong> (<strong>Streaming</strong>): Treat data as <strong>streaming</strong>, it will be used once and then discarded to free up space in L1.</li><li><strong>.wb</strong> (<strong>Write-back</strong>): Normal cached write, the data will remain in the cache hierarchy, good if the output may be reused.</li></ul><p>In the following kernels, we’ll use the .ca modifier for loads since we perform multiple operations on the loaded data. For storing, we’ll use .cs in the forward pass, since the outputs won’t be immediately reused and .wb in the backward pass since in the context of autograd (i.e. the chain rule), gradient outputs will be consumed by downstream kernels.</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/c1c5a1b9386d7b398a941def96172140/href">https://medium.com/media/c1c5a1b9386d7b398a941def96172140/href</a></iframe><h4>Multi-Block Softmax</h4><p>Now, let’s take a look at the online formulation of the softmax. In this section, we implement a multi-block variant of the previous kernel. This version will use BLOCK_SIZE &lt; n_cols, in other words, we’ll only load a tile with BLOCK_SIZE elements at a time, similar to how we handled tiled GEMM in the <a href="https://medium.com/data-science-collective/learning-triton-one-kernel-at-a-time-matrix-multiplication-44851b4146dd">last tutorial</a>. Now you might ask <strong>“how do we select the block size?”</strong>.</p><p>This is a great occasion to introduce Triton’s autotune utility. Provided with a list of configuration, autotune will perform a grid-search to determine and cache the best configuration for a specific input shape. This process is repeated every time a new input shape is passed to the kernel. <br>Here, we perform a grid search over the block size and number of warps using the following utility function:</p><pre>from itertools import product<br><br># --- Multi Block Tuning ---<br>BLOCK_SIZES = [256, 512, 1024, 2048, 4096, 8192]<br>NUM_WARPS = [2, 4, 8, 16]<br><br><br>def get_autotune_config(<br>    block_sizes: list[int], num_warps: list[int]<br>) -&gt; list[triton.Config]:<br>    return [<br>        triton.Config(kwargs={&quot;BLOCK_SIZE&quot;: bs}, num_warps=nw)<br>        for (bs, nw) in list(product(block_sizes, num_warps))<br>    ]</pre><p>We can now decorate our multi-block kernels with autotune and pass the list of configs, key=”n_cols” indicates that the optimal config is dependent on the number of columns of the input.<br>The implementation of these kernels is conceptually very close to the online softmax we covered before, the main differences is that we iterate over tiles (not over single elements like in Numpy), which requires some adjustments. For instance, we add a sum over the tile in the d update and the backward kernel now requires two iterations as well.</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/afbb63ac1deb7ad9b28a617d941abcfd/href">https://medium.com/media/afbb63ac1deb7ad9b28a617d941abcfd/href</a></iframe><h3>Testing and Benchmarking</h3><p>We can now execute a forward and backward pass with both kernels and ensure they match the PyTorch baselines:</p><pre>def validate_kernel(kernel_fn: callable) -&gt; None:<br>    device = &quot;cuda:0&quot; if torch.cuda.is_available() else &quot;cpu&quot;<br>    torch.random.manual_seed(0)<br><br>    # Generate inputs<br>    x = torch.randn((256, 512), device=device) # triton input<br>    x.requires_grad = True<br>    xt = deepcopy(x) # torch input<br><br>    triton_output = kernel_fn(x)<br>    torch_output = torch.softmax(xt, dim=1)<br>    torch.testing.assert_close(triton_output, torch_output) # test fwd kernel<br><br>    # Setup fake labels<br>    y = torch.zeros_like(x)<br>    inds = (torch.arange(0, y.shape[0]), torch.randint(0, 3, (y.shape[0],)))<br>    y[inds] = 1<br><br>    # Define loss and run backward pass<br>    loss_fn = torch.nn.CrossEntropyLoss()<br>    loss = loss_fn(torch_output, y)<br>    loss.backward()<br><br>    # Save gradient tensor for later<br>    torch_xgrad = xt.grad.detach().clone()<br>    triton_loss = loss_fn(triton_output, y)<br>    triton_loss.backward()<br>    torch.testing.assert_close(x.grad, torch_xgrad) # test grad outputs<br><br>validate_kernel(softmax_sb)<br>validate_kernel(softmax_mb)</pre><p>Finally, we benchmark our implementation against the PyTorch baseline using the following snippet:</p><pre># --- Source: Triton softmax tutorial ---<br>@triton.testing.perf_report(<br>    triton.testing.Benchmark(<br>        x_names=[&quot;N&quot;],  # argument names to use as an x-axis for the plot<br>        x_vals=[<br>            128 * i for i in range(2, 100)<br>        ],  # different possible values for `x_name`<br>        line_arg=&quot;provider&quot;,  # argument name whose value corresponds to a different line in the plot<br>        line_vals=[<br>            &quot;triton_single_block&quot;,<br>            &quot;triton_multi_block&quot;,<br>            &quot;torch&quot;,<br>        ],  # possible values for `line_arg``<br>        line_names=[<br>            &quot;Triton_single_block&quot;,<br>            &quot;Triton_multi_block&quot;,<br>            &quot;Torch&quot;,<br>        ],  # label name for the lines<br>        styles=[(&quot;blue&quot;, &quot;-&quot;), (&quot;green&quot;, &quot;-&quot;), (&quot;red&quot;, &quot;-&quot;)],<br>        ylabel=&quot;GB/s&quot;,  # label name for the y-axis<br>        plot_name=&quot;softmax-performance&quot;,  # name for the plot. Used also as a file name for saving the plot.<br>        args={&quot;M&quot;: 4096},  # values for function arguments not in `x_names` and `y_name`<br>    )<br>)<br>def benchmark(M, N, provider):<br>    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)<br>    stream = getattr(torch, DEVICE.type).Stream()<br>    getattr(torch, DEVICE.type).set_stream(stream)<br>    if provider == &quot;torch&quot;:<br>        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))<br>    if provider == &quot;triton_single_block&quot;:<br>        torch.cuda.synchronize()<br>        ms = triton.testing.do_bench(lambda: softmax_sb(x))<br>        torch.cuda.synchronize()<br>    if provider == &quot;triton_multi_block&quot;:<br>        torch.cuda.synchronize()<br>        ms = triton.testing.do_bench(lambda: softmax_mb(x))<br>        torch.cuda.synchronize()<br>    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)<br>    return gbps(ms)<br><br><br>benchmark.run(show_plots=True, print_data=True)</pre><p>Good news! Our single-block kernel consistently outperforms the PyTorch baseline while the multi-block variant falls off for inputs with more than 6k columns:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*M0kgF6TB7CUJY_Pcwy8JFw.png" /></figure><p>Considering larger inputs, we can make several observations:</p><ol><li>The multi-block kernel eventually stabilises around 900GB/s of throughput, surpassing the PyTorch baseline for inputs with more than 30k columns.</li><li>Interestingly, it seems like the multi-block variant will dominate for inputs with more than 60k columns.</li><li>Even though we exceed the maximum block size with the single-block variant, the kernel still runs smoothly for some reason. Indeed, Triton automatically manages the block size under the hood. <br>When n_cols is larger than the hardware limit, Triton will break down the input and iterate over it. However, this seems to be slower than the multi-block approach.</li></ol><p>To go further, we could combine both approaches in a single kernel that explicitly selects the optimal kernel based on the input size. This way, we would benefit from the high performance of the single-block kernel for small inputs and the higher throughput of the multi-block variant for inputs with more than 60k columns.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*lixMZTxFPwd_xZEubiSPJQ.png" /></figure><p>This concludes the third episode of this Triton series, thanks again for your support!</p><p>In the next article, we’ll leverage the online softmax formulation in the context of <strong>Flash Attention</strong>.</p><p>Until next time! 👋</p><h3>Resources:</h3><ul><li><a href="https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/softmax.py"><strong>LigerKernel Softmax Implementation</strong></a></li><li><a href="https://medium.com/data-science/derivative-of-the-softmax-function-and-the-categorical-cross-entropy-loss-ffceefc081d1"><strong>Softmax Gradient derivation by Thomas Kurbiel</strong></a></li><li><a href="https://medium.com/@hugo.rosenkranz/gpu-kernel-optimization-softmax-part-2-43ce9f8019e8"><strong>GPU kernel optimization: Softmax — Part 2 by Hugo Rosenkranz-costa</strong></a> (Cuda &amp; Triton kernels with more emphasis on profiling and hardware optimisation)</li><li><a href="https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf"><strong>From online softmax to FlashAttention by Zihao Ye</strong></a></li></ul><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=78e8ba73734d" width="1" height="1" alt=""><hr><p><a href="https://medium.com/data-science-collective/learning-triton-one-kernel-at-a-time-softmax-78e8ba73734d">Learning Triton One Kernel at a Time: Softmax</a> was originally published in <a href="https://medium.com/data-science-collective">Data Science Collective</a> on Medium, where people are continuing the conversation by highlighting and responding to this story.</p>]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[Learning Triton One Kernel at a Time: Matrix Multiplication]]></title>
            <link>https://medium.com/data-science-collective/learning-triton-one-kernel-at-a-time-matrix-multiplication-44851b4146dd?source=rss-27fba63b402e------2</link>
            <guid isPermaLink="false">https://medium.com/p/44851b4146dd</guid>
            <category><![CDATA[pytorch]]></category>
            <category><![CDATA[matrix-multiplication]]></category>
            <category><![CDATA[deep-learning]]></category>
            <category><![CDATA[triton]]></category>
            <category><![CDATA[gpu]]></category>
            <dc:creator><![CDATA[Ryan Pégoud]]></dc:creator>
            <pubDate>Fri, 14 Nov 2025 11:24:36 GMT</pubDate>
            <atom:updated>2025-12-22T11:59:55.563Z</atom:updated>
            <content:encoded><![CDATA[<figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*Abp-FrfqgqehMRUx" /><figcaption>Photo by <a href="https://unsplash.com/@lucaskphoto?utm_source=medium&amp;utm_medium=referral">Lucas Kepner</a> on <a href="https://unsplash.com?utm_source=medium&amp;utm_medium=referral">Unsplash</a></figcaption></figure><p>Matrix multiplication is undoubtedly the most common operation performed by GPUs. It is the fundamental building block of linear algebra and shows up across a wide spectrum of different fields such as graphics, physics simulations and scientific computing while being ubiquitous in machine learning.</p><p>In today’s article, we’ll break down the conceptual implementation of general matrix-matrix multiplication (GEMM) while introducing several optimisation concepts such as tiling and memory coalescing. Finally, we’ll implement GEMM in Triton!</p><p><em>This article is the second of a series on Triton and GPU kernels, If you are not familiar with Triton or need a refresher on GPU basics, check out the previous article! All the code showcased in this article is available on </em><a href="https://github.com/RPegoud/Triton-Kernels"><em>GitHub</em></a><em>.</em></p><p><a href="https://towardsdatascience.com/learning-triton-one-kernel-at-a-time-vector-addition/">Learning Triton One Kernel At a Time: Vector Addition | Towards Data Science</a></p><p><em>Disclaimer: all the following figures and animations were made by the author unless stated otherwise.</em></p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*Oho4e5iI5h_p7K6lOQjHSw.gif" /><figcaption>Parallel Tiled GEMM, as we’ll implement in this article!</figcaption></figure><h3>Naive GEMM</h3><p>Let’s start simple: we want to multiply two matrices X and Y with shapes (M,N) and (N,K) respectively. The output matrix Z=X@Y will therefore have shape (M,K).</p><p>This operation involves computing the dot products of all pairs of rows and columns in X and Y respectively. A straightforward NumPy implementation might look something like this:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/e77f6c75d006a84ded3aa31b3362bf02/href">https://medium.com/media/e77f6c75d006a84ded3aa31b3362bf02/href</a></iframe><p>While easy to write, read and understand, this implementation is highly inefficient in terms of memory access and caching. As mentioned in the first article of this series, a fundamental aspect of GPU optimisation is <strong>minimising data transfers</strong>.</p><p>However, our current implementation starts by loading a row from X, iteratively loads all K columns of Y, computes their dot product and repeats the process for every row in X. This results in a total of M(K+1) loading operations.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/800/1*G0i-4NVQFPNi6C7QVmSgiQ.gif" /><figcaption>Naive Matrix Multiplication, purple and blue tiles represent the vectors involved in dot products at every time step and green cells the computed output values.</figcaption></figure><p>As seen in the animation, the memory access pattern is wasteful, as every column of Y is loaded M times. As an analogy: this is like running to the grocery store (global memory) every time you need a new ingredient for a dish instead of preparing all the ingredients on your kitchen counter (shared memory). Ideally, we would like to minimise the number of times each chunk of data is loaded and maximise its reusability once loaded. This leaves us with two main axes of optimisation:</p><ol><li>How can we improve the access pattern to minimise redundant loads?</li><li>How much data can we load at once, and where should it be stored on the GPU?</li></ol><h3><strong>Tiled GEMM</strong></h3><p>As mentioned previously, the naive approach to GEMM results in many redundant loads, which induces unnecessary overhead. Ideally, we’d like to load each segment of data only once and perform all the operations in which they are used before dropping them from memory.</p><p>An elegant approach to this problem is <strong>tiling</strong>, which involves dividing large matrices in smaller <em>“tiles”</em> or sub-matrices. Consider two matrices X and Y with shapes (4,6) and (6,4) respectively, X@Y results in a matrix Z with shape (4,4).</p><p>In order to compute the first element of Z, Z[0,0], we need to compute the dot product between the first row of X and the first column of Y: Z[0,0] = dot(X[0, :], Y[:, 0]). We can also break down the dot product into smaller chunks, for instance in groups of 3 elements: Z[0,0] = dot(X[0,0:3], Y[0:3, 0]) + dot(X[0,3:6], Y[3:6, 0]).</p><p>Alternatively, we can expand this approach to two dimensions and compute an entire (2,2) block of Z at a time: Z[0:2, 0:2] = dot(X[0:2, 0:2], Y[0:2, 0:2]) + dot(X[0:2, 2:4], Y[2:4, 0:2]) + dot(X[0:2, 4:6], Y[4:6, 0:2]).</p><p>Here’s a visual representation of tiled matrix multiplication:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/800/0*4ZCyn04i9iKSnfB6.gif" /><figcaption>Tiled Matrix Multiplication. The computation is split in several “tiles” of X and Y (highlighted in pale blue and purple), each containing several blocks (dark blue and purple). In each block, we compute dot products (green cells in X and Y). These dot products are accumulated across the blocks of a tile to compute the output values in Z (the accumulation is represented by colors from orange to green).</figcaption></figure><p>The above animation illustrates how data is reused in tiled GEMM. For each 2x2 block in X and Y, we compute 4 dot products, which results in a (2,2) output matrix in Z. Since each tile contains 3 blocks, we need to accumulate 3 of these matrices to compute the final (2,2) output in Z. This accumulation is represented by colored cells in Z.</p><p>In the kitchen analogy, this is like fetching ingredients from the store and preparing them on the kitchen counter (i.e. small shared memory), reusing them several times before going back to the store.</p><p>Importantly, reusing loaded data over multiple steps allows this approach to drastically reduce the number of load operations. For (2,2) blocks, each X row and Y column is used in two dot products. Therefore, we’re performing <strong>twice as many operations </strong>with each block of loaded data, roughly <strong>halving</strong> the number of load operations! Note that this generalises to larger blocks as well, using a (32,32) block would reduce the number of loads by a factor of around 32.</p><p>Now you’re probably wondering “how large can these blocks be”? To answer this question, let’s recall how memory is managed in modern GPUs.</p><h3>GPU Memory Hierarchy</h3><p>We distinguish four main types of memory in Nvidia GPUs. Here, we take the example of an A100:</p><ul><li><strong>Registers: </strong>The fastest and smallest type of memory on the GPU, residing directly within each Streaming Multiprocessor (SM). On the A100, each SM provides <strong>256 KB of register file</strong> space (65,536 × 32-bit registers), distributed among its threads. Each thread gets its own private 32-bit registers for storing temporary variables and intermediate results, avoiding memory traffic altogether. However, register usage per thread directly affects occupancy, as using too many registers per thread limits how many threads can run concurrently.</li><li><strong>L1/Shared Memory</strong>: On an A100, each SM has 192KB of SRAM that can be <a href="https://docs.nvidia.com/cuda/ampere-tuning-guide/index.html?utm_source=chatgpt.com#unified-shared-memory-l1-texture-cache">flexibly configured</a> as either a hardware-managed <strong>L1 cache</strong> or a programmer-managed <strong>shared memory</strong>. For performance-critical kernels like matrix multiplication, we explicitly use this space as shared memory to stage data tiles close to the compute units, bypassing the L1 cache entirely. This gives us fine-grained control over data reuse.</li><li><strong>L2 cache</strong>: This cache is slower than L1 but much larger, with around <strong>40 MB shared across all SMs</strong> on the A100. It serves as a global cache for both data and instructions, reducing the number of accesses to high-latency HBM memory. The L2 cache is <strong>coherent across SMs</strong>, meaning that updates from one SM are visible to others, enabling synchronisation between thread blocks. Its bandwidth can reach several terabytes per second, acting as a buffer between the fast on-chip SRAM and the slower HBM.</li><li><strong>High Bandwidth Memory (HBM)</strong>: This is the device memory, it has a capacity of either 40GB or 80GB depending on the A100 model. It provides <strong>extremely high bandwidth</strong> (up to <strong>2 TB/s on the 80 GB variant) </strong>but with <strong>much</strong> <strong>higher latency</strong> than on-chip caches. HBM is where large tensors, model weights, and datasets reside during execution. Since accessing HBM is expensive, efficient kernels aim to <strong>minimise data movement</strong> and <strong>maximise on-chip data reuse</strong> via registers and shared memory.</li></ul><p>As you can see, the memory hierarchy generally trades off capacity with latency. Therefore, maximising performance boils down to loading data from HBM into shared memory efficiently and reusing it as much as possible.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*EGGxz6dEolx6F5IM_hbYkQ.png" /><figcaption>GPU Memory Hierarchy, from fastest/smallest (top) to slowest/largest (bottom).</figcaption></figure><p>Choosing our block size is critical. We want blocks to be large enough to create a lot of parallel work, but small enough that their data fits in the SM’s shared memory and registers. A BLOCK_SIZE of <strong>64</strong> is a common starting point because it&#39;s a multiple of the <strong>warp size</strong> (32 threads), ensuring full hardware utilisation.</p><h3>Parallel Tiled GEMM</h3><p>With these considerations in mind, a natural follow-up to our tiled GEMM is to parallelise the computation of each pairs of tiles over several thread blocks, as depicted on the following animation.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/800/1*qqui6dSE8jWNbXR_OUIWFA.gif" /><figcaption>Parallel Tiled Matrix Multiplication. The iteration over tiles is replaced by a parallel operation over multiple thread blocks.</figcaption></figure><h3>Memory Coalescing</h3><p>Before writing tiled GEMM in Triton, we need to consider one last detail: <strong>memory coalescing</strong>, a technique that allows optimal use of global memory bandwidth. Memory coalescing is achieved when <strong>subsequent threads in a warp access subsequent memory addresses</strong>. Imagine a librarian needing to fetch books for a client, if all books are side-by-side on a shelf, they can grab them all at once. In contrast, if all books are lying on different shelves, they’ll have to grab them one by one, which takes significantly longer.</p><p>To understand how this applies to our case, note that matrices are stored linearly in memory, in other words a (2,2) matrix is stored as a sequence of 4 consecutive elements. Frameworks like PyTorch adopt a <strong>row-major</strong> layout, meaning that elements of a matrix are <strong>per-row contiguous in memory</strong>. For instance, elements of our (2,2) matrix would be stored as follows: [(0,0), (0,1), (1,0), (1,1)], notice that elements of the same row are <em>contiguous </em>(touching) while elements of the same column have a <em>stride</em> of 1 (separated by one element).</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*ed2RfXNQMZ5_VbdC4_6GPA.png" /><figcaption>PyTorch stores matrices in row-major layout. Elements of a row contiguous in memory while elements of a column are strided.</figcaption></figure><p>This implies that we can load rows using <strong>coalesced loads</strong>, but columns do <strong>not</strong> satisfy this condition. However, we need to access columns of Y to compute dot products. In order to maximise performance, a good practice is to transpose Y so that we iterate on its rows rather than its columns.</p><p>However, transposing Y isn’t enough to modify its layout in memory. As mentioned previously, PyTorch stores matrices in a flat array. Each matrix dimension is associated with a stride attribute, denoting the jump necessary to go from one element to the next one along this dimension. For instance, a (10,10) matrix would have strides=(10,1). Indeed, starting from element [0,0], element [1,0] is 10 memory slots (i.e. one row) away, whereas element [0,1] is adjacent.</p><p>When transposing a tensor, PyTorch doesn’t modify the layout in memory but simply recomputes the strides. In order to make the transpose effective from a memory standpoint we need to call Y.T.contiguous().</p><p>These are the required steps the load columns of Y efficiently, however we’ll need to transpose the loaded blocks within the kernel to perform the dot product properly: z_block = tl.dot(X_block, Y_block.T).</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*c-fP5mDPMOSWQdxLwjhXgw.png" /><figcaption>Representation of Y, Y.T and Y.T.contiguous() in their block representation and memory layout. The transpose operation changes the behaviour of the matrix but doesn’t modify its memory layout. This is why we need to add .contiguous() to enable coalesced reads on rows.</figcaption></figure><h3>Triton Implementation</h3><p>From here on, we first describe the kernel without memory coalescing to simplify the logic and pointer arithmetic before summarising the changes required to make the load operations coalesced on Y columns.</p><p>Let’s start by focusing on the PyTorch wrapper around the kernel. We need to read M, N, K from the input matrices and compute their strides since these constants will be useful later in the kernel. Then, we define the BLOCK_SIZE and declare the grid.</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/2a750f12b2cc0a271e45fe730a684d25/href">https://medium.com/media/2a750f12b2cc0a271e45fe730a684d25/href</a></iframe><p>Now let’s dive into the actual kernel code. We’re going to make use of Triton’s make_block_ptr utility, which simplifies the pointer arithmetic. We create one block pointer per matrix and pass the matrix shape, its strides, and the size of the block as inputs. Additionally, we specify the offset, the coordinate of the top-left element in the current block. For X, this corresponds to (m_idx * BLOCK_SIZE, 0) where m_idx is the index of the current block along the M dimension.</p><p>From there, we define z_acc, a zero matrix that will receive the partial dot-products as we iterate through tiles. We now iterate through the shared dimension N, loading blocks of size (BLOCK_SIZE, BLOCK_SIZE), and accumulate their dot products in z_acc. We then move the block pointers along the shared dimension by using .advance.</p><p>You might have noticed that when loading data, we use boundary_check and padding_option instead of mask and other as in the previous article. These arguments are specific to the use of block pointers and specify which axes to check for out-of-bound operations (here (0,1) for x and y) and how to treat those invalid values. Here we set them to zero to be ignored in the dot product.</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/15d95a3631676f7d2551882ad5d6ddb7/href">https://medium.com/media/15d95a3631676f7d2551882ad5d6ddb7/href</a></iframe><p>We can now take a look at the performance of this kernel by using the following function:</p><pre>def bench(fn: callable, x: torch.Tensor, y: torch.Tensor, repeat: int):<br>  flops = []<br>  med_latency = []<br><br>  for _ in tqdm(range(repeat), desc=f&quot;Benchmarking {fn.__name__}&quot;):<br>    latency_ms = triton.testing.do_bench(<br>      lambda: fn(x, y),<br>      quantiles=[0.5], # get the median latency<br>      return_mode=&quot;all&quot;,<br>      )<br>    n_flops = 2 * M * N * K # matmul roughly requires 2*M*N*K operations<br>    tflops = n_flops / (latency_ms / 1e3) / 1e12<br><br>    med_latency.append(latency_ms)<br>    flops.append(tflops)<br><br>  flops = np.array(flops)<br>  med_latency = np.array(med_latency)<br>  print(f&quot;Absolute Error: {torch.sum(torch.abs(X@Y - fn(x, y)))}&quot;)<br>  print(f&quot;Median Latency: {med_latency.mean():.4f} ± {med_latency.std():.3f} ms&quot;)<br>  print(f&quot;Throughput: {flops.mean():.4f} ± {flops.std():.3f} TeraFLOPS&quot;)<br><br><br>M = 8192<br>N = 6144<br>K = 4096<br><br>X = torch.randn((M, N), device=&quot;cuda&quot;, dtype=torch.float32)<br>Y = torch.randn((N, K), device=&quot;cuda&quot;, dtype=torch.float32)<br><br>bench(block_matmul, X, Y, repeat=10)</pre><p>We get the following outputs (using a T4 GPU on Colab):</p><pre>Absolute Error: 0.0 # the kernel outputs the correct result!<br>Median Latency: 130.7831 ± 1.794 ms<br>Throughput: 3.1533 ± 0.043 TeraFLOPS</pre><p>Now let’s review the changes required for coalesced loads on Y: we mainly need to flip the shape, strides and offsets when defining the block pointer for Y. Additionally, we update the block pointer to move along the column dimension (previously row dimension). The full code for this implementation is available on <a href="https://github.com/RPegoud/Triton-Kernels">GitHub</a>.</p><pre>@triton.jit<br>def coalesced_block_matmul_kernel(<br>    X_ptr, X_m_stride, X_n_stride,<br>    Y_ptr, Y_k_stride, Y_n_stride,<br>    Z_ptr, Z_m_stride, Z_k_stride,<br>    M, N, K,<br>    BLOCK_SIZE: tl.constexpr,<br>):<br>    ... <br>    y_block_ptr = tl.make_block_ptr(<br>        base=Y_ptr,<br>        # flip the shape, strides and offsets to match Y.T<br>        shape=(K, N),<br>        strides=(Y_k_stride, Y_n_stride), <br>        offsets=(k_idx * BLOCK_SIZE, 0),<br>        block_shape=(BLOCK_SIZE, BLOCK_SIZE),<br>        order=(0, 1),<br>    )<br>    ...<br><br>    for _ in range(0, N, BLOCK_SIZE):<br>        ... # loads<br>        z_acc += tl.dot(x, y.T)  # transpose Y back for dot product<br>        x_block_ptr = tl.advance(x_block_ptr, offsets=(0, BLOCK_SIZE))<br>        # advance the block pointer along columns of Y.T (i.e rows of Y)<br>        y_block_ptr = tl.advance(y_block_ptr, offsets=(0, BLOCK_SIZE))<br><br>    tl.store(pointer=z_block_ptr, value=z_acc, boundary_check=(0, 1))<br><br>def coalesced_block_matmul(X, Y):<br>    Y = Y.T.contiguous()  # Y is now (K,N)<br>    M, N = X.shape<br>    K, _ = Y.shape<br>    Z = torch.empty((M, K), device=&quot;cuda&quot;)<br><br>    x_stride_m, x_stride_n = X.stride()<br>    y_stride_k, y_stride_n = Y.stride()<br>    z_stride_m, z_stride_k = Z.stride()<br><br>    ...  # define BLOCK_SIZE and grid<br><br>    coalesced_block_matmul_kernel[grid](<br>        X, x_stride_m, x_stride_n,<br>        Y, y_stride_k, y_stride_n,<br>        Z, z_stride_m, z_stride_k,<br>        M, N, K,<br>        BLOCK_SIZE,<br>    )<br><br>    return Z</pre><p>Here are the results of our benchmark for the kernel with coalesced loads for Y:</p><pre>Absolute Error: 0.0 # Again, the kernel is correct!<br>Median Latency: 261.9420 ± 0.858 ms<br>Throughput: 1.5741 ± 0.005 TeraFLOPS</pre><p>Surprisingly, the throughput of this second kernel is only half of what we obtained with the first one, despite improving the efficiency of load operations 🤔</p><p>A quick inspection using nsight (Nvidia’s kernel profiler, more on that in a future article) reveals that the transpose operation within the kernel creates a “traffic jam”. Specifically, the transpose creates <strong>bank conflicts</strong>, causing threads to remain idle most of the time. Notably, the warp scheduler has no eligible warp to dispatch 87.6% of the time as they are waiting for the bank conflict to resolve. Additionally, the report reads:</p><pre>----------------------- ----------- --------------<br>Metric Name             Metric Unit   Metric Value<br>----------------------- ----------- --------------<br>...<br>DRAM Throughput                   %           8.20<br>Compute (SM) Throughput           %          21.14<br>...</pre><p>This indicates that the kernel is <strong>latency bound</strong> (i.e. neither memory nor compute bound, refer to the previous article for more details). In contrast, the first kernel is <strong>compute bound </strong>(i.e. increasing compute will improve performance) since the compute throughput is high compared to the DRAM throughput.</p><pre>----------------------- ----------- --------------<br>Metric Name             Metric Unit   Metric Value<br>----------------------- ----------- --------------<br>...<br>DRAM Throughput                   %          29.35<br>Compute (SM) Throughput           %          74.39<br>...</pre><h3>Conclusion</h3><p>This experiment highlights the importance of profiling and empirical validation. Even well-intentioned optimisations like coalescing memory accesses can introduce new bottlenecks if not evaluated carefully. The first kernel, though simpler, was compute-bound and better matched the hardware characteristics.</p><p>In the next articles of this series, we’ll implement a softmax kernel, paying particular attention to integrating Triton with PyTorch&#39;s autograd and profiling kernels using Nsight.</p><p>Until next time! 👋</p><h4>Useful Resources</h4><ul><li><a href="https://github.com/RPegoud/Triton-Kernels">Complete implementation</a></li><li><a href="https://www.cs.sfu.ca/~ashriram/Courses/CS7ARCH/hw/hw4.html">Introduction to GEMM and Assignment</a></li><li><a href="https://en.wikipedia.org/wiki/Ampere_(microarchitecture)#cite_note-15">Nvidia Ampere Architecture (A100 specs)</a></li></ul><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=44851b4146dd" width="1" height="1" alt=""><hr><p><a href="https://medium.com/data-science-collective/learning-triton-one-kernel-at-a-time-matrix-multiplication-44851b4146dd">Learning Triton One Kernel at a Time: Matrix Multiplication</a> was originally published in <a href="https://medium.com/data-science-collective">Data Science Collective</a> on Medium, where people are continuing the conversation by highlighting and responding to this story.</p>]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[Learning Triton One Kernel At a Time: Vector Addition]]></title>
            <link>https://medium.com/data-science-collective/learning-triton-one-kernel-at-a-time-vector-addition-5f57e9d2f3e1?source=rss-27fba63b402e------2</link>
            <guid isPermaLink="false">https://medium.com/p/5f57e9d2f3e1</guid>
            <category><![CDATA[pytorch]]></category>
            <category><![CDATA[triton]]></category>
            <category><![CDATA[machine-learning]]></category>
            <category><![CDATA[kernel]]></category>
            <category><![CDATA[gpu]]></category>
            <dc:creator><![CDATA[Ryan Pégoud]]></dc:creator>
            <pubDate>Wed, 29 Oct 2025 14:20:54 GMT</pubDate>
            <atom:updated>2025-10-29T15:58:17.877Z</atom:updated>
            <content:encoded><![CDATA[<p>The basics of GPU programming, optimisation, and your first Triton kernel!</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*WKrqisTspGEK5NHY" /><figcaption>Photo by <a href="https://unsplash.com/@omilaev?utm_source=medium&amp;utm_medium=referral">Igor Omilaev</a> on <a href="https://unsplash.com?utm_source=medium&amp;utm_medium=referral">Unsplash</a></figcaption></figure><p>In the era of billion-parameter models, a little optimisation goes a long way. Models like <strong>GPT4</strong> cost <strong>more than $100 millions to train</strong>, which makes a <strong>1% efficiency gain</strong> <strong>worth<em> over a million dollars</em></strong>. A powerful way to optimise the efficiency of machine learning models is by writing some of their components <strong>directly on the GPU</strong>. Now if you’re anything like me, the simple mention of CUDA kernels is enough to send chills down your spine, as they are notoriously complex to write and debug. <br>Fortunately, <strong>OpenAI</strong> released <strong>Triton</strong> in 2021, a new language and compiler abstracting away much of CUDA’s complexity and allowing less experienced practitioners to write performant kernels. A notable example is <strong>Unsloth</strong>, an LLM-training service that promises <strong>30x faster training</strong> with <strong>60% less memory usage</strong>, all thanks to <strong>replacing layers written in PyTorch with Triton kernels</strong>.<br>In this tutorial series, we’ll learn the basics of GPU architecture and how to implement high-performance Triton kernels!</p><h3>GPU Architecture Basics</h3><p>In this section, we’ll go through the very basics of (<em>Nvidia</em>) GPUs to get us started and write our first Triton kernel by the end of this article.<br>Starting from the smallest software unit, we can describe the hierarchy of execution units as follows:</p><ul><li><strong>Threads</strong>: The smallest <strong>unit of work</strong>, they run the user-defined kernel code.</li><li><strong>Warps</strong>: The smallest <strong>scheduling unit</strong>, they are always composed of 32 parallel threads, each with their own instruction address counter and register state. Threads in a warp <strong>start together</strong> but are <strong>free to branch</strong> and <strong>execute independently</strong>.</li><li><strong>Thread Blocks</strong>: Group of warps, where all threads can <strong>cooperate via shared memory</strong> and sync barriers. It is required that thread blocks can execute <strong>independently</strong> and in any order, in parallel or sequentially. This independence allows thread blocks to be <strong>scheduled in any order across any number of cores</strong>, so that GPU programs scale efficiently with the number of cores. We can synchronise the threads within a block at specific points in the kernel if needed, for example to synchronise memory access.</li><li><strong>Streaming Multiprocessor (SM)</strong>: A unit in charge of <strong>executing many warps in parallel</strong>, it owns shared memory and an L1 cache (holds the most recent global-memory lines that the SM has accessed). An SM has a dedicated <strong>warp scheduler</strong> that pull warps from the thread blocks that are ready to run.</li></ul><p>On the hardware side, the smallest unit of work is a <strong>CUDA core</strong>, the physical <strong>Arithmetic Logic Uni</strong>t (ALU) which performs <strong>arithmetic operations for a thread</strong> (or parts of it).</p><p>To summarise this section with an analogy, we could see <strong>CUDA cores</strong> as <strong>individual workers</strong>, while a <strong>warp</strong> is a <strong>squad of 32 workers</strong> given the same instruction at once. They may or may not execute this task the same way (branching) and can potentially complete it at a different point in time (independence). A <strong>thread block</strong> is composed of <strong>several squads sharing a common workspace</strong> (i.e. have shared memory), workers from all squads in the workspace can wait for each other to get lunch at the same time. A <strong>streaming multiprocessor </strong>is a <strong>factory floor with many squads working together and sharing tools and storage</strong>. Finally, the <strong>GPU</strong> is a <strong>whole plant</strong>, with many floors.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*kl-G_MdVLicCPZan12ciAA.png" /><figcaption>Hierarchy of an Nvidia GPU architecture.</figcaption></figure><h3>Optimisation Basics</h3><p>When optimising deep learning models, we are juggling with three main components:</p><ol><li><strong>Compute</strong>: Time spent by the GPU computing floating point operations (FLOPS).</li><li><strong>Memory</strong>: Time spent transferring tensors within a GPU.</li><li><strong>Overhead</strong>: All other operations (Python interpreter, PyTorch dispatch, …).</li></ol><p>Keeping those components in mind helps figuring out the right way to resolve a bottleneck. For instance, increasing compute (e.g. using a more powerful GPU) doesn’t help if most of the time is spent doing memory transfers. Ideally though, most of the time should be spent on compute, more precisely on matrix multiplications, the precise operation GPUs are optimised for. <br>This implies minimising the cost paid to move data around, either from the CPU to the GPU (”<strong>data transfer cost</strong>”), from one node to the other (”<strong>network cost</strong>”) or from CUDA global memory (<strong>DRAM</strong>, cheap but slow) to CUDA shared memory (<strong>SRAM</strong>, expensive but fastest on-device memory). The later is called <strong>bandwidth costs</strong> and is going to be our main focus for now. Common strategies to reduce bandwidth costs include:</p><ol><li><strong>Reusing</strong> data loaded in shared memory for multiple steps. A prime example of this is tiled matrix multiplication, which we’ll cover in a future post.</li><li><strong>Fusing</strong> multiple operations in a single kernel (since every kernel launch implies moving data from DRAM to SRAM), for instance we can fuse a matrix multiplication with an activation function. Generally, <strong>operator fusion</strong> can provide massive performance increase since it prevents a lot of global memory reads/writes and any two operators present an opportunity for fusion.</li></ol><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*LUnH4U23PfO-nBE9XNUCJA.png" /><figcaption>Matrix multiplication followed by a ReLU activation without operator fusion.</figcaption></figure><p>In this example, we perform a matrix multiplication x@W and store the result in an intermediate variable a. We then apply a relu to a and store the result in a variable y. This requires the GPU to read from x and W in global memory, write the result in a, read from a again and finally write in y. Instead, operator fusion would allow us to halve the amount of reads and writes to global memory by performing the matrix multiplication and applying the ReLU in a single kernel.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*ODwkR4DHvLH4rUNJPN1kog.png" /><figcaption>Fused matrix multiplication and ReLU activation.</figcaption></figure><h3>Triton</h3><p>We’ll now write our first Triton kernel, a simple vector addition. First, let’s walk through how this operation is broken down and executed on a GPU.</p><p>Consider wanting to sum the entries of two vectors X and Y, each with 7 elements (n_elements=7). <br>We’ll instruct the GPU to tackle this problem in chunks of 3 elements at a time (BLOCK_SIZE=3). Therefore, to cover all 7 elements of the input vectors, the GPU will launch 3 parallel “programs”, independent instance of our kernel, each with a unique program ID, pid:</p><ul><li>Program 0 is assigned elements 0, 1, 2.</li><li>Program 1 is assigned elements 3, 4, 5.</li><li>Program 2 is assigned element 6.</li></ul><p>Then, these programs will write back the results in a vector Z stored in global memory.<br>An important detail is that a kernel doesn’t receive an entire vector X, instead it receives a <strong>pointer to the memory address of the first element</strong>, X[0]. In order to access the actual values of X, we need to load them from global memory manually. <br>We can access the data for each block by using the program ID: block_start = pid * BLOCK_SIZE. From there, we can get the remaining element addresses for that block by computing offsets = block_start + range(0, BLOCK_SIZE) and load them into memory.<br>However, remember that program 2 is only assigned element 6, but its offsets are [6, 7, 8]. To avoid any indexing error, Triton lets us define a <strong>mask</strong> to identify valid target elements, here mask = offsets &lt; n_elements.<br>We can now safely load X and Y and add them together before writing the result back to an output variable Z in global memory in a similar way.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*PfCGXA_ZOkpTtjgkq5IGdg.png" /><figcaption>Per-block vector indexing. Slices of X, Y and Z are sent to independent thread blocks, each indexed by a unique ID.</figcaption></figure><p>Let’s take a closer look at the code, here’s the Triton kernel:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/fe9f4ffa10c444708170314dca784eff/href">https://medium.com/media/fe9f4ffa10c444708170314dca784eff/href</a></iframe><p>Let’s break down some of the Triton-specific syntax:</p><ul><li>First, a Triton kernel is always decorated by <a href="http://twitter.com/triton">@triton</a>.jit.</li><li>Second, some arguments need to be declared as static, meaning that they are known at compute-time. This is required for BLOCK_SIZE and is achieved by add the tl.constexpr type annotation. Also note that we do not annotate other variables, since they are not proper Python variables.</li><li>We use tl.program_id to access the ID of the current block, tl.arange behaves similarly to Numpy’s np.arange.</li><li>Loading and storing variables is achieved by calling tl.load and tl.store with arrays of pointers. Notice that there is no return statement, this role is delegated to tl.store.</li></ul><p>To use our kernel, we now need to write a <strong>PyTorch-level wrapper</strong> that provides memory pointers and defines a <strong>kernel grid</strong>. Generally, the kernel grid is a 1D, 2D or 3D tuple containing the <strong>number of thread blocks allocated to the kernel along each axis</strong>. In our previous example, we used a 1D grid of 3 thread blocks: grid = (3, ).<br>To handle varying array sizes, we default to grid = (ceil(n_elements / BLOCK_SIZE), ).</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/08d85e2be936f48da9feba1f04b26123/href">https://medium.com/media/08d85e2be936f48da9feba1f04b26123/href</a></iframe><p>Here are two final notes about the wrapper:</p><ol><li>You might have noticed that grid is defined as a lambda function. This allows Triton to compute the number of thread blocks to launch <strong>at launch time</strong>. Therefore, we compute the grid size based on the block size which is stored in meta, a dictionary of compile-time constants that are exposed to the kernel.</li><li>When calling the kernel, the value of output will be modified in-place, so we don’t need to reassign output = add_kernel[…].</li></ol><p>We can conclude this tutorial by verifying that our kernel works properly:</p><pre>x, y = torch.randn((2, 2048), device=&quot;cuda&quot;)<br><br>print(add(x, y))<br>&gt;&gt; tensor([ 1.8022, 0.6780, 2.8261, ..., 1.5445, 0.2563, -0.1846], device=&#39;cuda:0&#39;)<br><br>abs_difference = torch.abs((x + y) - add(x, y))<br>print(f&quot;Max absolute difference: {torch.max(abs_difference)}&quot;)<br>&gt;&gt; Max absolute difference: 0.0</pre><p>That’s it for this introduction, in following posts we’ll learn to implement more interesting kernels such as tiled matrix multiplication and see how to integrate Triton kernels in PyTorch models using autograd.</p><p>Until next time! 👋</p><h3>References and Useful Resources</h3><ul><li><a href="https://en.wikipedia.org/wiki/GPT-4#:~:text=Sam%20Altman%20stated%20that%20the,51">Cost of training</a></li><li><a href="https://unsloth.ai/introducing">Unsloth kernels</a></li><li><a href="https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py">Triton tutorial: Vector Addition</a></li><li><a href="https://horace.io/brrr_intro.html">Making Deep Learning Go Brrrr From First Principles</a></li></ul><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=5f57e9d2f3e1" width="1" height="1" alt=""><hr><p><a href="https://medium.com/data-science-collective/learning-triton-one-kernel-at-a-time-vector-addition-5f57e9d2f3e1">Learning Triton One Kernel At a Time: Vector Addition</a> was originally published in <a href="https://medium.com/data-science-collective">Data Science Collective</a> on Medium, where people are continuing the conversation by highlighting and responding to this story.</p>]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[Rainbow: The Colorful Evolution of Deep Q-Networks ]]></title>
            <link>https://medium.com/data-science/rainbow-the-colorful-evolution-of-deep-q-networks-37e662ab99b2?source=rss-27fba63b402e------2</link>
            <guid isPermaLink="false">https://medium.com/p/37e662ab99b2</guid>
            <category><![CDATA[reinforcement-learning]]></category>
            <category><![CDATA[dqn]]></category>
            <category><![CDATA[jax]]></category>
            <category><![CDATA[deep-dives]]></category>
            <category><![CDATA[deep-learning]]></category>
            <dc:creator><![CDATA[Ryan Pégoud]]></dc:creator>
            <pubDate>Fri, 12 Jul 2024 21:22:16 GMT</pubDate>
            <atom:updated>2024-07-13T18:04:34.566Z</atom:updated>
            <content:encoded><![CDATA[<h4>Everything you need to assemble the DQN Megazord in JAX.</h4><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*X6V_SKO4X1mk7rKH8WaRQw.png" /><figcaption>“The Rainbow Megazord”, Dall-E 3</figcaption></figure><p>In 2013, the introduction of Deep Q-Networks (DQN) by <em>Mnih et al.</em>[1]<em> </em>marked the first breakthrough in Deep Reinforcement Learning, surpassing expert human players in three Atari games. Over the years, several variants of DQN were published, each improving on specific weaknesses of the original algorithm.</p><p>In 2017, <em>Hessel et al.</em>[2]<em> </em>made the best out of the DQN palette by combining 6 of its powerful variants, crafting what could be called the DQN Megazord: Rainbow.</p><p>In this article, we’ll break down the individual components that make up Rainbow, while reviewing their JAX implementations in the <a href="https://github.com/EdanToledo/Stoix"><strong>Stoix library.</strong></a></p><h3>DQN</h3><p>The fundamental building block of Rainbow is DQN, an extension of Q-learning using a neural network with parameters <strong>θ</strong> to approximate the Q-function (i.e. action-value function). In particular, DQN uses convolutional layers to extract features from images and a linear layer to produce a scalar estimate of the Q-value.</p><p>During training, the network parameterized by <strong>θ</strong>, referred to as the <em>“online network”</em> is used to select actions while the <em>“target network”</em> parameterized by <strong>θ-</strong> is a delayed copy of the online network used to provide stable targets. This way, the targets are not dependent on the parameters being updated.<br>Additionally, DQN uses a replay buffer <strong><em>D</em></strong> to sample past transitions (observations, reward, and done flag tuples) to train on at fixed intervals.</p><p>At each iteration <strong><em>i</em></strong>, DQN samples a transition <strong><em>j </em></strong>and takes a gradient step on the following loss:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*zgDgFjtgvGsGUb1kYyaSzA.png" /><figcaption>DQN loss function, all images are made by the author, unless specified otherwise</figcaption></figure><p>This loss aims at minimizing the expectation of the squared temporal-difference (TD) error.</p><p>Note that DQN is an <strong>off-policy</strong> algorithm because it learns the optimal policy defined by the <strong>maximum Q-value</strong> term while following a different behavior policy, such as an epsilon-greedy policy.</p><p>Here’s the DQN algorithm in detail:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*8uRBmpYM16wcZEvqgZCQFg.png" /><figcaption>DQN algorithm</figcaption></figure><h4>DQN in practice</h4><p>As mentioned above, we’ll reference code snippets from the Stoix library to illustrate the core parts of DQN and Rainbow <em>(some of the code was slightly edited or commented for pedagogical purposes)</em>.</p><p>Let’s start with the neural network: Stoix lets us break down our model architecture into a pre-processor and a post-processor, referred to as <strong>torso</strong> and <strong>head</strong> respectively. In the case of DQN, the torso would be a multi-layer perceptron (MLP) or convolutional neural network (CNN) and the head an epsilon greedy policy, both implemented as <a href="https://flax.readthedocs.io/en/latest/index.html"><strong>Flax </strong></a>modules:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/7b6c514c0bc0bfd45d845f5527f78421/href">https://medium.com/media/7b6c514c0bc0bfd45d845f5527f78421/href</a></iframe><p>Additionally, DQN uses the following loss (<em>note that Stoix follows the </em><a href="https://github.com/google-deepmind/rlax"><strong><em>Rlax</em></strong></a><strong><em> </em></strong><em>naming conventions, therefore tm1 is equivalent to timestep t in the above equations, while t refers to timestep t+1</em>):</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/1be20d13144ef4888865ff6d20fa80e8/href">https://medium.com/media/1be20d13144ef4888865ff6d20fa80e8/href</a></iframe><h4>The Rainbow blueprint</h4><p>Now that we have laid the foundations for DQN, we’ll review each part of the algorithm in more detail, while identifying potential weaknesses and how they are addressed by Rainbow.<br>In particular, we’ll cover:</p><ul><li>Double DQN and the overestimation bias</li><li>Dueling DQN and the state-value / advantage prediction</li><li>Distributional DQN and the return distribution</li><li>Multi-step learning</li><li>Noisy DQN and flexible exploration strategies</li><li>Prioritized Experience Replay and learning potential</li></ul><figure><img alt="" src="https://cdn-images-1.medium.com/max/742/1*nP0nGY7dtgM0zKbr5HT-HA.png" /><figcaption>The Rainbow Blueprint, Dall-E 3</figcaption></figure><h3>Double DQN</h3><ul><li><strong>Source:</strong> <a href="http://arxiv.org/abs/1509.06461"><em>Deep Reinforcement Learning with Double Q-learning</em></a><em> </em>[3]</li><li><strong>Improvement:</strong> Reduced overestimation bias</li></ul><h4>The overestimation bias</h4><p>One issue with the loss function used in vanilla DQN arises from the Q-target. Remember that we define the target as:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*s1dH__rS9nRQaypJ4061kg.png" /><figcaption>Objective in the DQN loss</figcaption></figure><p>This objective may lead to an <strong>overestimation bias</strong>. Indeed, as DQN uses bootstrapping (learning estimates from estimates), the max term may select overestimated values to update the Q-function, leading to overestimated Q-values.</p><p>As an example, consider the following figure:</p><ul><li>The Q-values predicted by the network are represented in blue.</li><li>The true Q-values are represented in purple.</li><li>The gap between the predictions and true values is represented by red arrows.</li></ul><p>In this case, action 0 has the highest predicted Q-value because of a large prediction error. This value will therefore be used to construct the target. <br>However, the action with the highest true value is action 2. This illustration shows how the max term in the target favors <strong>large positive estimation errors</strong>, inducing an overestimation bias.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*ifcuiTXwna1NYAY6owLx8A.png" /><figcaption>Illustration of the overestimation bias.</figcaption></figure><h4>Decoupling action selection and evaluation</h4><p>To solve this problem, <em>Hasselt et al.</em> (2015)[3] propose a new target where the action is selected by the online network, while its value is estimated by the target network:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*bbFmTqWiHEB6GUJYl-cccA.png" /><figcaption>The Double DQN target</figcaption></figure><p>By decoupling action selection and evaluation, the estimation bias is significantly reduced, leading to better value estimates and improved performance.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*bXIdPN_xA4oz6unCyVsS0Q.png" /><figcaption>Double DQN provides stable and accurate value estimates, leading to improved performance. Source: Hasselt et al. (2015), Figure 3</figcaption></figure><h4>Double DQN in practice</h4><p>As expected, implementing Double DQN only requires us to modify the loss function:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/f214383ec1cd89468af2ec6fc067ca39/href">https://medium.com/media/f214383ec1cd89468af2ec6fc067ca39/href</a></iframe><h3>Dueling DQN</h3><ul><li><strong>Source:</strong> <a href="http://arxiv.org/abs/1511.06581"><em>Dueling Network Architectures for Deep Reinforcement Learning</em></a></li><li><strong>Improvement:</strong> Separation of the value and advantage computation</li></ul><h4>State value, Q-value, and advantage</h4><p>In RL, we use several functions to estimate the value of a given state, action, or sequence of actions from a given state:</p><ul><li><strong>State-value V(s): </strong>The state value corresponds to the expected return when starting in a given state <strong>s </strong>and following a policy <strong>π </strong>thereafter.</li><li><strong>Q-value Q(s, a): </strong>Similarly, the Q-value corresponds to the expected return when starting in a given state <strong>s</strong>, taking action<strong> a, </strong>and following a policy <strong>π </strong>thereafter.</li><li><strong>Advantage A(s, a): </strong>The advantage is defined as the difference between the Q-value and the state-value in a given state <strong>s </strong>for an action <strong>a</strong>. It represents the inherent value of action <strong>a </strong>in the current state.</li></ul><p>The following figure attempts to represent the differences between these value functions on a backup diagram (<em>note that the state value is weighted by the probability of taking each action under policy </em><strong>π</strong>).</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*eOAbQBykllUwckuo2FLa2g.png" /><figcaption>Visualization of the state value (in purple), state-action value (Q-function, in blue), and the advantage (in pink) on a backup diagram.</figcaption></figure><p>Usually, DQN estimates the Q-value directly, using a feed-forward neural network. This implies that DQN has to learn the Q-values for each action in each state independently.</p><h4>The dueling architecture</h4><p>Introduced by <em>Wang et al</em>.[4] in 2016, Dueling DQN uses a neural network with two separate streams of computation:</p><ul><li>The <strong>state value stream </strong>predicts the scalar value of a given state.</li><li>The <strong>advantage stream </strong>predicts to predict the advantage of each action for a given state.</li></ul><p>This decoupling enables the <strong>independent estimation</strong> of the state value and advantages, which has several benefits. For instance, the network can learn state values without having to update the action values regularly. Additionally, it can better generalize to unseen actions in familiar states.<br>These improvements lead to stabler and faster convergence, especially in environments with many similar-valued actions.</p><p>In practice, a dueling network uses a <strong>common representation </strong>(i.e. a shared linear or convolutional layer) parameterized by parameters <strong>θ</strong> before splitting into two streams, consisting of linear layers with parameters <strong>α</strong> and <strong>β</strong> respectively. The state value stream outputs a scalar value while the advantage stream returns a scalar value for each available action. <br>Adding the outputs of the two streams allows us to reconstruct the Q-value for each action as <strong>Q(s, a) = V(s) + A(s, a)</strong>.</p><p>An important detail is that the mean is usually subtracted from the advantages. Indeed, the advantages need to have<strong> zero mean</strong>, otherwise, it would be impossible to decompose Q into V and A, making the problem ill-defined. With this constraint, <strong>V</strong> represents the value of the state while <strong>A</strong> represents how much better or worse each action is compared to the average action in that state.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*wqmeuL471NfDOQs3-BvaqQ.png" /><figcaption>Illustration of a dueling network</figcaption></figure><h4>Dueling Network in practice</h4><p>Here’s the Stoix implementation of a Q-network:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/91e7ff3ec86d5f10983adbe1061653ef/href">https://medium.com/media/91e7ff3ec86d5f10983adbe1061653ef/href</a></iframe><h3>Distributional DQN</h3><ul><li><strong>Source:</strong> <a href="http://arxiv.org/abs/1707.06887">A distributional perspective on Reinforcement Learning</a>[5]</li><li><strong>Improvement:</strong> Richer value estimates</li></ul><h4>The return distribution</h4><p>Most RL systems model the expectation of the return, however, a promising body of literature approaches RL from a distributional perspective. In this setting, the goal becomes to model the <strong>return distribution</strong>, which allows us to consider other statistics than the mean.<br>In 2017, <em>Bellemare et al.</em>[5] published a distributional version of DQN called C51 predicting the return distribution for each action, reaching new state-of-the-art performances on the Atari benchmark.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/732/1*hequN6-JLnMm_mujmTQeBg.png" /><figcaption>Illustrated comparison between DQN and C51. Source [5&#39;]</figcaption></figure><p>Let’s take a step back and review the theory behind C51.<br>In traditional RL, we evaluate a policy using the <strong>Bellman Equation</strong>, which allows us to define the Q-function in a recursive form. Alternatively, we can use a distributional version of the Bellman equation, which accounts for randomness in the returns:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*8jPhPrsaWjGC6JO_V68E6A.png" /><figcaption>Standard and Distributional versions of the Bellman Equation</figcaption></figure><p>Here, <strong>ρ</strong> is the transition function.<br>The main difference between those functions is that <strong>Q</strong> <strong>is a numerical value</strong>, summing expectations over random variables. In contrast, <strong>Z is a random variable</strong>, summing the reward distribution and the discounted distribution of future returns.</p><p>The following illustration helps visualize how to derive <strong>Z </strong>from the distributional Bellman equation:</p><ul><li>Consider the distribution of returns <strong>Z</strong> at a given timestep and the transition operator <strong>Pπ.</strong> <strong>PπZ</strong> is the distribution of future returns <strong>Z(s’, a’)</strong>.</li><li>Multiplying this by the discount factor <strong>γ</strong> contracts the distribution towards 0 (as <strong>γ</strong> is less than 1).</li><li>Adding the reward distribution shifts the previous distribution by a set amount <em>(Note that the figure assumes a constant reward for simplicity. In practice, adding the reward distribution would shift but also modify the discounted return</em>).</li><li>Finally, the distribution is projected on a discrete support using an L2 projection operator <strong>Φ</strong>.</li></ul><figure><img alt="" src="https://cdn-images-1.medium.com/max/519/1*wmgLoYfR6x28kLbFM6tc7g.png" /><figcaption>Illustration of the distributional Bellman equation. Source: [5]</figcaption></figure><p>This fixed support is a vector of <strong><em>N</em></strong> atoms separated by a constant gap within a set interval:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*-sJ8bVmI-rl4fEEew8vQPg.png" /><figcaption>Definition of the discrete support <strong>z</strong></figcaption></figure><p>At inference time, the Q-network returns an approximating distribution <strong>dt</strong> defined on this support with the probability mass <strong>pθ(st, at) </strong>on each atom <strong><em>i</em></strong> such that:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*lTM-SEKtX2Qzc5TN2y1FkQ.png" /><figcaption>Predicted return distribution</figcaption></figure><p>The goal is to update <strong>θ</strong> such that the distribution closely matches the true distribution of returns. To learn the probability masses, the target distribution is built using a <strong>distributional variant of Bellman’s optimality equation</strong>:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*4gLS1YpzdZpYieCR8r24DA.png" /><figcaption>Target return distribution</figcaption></figure><p>To be able to compare the distribution predicted by our neural network and the target distribution, we need to discretize the target distribution and project it on the same support <strong>z</strong>.</p><p>To this end, we use an L2 projection (<em>a projection onto </em><strong><em>z</em></strong><em> such that the difference between the original and projected distribution is minimized in terms of the L2 norm</em>):</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*mIMIRPJTZXgSsIlvbMUDqg.png" /><figcaption>L2 projection of the target distribution</figcaption></figure><p>Finally, we need to define a loss function that minimizes the difference between the two distributions. As we’re dealing with distributions, we can’t simply subtract the prediction from the target, as we did previously.</p><p>Instead, we minimize the Kullback-Leibler divergence between <strong>dt </strong>and <strong>d’t </strong>(in practice, this is implemented as a cross-entropy loss):</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*I9wouicVn0QBSf2l0o4Qlg.png" /><figcaption>KL divergence between the projected target and the predicted return distribution</figcaption></figure><p><em>For a more exhaustive description of Distributional DQN, you can refer to Massimiliano Tomassoli’s article[8] as well as Pascal Poupart’s video on the topic[11].</em></p><h4>C51 in practice</h4><p>The key components of C51 in Stoix are the Distributional head and the categorical loss, which uses double Q-learning by default as introduced previously. The choice of defining the C51 network as a head lets us use an MLP or a CNN torso interchangeably depending on the use case.</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/743444ea903453114bdf41f2b6a6dfa3/href">https://medium.com/media/743444ea903453114bdf41f2b6a6dfa3/href</a></iframe><h3>Noisy DQN</h3><ul><li><strong>Source:</strong> <a href="http://arxiv.org/abs/1706.10295">Noisy Networks for Exploration</a>[6]</li><li><strong>Improvement:</strong> Learnable and state-dependent exploration mechanism</li></ul><h4>Noisy parameterization of Neural Networks</h4><p>As many off-policy algorithms, DQN relies on an epsilon-greedy policy as its main exploration mechanism. Therefore, the algorithm will behave greedily with respect to the Q-values most of the time and select random actions with a predefined probability.</p><p><em>Fortunato et al.</em>[6] introduce NoisyNets as a more flexible alternative. NoisyNets are neural networks whose weights and biases are <strong>perturbed</strong> by a <strong>parametric function of Gaussian noise</strong>. Similarly to an epsilon-greedy policy, such noise injects randomness in the agent’s action selection, thus encouraging exploration.</p><p>However, this noise is scaled and offset by <strong>learned parameters</strong>, allowing the level of noise to be adapted state-by-state. This way, the balance between exploration and exploitation is optimized <em>dynamically</em> during training. Eventually, the network may learn to ignore the noise, but will do so at <strong>different rates</strong> in <strong>different parts of the state space</strong>, leading to more flexible exploration.</p><p>A network parameterized by a vector of noisy parameters is defined as follows:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*t0NjLAh8LynCZEFBJsVp_A.png" /><figcaption>Neural Network parameterized by Noisy parameters</figcaption></figure><p>Therefore, a linear layer <strong>y = wx + b </strong>becomes:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*bv4upTlUK9X2Dke7UShW6w.png" /><figcaption>Noisy linear layer</figcaption></figure><p>For performance, the noise is generated at inference time using <strong>Factorized Gaussian Noise</strong>. For a linear layer with <strong>M </strong>inputs and <strong>N </strong>outputs, a noise matrix of shape (<strong>M x N</strong>) is generated as a combination of two noise vectors with size <strong>M</strong> and <strong>N</strong>. This methods reduces the number of required random variables from <strong>M x N </strong>to <strong>M + N</strong>.<br>The noise matrix is defined as the outer product of the noise vectors, each scaled by a function <strong>f</strong>:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*zpNEw8WATncHMb9Iu_0CpA.png" /><figcaption>Noise generation using Factorised Gaussian Noise</figcaption></figure><h4>Improved exploration</h4><p>The improved exploration induced by noisy networks allow a wide range of algorithms, such as DQN, Dueling DQN and A3C to benefit from improved performances with a reasonably low amount of extra parameters.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*fd5h7tODJ4G6FUsJ6wQoGw.png" /><figcaption>NoisyNets improve the performance of several algorithms on the Atari benchmark. Source: [6]</figcaption></figure><h4>Noisy DQN in practice</h4><p>In Stoix, we implement a noisy layer as follows:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/abf0db362f8e8592d6bc71ab78ef1a61/href">https://medium.com/media/abf0db362f8e8592d6bc71ab78ef1a61/href</a></iframe><p><em>Note: All the linear layers in Rainbow are replaced with their noisy equivalent (see the </em><strong><em>“Assembling Rainbow”</em></strong><em> section for more details).</em></p><h3>Prioritized Experience Replay</h3><p><strong>Source:</strong> Prioritized Experience Replay[7]<br><strong>Improvement:</strong> Prioritization of experiences with higher learning potential</p><h4>Estimating the Learning Potential</h4><p>After taking an environment step, vanilla DQN uniformly samples a batch of experiences (also called <em>transitions</em>) from a replay buffer and performs a gradient descent step on this batch. Although this approach produces satisfying results, some specific experiences might be more valuable from a learning perspective than others. Therefore, we could potentially speed up the training process by sampling such experiences more often.</p><p>This is precisely the idea explored in the Prioritized Experience Replay (PER) paper published by <em>Schaul et al.</em>[7] in 2016. However, the main question remains: how to approximate the <strong>expected learning potential</strong> of a transition?</p><blockquote>One idealized criterion would be the amount the RL agent can learn from a transition in its current state (expected learning progress). While this measure is not directly accessible, a reasonable proxy is the magnitude of a transition’s TD error δ, which indicates how ‘surprising’ or unexpected the transition is: specifically, how far the value is from its next-step bootstrap estimate (Andre et al., 1998).<br>Prioritized Experience Replay, Schaul et al. (2016)</blockquote><p>As a reminder, the TD error is defined as follows:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*w8IbPsTFpLgYfueR65Hnlw.png" /><figcaption>The temporal-difference error</figcaption></figure><p>This metric is a decent estimate of the learning potential of a specific transition, as a high TD error indicates a large difference between the predicted and actual outcomes, meaning that the agent would benefit from updating its beliefs.</p><p>However, it is worth noting that alternative prioritization metrics are still being studied. For instance, <em>Lahire et al.</em>[9] (2022) argue that the optimal sampling scheme is distributed according to the per-sample gradient norms:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*cymThHdJ312PnNgCBW8B4A.png" /><figcaption>Per-sample gradient norms</figcaption></figure><p>However, let’s continue with the TD error, as Rainbow uses this metric.</p><h4>Deriving Sampling Probabilities</h4><p>Once we have selected the prioritization criterion, we can derive the probabilities of sampling each transition from it. In Prioritized Experience Replay, two alternatives are showcased:</p><ul><li><strong>Proportional</strong>: Here the probability of replaying a transition is equal to the absolute value of the associated TD error. A small positive constant is added to prevent transitions not being revisited once their error is zero.</li><li><strong>Rank-based</strong>: In this mode, transitions are ranked in descending order according to their absolute TD error, and their probability is defined based on their rank. This option is supposed to be more robust as it is insensible to outliers.</li></ul><p>The sampling probabilities are then normalized and raised to the power <strong>α</strong>, a hyperparameter determining the degree of prioritization (<strong>α=0</strong> is the uniform case).</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*YfjpCvW_6DCwY9U8OK1Xhg.png" /><figcaption>Prioritization modes and probability normalization</figcaption></figure><h4>Importance sampling and bias annealing</h4><p>In RL, the estimation of the expected value of the return relies on the assumption that the updates correspond to the same distribution as the expectation (i.e., the uniform distribution). However, PER introduces bias as we now sample experiences according to their TD error.</p><p>To rectify this bias, we use <strong>importance sampling</strong>, a statistical method used to <em>estimate the properties of a distribution while sampling from a different distribution</em>. Importance sampling re-weights samples so that the estimates remain unbiased and accurate.<br>Typically, the correcting weights are defined as the ratio of the two probabilities:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*dTDmKcMX0F6SfZ5ubQHQ3A.png" /><figcaption>Importance sampling ratio</figcaption></figure><p>In this case, the target distribution is the uniform distribution, where every transition has a probability of being sampled equal to 1/<strong>N</strong>, with <strong>N </strong>being the size of the replay buffer. <br>Therefore, the importance sampling coefficient in the context of PER is defined by:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*OLuzHB3-WVErREoxCtXTaw.png" /><figcaption>Importance sampling weight used in PER</figcaption></figure><p>With <strong>β</strong> a coefficient adjusting the amount of bias correction (the bias is fully corrected for <strong>β=1</strong>). Finally, the weights are normalized for stability:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*oLNMy56p3ZuLlpXBzxwk0Q.png" /><figcaption>Normalization of the importance sampling weights</figcaption></figure><p>To summarize, here’s the full algorithm for Prioritized Experience Replay (the update and training steps are identical to DQN):</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*JK7-l89NqTEamkpVp_IeYA.png" /><figcaption>The Prioritized Experience Replay algorithm</figcaption></figure><h4>Increased convergence speed with PER</h4><p>The following plots highlight the performance benefits of PER. Indeed, the proportional and rank-based prioritization mechanisms enable DQN to reach the same baseline performances roughly twice as fast on the Atari benchmark.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*NEcTZgrj6Z91q86c7InGCw.png" /><figcaption>Normalized maximum and average scores (in terms of Double DQN performance) on 57 Atari games. Source:[7]</figcaption></figure><h4>Prioritized Experience Replay in practice</h4><p>Stoix seamlessly integrates the <a href="https://github.com/instadeepai/flashbax">Flashbax</a> library which provides a variety of replay buffers. Here are the relevant code snippets used to instantiate the replay buffer, compute the sampling probabilities from the TD error, and update the buffer’s priorities:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/06179b5ff16f3056c2599ecc457e1d8d/href">https://medium.com/media/06179b5ff16f3056c2599ecc457e1d8d/href</a></iframe><h3>Multi-step Learning</h3><ul><li><strong>Source:</strong> <a href="https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf">Reinforcement Learning: an Introduction, chapter 7</a></li><li><strong>Improvement:</strong> Enhanced reward signal and sample efficiency, reduced variance</li></ul><p>Multi-step learning is an improvement on traditional one-step temporal difference learning which allows us to consider the return over <strong>n</strong> steps when building our targets. For instance, instead of considering the reward at the next timestep, we’ll consider the n-step truncated rewards (see the below equation). This process has several advantages, among which:</p><ul><li><strong>Immediate feedback:</strong> considering a larger time horizon allows the agent to learn the value of state-action pairs much faster, especially in environments where rewards are delayed and specific actions might not pay out immediately.</li><li><strong>Sample efficiency:</strong> Each update in multi-step learning incorporates information from multiple time steps, making each sample more informative. This improves sample efficiency, meaning the agent can learn more from fewer experiences.</li><li><strong>Balancing Bias and Variance: </strong>Multi-step methods offer a trade-off between bias and variance. One-step methods have low bias but high variance, while multi-step methods have higher bias but lower variance. By tuning the number of steps, one can find a balance that works best for the given environment.</li></ul><p>The multi-step distributional loss used in Rainbow is defined as:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*ko2ZaPBNHqQnbyyE0glidQ.png" /><figcaption>Multi-step target return distribution</figcaption></figure><p>In practice, using n-step returns implies a few adjustments to our code:</p><ul><li>We now sample trajectories of <strong>n</strong> experiences, instead of individual experiences</li><li>The reward is replaced with the n-step discounted returns</li><li>The done flag is set to True if any of the <strong>n </strong>done flag is True</li><li>The next state <strong>s(t+1)</strong> is replaced by the last observation of the trajectory <strong>s(t+n)</strong></li></ul><h4>Multi-Step learning in practice</h4><p>Finally, we can reuse the categorical loss function used in C51 with these updated inputs:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/f667c9dc3fbefd3824f58ccb43677608/href">https://medium.com/media/f667c9dc3fbefd3824f58ccb43677608/href</a></iframe><h3>Assembling Rainbow</h3><p>Congratulations on making it this far! We now have a better understanding of all the moving pieces that constitute Rainbow. Here’s a summary of the Rainbow agent:</p><ul><li><strong>Neural Network Architecture:</strong><br> —<strong> Torso:</strong> A convolutional neural network (CNN) or multi-layer perceptron (MLP) base that creates embeddings for the head network.<br> — <strong>Head:</strong> Combines Dueling DQN and C51. The value stream outputs the state value distribution over atoms, while the advantage stream outputs the advantage distribution over actions and atoms. These streams are aggregated, and Q-values are computed as the weighted sum of atom values and their respective probabilities. An action is selected using an epsilon-greedy policy.<br> —<strong> Noisy Layers: </strong>All linear layers are replaced with their noisy equivalents to aid in exploration.</li><li><strong>Loss Function:</strong> Uses a distributional loss modeling the n-step returns, where targets are computed using Double Q-learning.</li><li><strong>Replay Buffer: </strong>Employs a prioritization mechanism based on the TD error to improve learning efficiency.</li></ul><p>Here’s the network used for the Rainbow head:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/548655defb33bdb46b76dccff1ec4c25/href">https://medium.com/media/548655defb33bdb46b76dccff1ec4c25/href</a></iframe><h4>Performances and ablations</h4><p>To conclude this article, let’s take a closer look at Rainbow’s performances on the Atari benchmark, as well as the ablation study.<br>The following figure compares Rainbow with the other DQN baselines we studied. The measured metric is the median human-normalized score. In other words, the median human performance on Atari games is set to 100%, which enables us to quickly spot algorithms achieving a human level.</p><p>Three of the DQN baselines reach this level after 200 million frames:</p><ul><li><strong>Distributional DQN</strong></li><li><strong>Dueling DQN</strong></li><li><strong>Prioritized Double DQN</strong></li></ul><p>Interestingly, Rainbow reaches the same level after only 44 million frames, making it <strong>roughly 5 times more sample efficient</strong> than the best baselines. At the end of training, it exceeds <strong>200%</strong> of the median human-normalized score.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/958/1*Whxt4H20QUrV_RBDRt7Qcw.png" /><figcaption>Median human-normalized performance across 57 Atari games. Each line represents a DQN baseline. Source: [2]</figcaption></figure><p>This second figure represents the ablation study, which represents the performances of Rainbow without one of its components. These results allow us to make several observations:</p><ul><li>The three most crucial components of Rainbow are the distributional head, the use of multi-step learning, and the prioritization of the replay buffer.</li><li>Noisy layers contribute significantly to the overall performance. Using standard layers with an epsilon-greedy policy doesn’t allow the agent to reach the 200% score in 200 million frames.</li><li>Despite achieving strong performances on their own, the dueling structure and double Q-learning only provide marginal improvements in the context of Rainbow.</li></ul><figure><img alt="" src="https://cdn-images-1.medium.com/max/955/1*OhtiBMXnBj8T9Fibuntmuw.png" /><figcaption>Median human-normalized performance across 57 Atari games. Each line represents an ablation of Rainbow. Source: [2]</figcaption></figure><p>Thank you very much for reading this article, I hope it provided you with a comprehensive introduction to Rainbow and its components. I highly advise reading through the <a href="https://github.com/EdanToledo/Stoix/blob/main/stoix/systems/q_learning/ff_rainbow.py"><strong>Stoix implementation of Rainbow</strong></a> for a more detailed description of the training process and the Rainbow architecture.</p><p>Until next time 👋</p><h3>Bibliography</h3><p>[1] Mnih, V., Kavukcuoglu, K., Silver, D., Graves, A., Antonoglou, I., Wierstra, D., &amp; Riedmiller, M. (2013). <a href="http://arxiv.org/abs/1312.5602"><strong><em>Playing Atari with Deep Reinforcement Learning</em></strong></a>, arXiv<br>[2] Hessel, M., Modayil, J., van Hasselt, H., Schaul, T., Ostrovski, G., Dabney, W., Horgan, D., Piot, B., Azar, M., &amp; Silver, D. (2017). <a href="http://arxiv.org/abs/1710.02298"><strong><em>Rainbow: Combining Improvements in Deep Reinforcement Learning</em></strong></a>, arXiv.<br>[3] van Hasselt, H., Guez, A., &amp; Silver, D. (2015). <a href="http://arxiv.org/abs/1509.06461"><strong><em>Deep Reinforcement Learning with Double Q-learning</em></strong></a>, arXiv. <br>[4] Wang, Z., Schaul, T., Hessel, M., van Hasselt, H., Lanctot, M., &amp; de Freitas, N. (2016). <a href="http://arxiv.org/abs/1511.06581"><strong><em>Dueling Network Architectures for Deep Reinforcement Learning</em></strong></a> (No. arXiv:1511.06581), arXiv<br>[5] Bellemare, M. G., Dabney, W., &amp; Munos, R. (2017). <a href="http://arxiv.org/abs/1707.06887"><strong><em>A Distributional Perspective on Reinforcement Learning</em></strong></a>, arXiv<br>[5&#39;] Dabney, W., Ostrovski, G., Silver, D., &amp; Munos, R. (2018). <a href="http://arxiv.org/abs/1806.06923"><strong><em>Implicit Quantile Networks for Distributional Reinforcement Learning</em></strong></a>, arXiv<br><a href="http://arxiv.org/abs/1806.06923](http://arxiv.org/abs/1806.06923)[6]">[6]</a> Fortunato, M., Azar, M. G., Piot, B., Menick, J., Osband, I., Graves, A., Mnih, V., Munos, R., Hassabis, D., Pietquin, O., Blundell, C., &amp; Legg, S. (2019). <a href="http://arxiv.org/abs/1706.10295"><strong><em>Noisy Networks for Exploration</em></strong></a>, arXiv. <br>[7] Schaul, T., Quan, J., Antonoglou, I., &amp; Silver, D. (2016). <a href="http://arxiv.org/abs/1511.05952"><strong><em>Prioritized Experience Replay</em></strong></a><strong><em>,</em></strong> arXiv</p><h4>Additional resources</h4><p>[8] Massimiliano Tomassoli, <a href="https://mtomassoli.github.io/2017/12/08/distributional_rl/"><strong><em>Distributional RL: An intuitive explanation of Distributional RL</em></strong></a><br>[9] Lahire, T., Geist, M., &amp; Rachelson, E. (2022). <a href="http://arxiv.org/abs/2110.01528"><strong><em>Large Batch Experience Replay</em></strong></a>, arXiv. <br>[10] Sutton, R. S., &amp; Barto, A. G. (1998). <strong><em>Reinforcement Learning: An Introduction</em></strong>.<br>[11] Pascal Poupart, <a href="https://youtu.be/r-Yk6-jagDU?si=9lqQHHNaQz8Uiclw"><strong><em>CS885 Module 5: Distributional RL</em></strong></a><strong><em>, </em></strong>YouTube</p><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=37e662ab99b2" width="1" height="1" alt=""><hr><p><a href="https://medium.com/data-science/rainbow-the-colorful-evolution-of-deep-q-networks-37e662ab99b2">Rainbow: The Colorful Evolution of Deep Q-Networks 🌈</a> was originally published in <a href="https://medium.com/data-science">TDS Archive</a> on Medium, where people are continuing the conversation by highlighting and responding to this story.</p>]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[A Practical Guide to Proximal Policy Optimization in JAX]]></title>
            <link>https://medium.com/data-science/breaking-down-state-of-the-art-ppo-implementations-in-jax-6f102c06c149?source=rss-27fba63b402e------2</link>
            <guid isPermaLink="false">https://medium.com/p/6f102c06c149</guid>
            <category><![CDATA[jax]]></category>
            <category><![CDATA[programming]]></category>
            <category><![CDATA[reinforcement-learning]]></category>
            <category><![CDATA[tips-and-tricks]]></category>
            <category><![CDATA[implementation]]></category>
            <dc:creator><![CDATA[Ryan Pégoud]]></dc:creator>
            <pubDate>Wed, 01 May 2024 05:32:45 GMT</pubDate>
            <atom:updated>2024-07-09T20:08:58.497Z</atom:updated>
            <content:encoded><![CDATA[<h4>All the tricks and details you wish you knew about PPO</h4><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*9iiDaMnE92OHLdVX" /><figcaption>Photo by <a href="https://unsplash.com/@lorenzoherrera?utm_source=medium&amp;utm_medium=referral">Lorenzo Herrera</a> on <a href="https://unsplash.com?utm_source=medium&amp;utm_medium=referral">Unsplash</a></figcaption></figure><p>Since its publication in a <a href="https://arxiv.org/pdf/1707.06347.pdf">2017 paper by OpenAI</a>, Proximal Policy Optimization (PPO) is widely regarded as one of the state-of-the-art algorithms in Reinforcement Learning. Indeed, PPO has demonstrated remarkable performances across various tasks, from <a href="https://openai.com/research/openai-five">attaining superhuman performances in Dota 2</a> teams to solving a <a href="https://openai.com/research/solving-rubiks-cube">Rubik’s cube with a single robotic hand</a> while maintaining three main advantages: simplicity, stability, and sample efficiency.</p><p>However, implementing RL algorithms from scratch is notoriously difficult and error-prone, given the numerous error sources and implementation details to be aware of.</p><p>In this article, we’ll focus on breaking down the clever tricks and programming concepts used in a popular implementation of PPO in JAX. Specifically, we’ll focus on the <a href="https://github.com/luchris429/purejaxrl/blob/main/purejaxrl/ppo.py">implementation featured in the PureJaxRL library</a>, developed by <a href="https://chrislu.page">Chris Lu</a>.</p><p><em>Disclaimer: Rather than diving too deep into theory, this article covers the practical implementation details and (numerous) tricks used in popular versions of PPO. Should you require any reminders about PPO’s theory, please refer to the “</em><strong><em>references</em></strong><em>” section at the end of this article. Additionally, all the code (minus the added comments) is copied directly from PureJaxRL for pedagogical purposes.</em></p><p><a href="https://github.com/luchris429/purejaxrl/tree/main">GitHub - luchris429/purejaxrl: Really Fast End-to-End Jax RL Implementations</a></p><h3><strong>Actor-Critic Architectures</strong></h3><p>Proximal Policy Optimization is categorized within the policy gradient family of algorithms, a subset of which includes actor-critic methods. The designation ‘actor-critic’ reflects the dual components of the model:</p><ul><li>The <strong>actor network</strong> creates a <strong>distribution over actions</strong> given the current state of the environment and returns an action sampled from this distribution. Here, the actor network comprises three dense layers separated by two activation layers (either ReLU or hyperbolic tangeant) and a final categorical layer applying the <strong>softmax</strong> function to the computed distribution.</li><li>The <strong>critic network</strong> <strong>estimates the value function of the current state</strong>, in other words, how good a particular action is at a given time. Its architecture is almost identical to the actor network, except for the final softmax layer. Indeed, the critic network doesn’t apply any activation function to the final dense layer outputs as it performs a regression task.</li></ul><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*hzKJYLAu5ZeXaFE6YKU_Nw.jpeg" /><figcaption>Actor-critic architecture, as defined in PureJaxRL (illustration made by the author)</figcaption></figure><p>Additionally, this implementation pays particular attention to <strong>weight initialization</strong> in dense layers. Indeed, all dense layers are initialized by <strong>orthogonal matrices</strong> with specific coefficients. This initialization strategy has been shown to <strong>preserve the gradient norms</strong> (i.e. scale) during forward passes and backpropagation, leading to <strong>smoother convergence</strong> and limiting the risks of vanishing or exploding gradients[1].</p><p>Orthogonal initialization is used in conjunction with specific scaling coefficients:</p><ul><li><strong>Square root of 2</strong>: Used for the first two dense layers of both networks, this factor aims to <strong>compensate for the variance reduction</strong> induced by ReLU activations (as inputs with negative values are set to 0). For the tanh activation, the Xavier initialization is a popular alternative[2].</li><li><strong>0.01: </strong>Used in the last dense layer of the actor network, this factor helps to <strong>minimize the initial differences in logit values</strong> before applying the softmax function. This will reduce the difference in action probabilities and thus <strong>encourage early exploration</strong>.</li><li><strong>1: </strong>As the critic network is performing a regression task, we do not scale the initial weights.</li></ul><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/6f187b7361f6450202f999af5aa17df0/href">https://medium.com/media/6f187b7361f6450202f999af5aa17df0/href</a></iframe><h3>Training Loop</h3><p>The training loop is divided into 3 main blocks that share similar coding patterns, taking advantage of Jax’s functionalities:</p><ol><li><strong>Trajectory collection:</strong> First, we’ll interact with the environment for a set number of steps and collect observations and rewards.</li><li><strong>Generalized Advantage Estimation (GAE):</strong> Then, we’ll approximate the expected return for each trajectory by computing the generalized advantage estimation.</li><li><strong>Update step: </strong>Finally, we’ll compute the gradient of the loss and update the network parameters via gradient descent.</li></ol><p>Before going through each block in detail, here’s a quick reminder about the jax.lax.scan<em> </em>function that will show up multiple times throughout the code:</p><h4>Jax.lax.scan</h4><p>A common programming pattern in JAX consists of defining a function that acts on a single sample and using jax.lax.scan<em> </em>to <strong>iteratively apply it to elements of a sequence</strong> or an array, while carrying along some state.<br>For instance, we’ll apply it to the step function to step our environment N consecutive times while carrying the new state of the environment through each iteration.</p><p>In pure Python, we could proceed as follows:</p><pre>trajectories = []<br><br>for step in range(n_steps):<br>  action = actor_network(obs)<br>  obs, state, reward, done, info = env.step(action, state)<br>  trajectories.append(tuple(obs, state, reward, done, info))</pre><p>However, we avoid writing such loops in JAX for performance reasons (as pure Python loops are incompatible with JIT compilation). The alternative is jax.lax.scan<em> </em>which is equivalent to:</p><pre>def scan(f, init, xs, length=None):<br>  &quot;&quot;&quot;Example provided in the JAX documentation.&quot;&quot;&quot;<br>  if xs is None:<br>    xs = [None] * length<br><br>  carry = init<br>  ys = []<br>  for x in xs:<br>    # apply function f to current state<br>    # and element x<br>    carry, y = f(carry, x) <br>    ys.append(y)<br>  return carry, np.stack(ys)</pre><p>Using jax.lax.scan is more efficient than a Python loop because it allows the transformation to be optimized and executed as a single compiled operation rather than interpreting each loop iteration at runtime.</p><p>We can see that the scan function takes multiple arguments:</p><ul><li><strong>f:</strong> A function that is applied at each step. It takes the current state and an element of xs (or a placeholder if xs is None) and returns the updated state and an output.</li><li><strong>init:</strong> The initial state that f will use in its first invocation.</li><li><strong>xs:</strong> A sequence of inputs that are iteratively processed by f. If xs is None, the function simulates a loop with length iterations using None as the input for each iteration.</li><li><strong>length:</strong> Specifies the number of iterations if xs is None, ensuring that the function can still operate without explicit inputs.</li></ul><p>Additionally, scan returns:</p><ul><li><strong>carry:</strong> The final state after all iterations.</li><li><strong>ys:</strong> An array of outputs corresponding to each step’s application of f, stacked for easy analysis or further processing.</li></ul><p>Finally, scan can be used in combination with vmap to scan a function over multiple dimensions in parallel. As we’ll see in the next section, this allows us to interact with several environments in parallel to collect trajectories rapidly.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*dUNkLv6GDm03xIM2J_HWPQ.jpeg" /><figcaption>Illustration of vmap, scan, and scan + vmap in the context of the step function (made by the author)</figcaption></figure><h3>1. Trajectory Collection</h3><p>As mentioned in the previous section, the trajectory collection block consists of a step function scanned across N iterations. This step function successively:</p><ul><li>Selects an action using the actor network</li><li>Steps the environment</li><li>Stores transition data in a transition tuple</li><li>Stores the model parameters, the environment state, the current observation, and rng keys in a runner_state tuple</li><li>Returns runner_state and transition</li></ul><p>Scanning this function returns the latest runner_state and traj_batch, an array of transition tuples. In practice, transitions are collected from multiple environments in parallel for efficiency as indicated by the use of jax.vmap(env.step, …)(for more details about vectorized environments and vmap, refer to my <a href="https://medium.com/towards-data-science/vectorize-and-parallelize-rl-environments-with-jax-q-learning-at-the-speed-of-light-49d07373adf5">previous article</a>).</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/ee2875a6bdb941f399155c6c0904c4c0/href">https://medium.com/media/ee2875a6bdb941f399155c6c0904c4c0/href</a></iframe><h3>2. Generalized Advantage Estimation</h3><p>After collecting trajectories, we need to compute the <strong>advantage function, </strong>a crucial component of PPO’s loss function. The advantage function measures how much better a specific action is compared to the average action in a given state:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/218/0*8JKsD0mv7SRmWbW3.png" /></figure><p>Where <strong>Gt </strong>is the return at time <strong><em>t</em></strong><em> </em>and <strong>V(St) is </strong>the value of state <strong><em>s</em></strong><em> </em>at time <strong><em>t</em></strong><em>.</em></p><p>As the return is generally unknown, we have to approximate the advantage function. A popular solution is <strong>generalized advantage estimation</strong>[3], defined as follows:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/744/0*vtoc2sT_x8Vg-bTR.png" /></figure><p>With <strong>γ</strong> the discount factor, <strong>λ</strong> a parameter that controls the trade-off between bias and variance in the estimate, and<strong> <em>δt</em></strong><em> </em>the temporal difference error at time <strong><em>t</em></strong>:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/522/0*VYM1lRnvvZ-8dWEO.png" /></figure><p>As we can see, the value of the GAE at time <em>t </em>depends on the GAE at future timesteps. Therefore, we compute it backward, starting from the end of a trajectory. For example, for a trajectory of 3 transitions, we would have:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/410/0*5DxYwJmbHIqtkIfy.png" /></figure><p>Which is equivalent to the following recursive form:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/423/0*IDa--MnSBjNWuwDR.png" /></figure><p>Once again, we use jax.lax.scan on the trajectory batch (this time in reverse order) to iteratively compute the GAE.</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/36dd1edacd3ecf53a1d203f46999f828/href">https://medium.com/media/36dd1edacd3ecf53a1d203f46999f828/href</a></iframe><p>Note that the function returns advantages + traj_batch.value as a second output, which is equivalent to the return according to the first equation of this section.</p><h3>3. Update step</h3><p>The final block of the training loop defines the loss function, computes its gradient, and performs gradient descent on minibatches. Similarly to previous sections, the update step is an arrangement of several functions in a hierarchical order:</p><pre>def _update_epoch(update_state, unused):<br>  &quot;&quot;&quot;<br>  Scans update_minibatch over shuffled and permuted <br>  mini batches created from the trajectory batch.<br>  &quot;&quot;&quot;<br><br>  def _update_minbatch(train_state, batch_info):<br>    &quot;&quot;&quot;<br>    Wraps loss_fn and computes its gradient over the <br>    trajectory batch before updating the network parameters.<br>    &quot;&quot;&quot;<br>    ...<br>    <br>    def _loss_fn(params, traj_batch, gae, targets):<br>      &quot;&quot;&quot;<br>      Defines the PPO loss and computes its value.<br>      &quot;&quot;&quot;<br>      ...</pre><p>Let’s break them down one by one, starting from the innermost function of the update step.</p><h4>3.1 Loss function</h4><p>This function aims to define and compute the PPO loss, originally defined as:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*zYp6PTEfKoMjU2RT.png" /></figure><p>Where:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*qaiCVqnYPO2nue09.png" /></figure><p>However, the PureJaxRL implementation features some tricks and differences compared to the original PPO paper[4]:</p><ul><li>The paper defines the PPO loss in the context of gradient ascent whereas the implementation performs gradient descent. Therefore, the sign of each loss component is reversed.</li><li>The value function term is modified to include an additional clipped term. This could be seen as a way to make the value function updates more conservative (as for the clipped surrogate objective):</li></ul><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*YQZP3hq6jiAxK_Ta.png" /></figure><ul><li>The GAE is standardized.</li></ul><p>Here’s the complete loss function:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/46f2d043c070808a7da5d97342afe905/href">https://medium.com/media/46f2d043c070808a7da5d97342afe905/href</a></iframe><h4>3.2 Update Minibatch</h4><p>The update_minibatch function is essentially a wrapper around loss_fn used to compute its gradient over the trajectory batch and update the model parameters stored in train_state.</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/a6798898fe3ef8800e8354098b03aaa8/href">https://medium.com/media/a6798898fe3ef8800e8354098b03aaa8/href</a></iframe><h4>3.3 Update Epoch</h4><p>Finally, update_epoch wraps update_minibatch and applies it on minibatches. Once again, jax.lax.scan is used to apply the update function on all minibatches iteratively.</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/725498766e43cfe26f21b1961bb49d01/href">https://medium.com/media/725498766e43cfe26f21b1961bb49d01/href</a></iframe><h3>Conclusion</h3><p>From there, we can wrap all of the previous functions in an update_step function and use scan one last time for N steps to complete the training loop.</p><p>A global view of the training loop would look like this:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/8408eb84bc2b05ecd9c2ae8ebebc8c4e/href">https://medium.com/media/8408eb84bc2b05ecd9c2ae8ebebc8c4e/href</a></iframe><p>We can now run a fully compiled training loop using jax.jit(train(rng)) or even train multiple agents in parallel using jax.vmap(train(rng)).</p><p>There we have it! We covered the essential building blocks of the PPO training loop as well as common programming patterns in JAX.</p><p>To go further, I highly recommend reading the <a href="https://github.com/luchris429/purejaxrl/blob/main/purejaxrl/ppo.py">full training script</a> in detail and running example notebooks on the PureJaxRL repository.</p><p><a href="https://github.com/luchris429/purejaxrl">GitHub - luchris429/purejaxrl: Really Fast End-to-End Jax RL Implementations</a></p><p>Thank you very much for your support, until next time 👋</p><h4>References:</h4><p><a href="https://github.com/luchris429/purejaxrl/blob/main/purejaxrl/ppo.py">Full training script</a>, PureJaxRL, Chris Lu, 2023</p><p>[1] <a href="https://smerity.com/articles/2016/orthogonal_init.html"><strong><em>Explaining and illustrating orthogonal initialization for recurrent neural networks</em></strong></a>, Smerity, 2016</p><p>[2] <a href="https://www.deeplearning.ai/ai-notes/initialization/index.html"><strong><em>Initializing neural networks</em></strong></a>, DeepLearning.ai</p><p>[3] <a href="https://towardsdatascience.com/generalized-advantage-estimation-in-reinforcement-learning-bf4a957f7975"><strong><em>Generalized Advantage Estimation in Reinforcement Learning</em></strong></a>, Siwei Causevic, Towards Data Science, 2023</p><p>[4] <a href="https://arxiv.org/pdf/1707.06347"><strong><em>Proximal Policy Optimization Algorithms</em></strong></a>, Schulman et Al., OpenAI, 2017</p><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=6f102c06c149" width="1" height="1" alt=""><hr><p><a href="https://medium.com/data-science/breaking-down-state-of-the-art-ppo-implementations-in-jax-6f102c06c149">A Practical Guide to Proximal Policy Optimization in JAX</a> was originally published in <a href="https://medium.com/data-science">TDS Archive</a> on Medium, where people are continuing the conversation by highlighting and responding to this story.</p>]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[A Gentle Introduction to Deep Reinforcement Learning in JAX]]></title>
            <link>https://medium.com/data-science/a-gentle-introduction-to-deep-reinforcement-learning-in-jax-c1e45a179b92?source=rss-27fba63b402e------2</link>
            <guid isPermaLink="false">https://medium.com/p/c1e45a179b92</guid>
            <category><![CDATA[deep-learning]]></category>
            <category><![CDATA[jax]]></category>
            <category><![CDATA[reinforcement-learning]]></category>
            <category><![CDATA[machine-learning]]></category>
            <category><![CDATA[getting-started]]></category>
            <dc:creator><![CDATA[Ryan Pégoud]]></dc:creator>
            <pubDate>Tue, 21 Nov 2023 17:51:55 GMT</pubDate>
            <atom:updated>2023-11-21T17:51:55.079Z</atom:updated>
            <content:encoded><![CDATA[<h4>Solving the CartPole environment with DQN in under a second</h4><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*vGRKNV5tJU45e09A" /><figcaption>Photo by <a href="https://unsplash.com/@thomasdes?utm_source=medium&amp;utm_medium=referral">Thomas Despeyroux</a> on <a href="https://unsplash.com?utm_source=medium&amp;utm_medium=referral">Unsplash</a></figcaption></figure><p>Recent progress in Reinforcement Learning (RL), such as Waymo’s autonomous taxis or DeepMind’s superhuman chess-playing agents, complement <strong>classical RL</strong> with <strong>Deep Learning </strong>components such as <strong>Neural Networks</strong> and <strong>Gradient Optimization</strong> methods.</p><p>Building on the foundations and coding principles introduced in one of my previous stories, we’ll discover and learn to implement <strong>Deep Q-Networks</strong> (<strong>DQN</strong>) and <strong>replay buffers</strong> to solve OpenAI’s <strong>CartPole </strong>environment. All of that <strong>in under a second</strong> using JAX!</p><p>For an introduction to <strong>JAX</strong>, <strong>vectorized environments</strong>, and <strong>Q-learning</strong>, please refer to the content of this story:</p><p><a href="https://towardsdatascience.com/vectorize-and-parallelize-rl-environments-with-jax-q-learning-at-the-speed-of-light-49d07373adf5">Vectorize and Parallelize RL Environments with JAX: Q-learning at the Speed of Light⚡</a></p><p>Our framework of choice for deep learning will be DeepMind’s <strong>Haiku </strong>library, which I recently introduced in the context of Transformers:</p><p><a href="https://towardsdatascience.com/implementing-a-transformer-encoder-from-scratch-with-jax-and-haiku-791d31b4f0dd">Implementing a Transformer Encoder from Scratch with JAX and Haiku 🤖</a></p><p>This article will cover the following sections:</p><ul><li><strong>Why </strong>do we need Deep RL?</li><li><strong>Deep Q-Networks, </strong><em>theory and practice</em></li><li><strong>Replay Buffers</strong></li><li>Translating the <strong>CartPole </strong>environment to <strong>JAX</strong></li><li>The <strong>JAX </strong>way to write <strong>efficient training loops</strong></li></ul><p><em>As always, all the code presented in this article is available on GitHub:</em></p><p><a href="https://github.com/RPegoud/jym">GitHub - RPegoud/jym: JAX implementation of RL algorithms and vectorized environments</a></p><h3><strong>Why </strong>do we need Deep RL?</h3><p>In previous articles, we introduced <a href="https://medium.com/towards-data-science/temporal-difference-learning-and-the-importance-of-exploration-an-illustrated-guide-5f9c3371413a">Temporal Difference Learning</a> algorithms and in particular <a href="https://towardsdatascience.com/vectorize-and-parallelize-rl-environments-with-jax-q-learning-at-the-speed-of-light-49d07373adf5">Q-learning</a>.</p><p>Simply put, Q-learning is an <strong>off-policy</strong> algorithm <em>(the target policy is not the policy used for decision-making)</em> maintaining and updating a<strong> Q-table</strong>, an explicit <strong>mapping </strong>of <strong>states </strong>to corresponding <strong>action values</strong>.</p><p>While Q-learning is a practical solution for environments with discrete action spaces and restricted observation spaces, it struggles to scale well to more complex environments. Indeed, creating a Q-table requires <strong>defining </strong>the <strong>action</strong> and <strong>observation spaces</strong>.</p><p>Consider the example of <strong>autonomous driving</strong>, the <strong>observation space</strong> is composed of an <em>infinity of potential configurations</em> derived from camera feeds and other sensory inputs. On the other hand, the <strong>action space</strong> includes a <em>wide spectrum of steering wheel positions</em> and varying levels of force applied to the brake and accelerator.</p><p>Even though we could theoretically discretize the action space, the sheer volume of possible states and actions leads to an <strong>impractical Q-table</strong> in <strong>real-world applications</strong>.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*JhqyHUB3t7cZCNdS" /><figcaption>Photo by <a href="https://unsplash.com/@photophotostock?utm_source=medium&amp;utm_medium=referral">Kirill Tonkikh</a> on <a href="https://unsplash.com?utm_source=medium&amp;utm_medium=referral">Unsplash</a></figcaption></figure><p>Finding optimal actions in large and complex state-action spaces thus requires <strong>powerful function approximation algorithms</strong>, which is precisely what<strong> Neural Networks</strong> are. In the case of Deep Reinforcement Learning, neural nets are used as a <strong>replacement for the Q-table</strong> and provide an efficient solution to the <em>curse of dimensionality </em>introduced by large state spaces. Furthermore, we do not need to explicitly define the observation space.</p><h3>Deep Q-Networks &amp; Replay Buffers</h3><p>DQN uses two types of neural networks in parallel, starting with the “<strong><em>online</em></strong>” network which is used for <strong>Q-value prediction</strong> and <strong>decision-making</strong>. On the other hand, the “<strong><em>target</em></strong>” network is used to <strong>create stable Q-targets</strong> to assess the performance of the online net via the loss function.</p><p>Similarly to Q-learning, DQN agents are defined by two functions: act and update.</p><h4>Act</h4><p>The act function implements an epsilon-greedy policy with respect to Q-values, which are estimated by the online neural network. In other words, the agent selects the action corresponding to the <strong>maximum predicted Q-value</strong> for a given state, with a set probability of acting randomly.</p><p>You might remember that Q-learning updates its Q-table <strong>after <em>every </em>step</strong>, however, in Deep Learning it is common practice to compute updates using <strong>gradient descent</strong> on a <strong>batch of inputs</strong>.</p><p>For this reason, DQN stores experiences (tuples containing state, action, reward, next_state, done_flag) in a <strong>replay buffer</strong>. To train the network, we’ll sample a batch of experiences from this buffer instead of using only the last experience <em>(more details in the Replay Buffer section)</em>.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*J36aRJCL3ocmA0ioFjrnBA.jpeg" /><figcaption>Visual representation of<strong> DQN’s action selection</strong> process (Made by the author)</figcaption></figure><p>Here’s a JAX implementation of the action-selection part of DQN:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/ad1820a30301651c4e329820bce2e4cf/href">https://medium.com/media/ad1820a30301651c4e329820bce2e4cf/href</a></iframe><p>The only subtlety of this snippet is that the model attribute doesn’t contain any internal parameters as is usually the case in frameworks such as PyTorch or TensorFlow.</p><p>Here, the model is a <strong>function </strong>representing a <strong>forward pass</strong> through our architecture, but the <strong><em>mutable </em>weights are stored externally </strong>and passed as <strong>arguments</strong>. This explains why we can use jit while passing the self argument as <strong>static <em>(</em></strong><em>the model being stateless as other class attributes)</em>.</p><h4>Update</h4><p>The update function is responsible for training the network. It computes a <strong>mean squared error</strong> (MSE) loss based on the <strong>temporal-difference</strong> (TD) <strong>error</strong>:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/802/0*hFfMBLrnGMUsdSsP.png" /><figcaption>Mean Squared Error used in DQN</figcaption></figure><p>In this loss function, <strong><em>θ</em></strong> denotes the <strong>parameters of the online network</strong>, and <strong><em>θ</em>−</strong> represents the <strong>parameters of the target network</strong>. The parameters of the target network are set on the online network’s parameters every <em>N</em> steps<em>, </em>similar to a <em>checkpoint </em>(<em>N is a hyperparameter).</em></p><p>This separation of parameters (with <em>θ</em> for the current Q-values and <em>θ</em>− for the target Q-values) is crucial to stabilize training.</p><p>Using the same parameters for both would be similar to aiming at a moving target, as <strong>updates to the network</strong> would i<strong>mmediately shift the target values</strong>. By <strong>periodically updating</strong> <strong><em>θ</em>−</strong> (i.e. freezing these parameters for a set number of steps), we ensure <strong>stable Q-targets</strong> while the online network continues to learn.</p><p>Finally, the <em>(1-done)</em> term <strong>adjusts the target</strong> for <strong>terminal states</strong>. Indeed, when an episode ends (i.e. ‘done’ is equal to 1), there is no next state. Therefore, the Q-value for the next state is set to 0.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*H7K68mSt_9EsvWX5ZOvY7Q.jpeg" /><figcaption>Visual representation of<strong> DQN’s parameter update </strong>process (Made by the author)</figcaption></figure><p>Implementing the update function for DQN is slightly more complex, let’s break it down:</p><ul><li>First, the _loss_fn function implements the squared error described previously for a <strong>single experience</strong>.</li><li>Then, _batch_loss_fn acts as a wrapper for _loss_fn and decorates it with vmap, applying the loss function to a <strong>batch of experiences</strong>. We then return the average error for this batch.</li><li>Finally, update acts as a final layer to our loss function, computing its <strong>gradient </strong>with respect to the online network parameters, the target network parameters, and a batch of experiences. We then use <strong>Optax </strong><em>(a JAX library commonly used for optimization)</em> to perform an optimizer step and update the online parameters.</li></ul><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/790a31d33db9a5b9411537a0f12ee2d2/href">https://medium.com/media/790a31d33db9a5b9411537a0f12ee2d2/href</a></iframe><p>Notice that, similarly to the replay buffer, the model and optimizer are <strong>pure functions</strong> modifying an <strong>external state</strong>. The following line serves as a good illustration of this principle:</p><pre>updates, optimizer_state = optimizer.update(grads, optimizer_state)</pre><p>This also explains why we can use a single model for both the online and target networks, as the parameters are stored and updated externally.</p><pre># target network predictions<br>self.model.apply(target_net_params, None, state)<br># online network predictions<br>self.model.apply(online_net_params, None, state)</pre><p>For context, the model we use in this article is a <em>multi-layer perceptron</em> defined as follows:</p><pre>N_ACTIONS = 2<br>NEURONS_PER_LAYER = [64, 64, 64, N_ACTIONS]<br>online_key, target_key = vmap(random.PRNGKey)(jnp.arange(2) + RANDOM_SEED)<br><br>@hk.transform<br>def model(x):<br>    # simple multi-layer perceptron<br>    mlp = hk.nets.MLP(output_sizes=NEURONS_PER_LAYER)<br>    return mlp(x)<br><br>online_net_params = model.init(online_key, jnp.zeros((STATE_SHAPE,)))<br>target_net_params = model.init(target_key, jnp.zeros((STATE_SHAPE,)))<br><br>prediction = model.apply(online_net_params, None, state)</pre><h4>Replay Buffer</h4><p>Now let us take a step back and look closer at replay buffers. They are widely used in reinforcement learning for a variety of reasons:</p><ul><li><strong>Generalization</strong>: By sampling from the replay buffer, we break the correlation between consecutive experiences by mixing up their order. This way, we avoid overfitting to specific sequences of experiences.</li><li><strong>Diversity</strong>: As the sampling is not limited to recent experiences, we generally observe a lower variance in updates and prevent overfitting to the latest experiences.</li><li><strong>Increased sample efficiency</strong>: Each experience can be sampled multiple times from the buffer, enabling the model to learn more from individual experiences.</li></ul><p>Finally, we can use several sampling schemes for our replay buffer:</p><ul><li><strong>Uniform sampling: </strong>Experiences are sampled uniformly at random. This type of sampling is straightforward to implement and allows the model to learn from experiences independently from the timestep they were collected.</li><li><strong>Prioritized sampling: </strong>This category includes different algorithms such as <strong>Prioritized Experience Replay </strong>(“PER”, <a href="https://arxiv.org/abs/1511.05952"><em>Schaul et al. 2015</em></a><em>) </em>or <strong>Gradient Experience Replay </strong>(“GER”, <a href="https://arxiv.org/abs/2110.01528"><em>Lahire et al., 2022</em></a><em>). </em>These methods attempt to prioritize the selection of experiences according to some metric related to their “<em>learning potential” </em>(the amplitude of the TD error for PER and the norm of the experience’s gradient for GER).</li></ul><p>For the sake of simplicity, we’ll implement a uniform replay buffer in this article. However, I plan to cover prioritized sampling extensively in the future.</p><p>As promised, the uniform replay buffer is quite easy to implement, however, there are a few complexities related to the use of JAX and functional programming. As always, we have to work with<strong> pure functions</strong> that are <strong>devoid of side effects</strong>. In other words, we are not allowed to define the buffer as a class instance with a variable internal state.</p><p>Instead, we initialize a buffer_state dictionary that maps keys to empty arrays with predefined shapes, as JAX requires constant-sized arrays when jit-compiling code to XLA.</p><pre>buffer_state = {<br>    &quot;states&quot;: jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),<br>    &quot;actions&quot;: jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),<br>    &quot;rewards&quot;: jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),<br>    &quot;next_states&quot;: jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),<br>    &quot;dones&quot;: jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),<br>}</pre><p>We will use a UniformReplayBuffer class to interact with the buffer state. This class has two methods:</p><ul><li>add: Unwraps an experience tuple and maps its components to a specific index. idx = idx % self.buffer_size ensures that when the buffer is full, adding new experiences overwrites older ones.</li><li>sample: Samples a sequence of random indexes from the uniform random distribution. The sequence length is set by batch_size while the range of the indexes is [0, current_buffer_size-1]. This ensures that we do not sample empty arrays while the buffer is not yet full. Finally, we use JAX’s vmap in combination with tree_map to return a batch of experiences.</li></ul><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/023105f4de64298471d0d67c6cd74853/href">https://medium.com/media/023105f4de64298471d0d67c6cd74853/href</a></iframe><h3>Translating the <strong>CartPole </strong>environment to <strong>JAX</strong></h3><p>Now that our DQN agent is ready for training, we’ll quickly implement a vectorized CartPole environment using the same framework as introduced in an <a href="https://towardsdatascience.com/vectorize-and-parallelize-rl-environments-with-jax-q-learning-at-the-speed-of-light-49d07373adf5">earlier article</a>. CartPole is a control environment having a <strong>large continuous observation space, </strong>which makes it relevant to test our DQN.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/600/0*lpRee3NrehU6l3PX.gif" /><figcaption>Visual representation of the CartPole Environment (credits and documentation: <a href="https://gymnasium.farama.org/environments/classic_control/cart_pole/">OpenAI Gymnasium</a>, MIT license)</figcaption></figure><p>The process is quite straightforward, we reuse most of <a href="https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/classic_control/cartpole.py">OpenAI’s Gymnasium implementation</a> while making sure we use JAX arrays and lax control flow instead of Python or Numpy alternatives, for instance:</p><pre># Python implementation<br>force = self.force_mag if action == 1 else -self.force_mag<br># Jax implementation<br>force = lax.select(jnp.all(action) == 1, self.force_mag, -self.force_mag)            )<br><br># Python<br>costheta, sintheta = math.cos(theta), math.sin(theta)<br># Jax<br>cos_theta, sin_theta = jnp.cos(theta), jnp.sin(theta)<br><br># Python<br>if not terminated:<br>  reward = 1.0<br>...<br>else: <br>  reward = 0.0<br># Jax<br>reward = jnp.float32(jnp.invert(done))</pre><p>For the sake of brevity, the full environment code is available here:</p><p><a href="https://github.com/RPegoud/jym/blob/main/src/envs/control/cartpole.py">jym/src/envs/control/cartpole.py at main · RPegoud/jym</a></p><h3>The <strong>JAX </strong>way to write <strong>efficient training loops</strong></h3><p>The last part of our implementation of DQN is the training loop <em>(also called rollout). </em>As mentioned in previous articles, we have to respect a specific format in order to take advantage of JAX’s speed.</p><p>The rollout function might appear daunting at first, but most of its complexity is purely syntactic as we’ve already covered most of the building blocks. Here’s a pseudo-code walkthrough:</p><pre>1. Initialization:<br>  * Create empty arrays that will store the states, actions, rewards <br>    and done flags for each timestep. Initialize the networks and optimizer<br>    with dummy arrays.<br>  * Wrap all the initialized objects in a val tuple<br><br>2. Training loop (repeat for i steps):<br>  * Unpack the val tuple<br>  * (Optional) Decay epsilon using a decay function<br>  * Take an action depending on the state and model parameters<br>  * Perform an environment step and observe the next state, reward <br>    and done flag<br>  * Create an experience tuple (state, action, reward, new_state, done)<br>    and add it to the replay buffer<br>  * Sample a batch of experiences depending on the current buffer size<br>    (i.e. sample only from experiences that have non-zero values)<br>  * Update the model parameters using experience batch<br>  * Every N steps, update the target network&#39;s weights <br>    (set target_params = online_params)<br>  * Store the experience&#39;s values for the current episode and return <br>    the updated `val` tuple</pre><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/ea8779e9ebf69d62835d6e063791864f/href">https://medium.com/media/ea8779e9ebf69d62835d6e063791864f/href</a></iframe><p>We can now run DQN for <strong>20,000 steps</strong> and observe the performances. After around 45 episodes, the agent manages to obtain decent performances, balancing the pole for more than 100 steps consistently.</p><p>The <strong>green bars</strong> indicate that the agent managed to balance the pole for <strong>more than 200 steps</strong>, <strong>solving the environment</strong>. Notably, the agent set its record on the <strong>51st episode</strong>, with <strong>393 steps</strong>.</p><iframe src="https://cdn.embedly.com/widgets/media.html?src=https%3A%2F%2Fplotly.com%2F%7ERyan_pgd%2F15.embed%3Fautosize%3Dtrue&amp;display_name=Plotly&amp;url=https%3A%2F%2Fchart-studio.plotly.com%2F%7ERyan_pgd%2F15%2F&amp;image=https%3A%2F%2Fchart-studio.plotly.com%2Fstatic%2Fwebapp%2Fimages%2Fplotly-logo.8d56a320dbb8.png&amp;key=a19fcc184b9711e1b4764040d3dc5c07&amp;type=text%2Fhtml&amp;schema=plotly" width="600" height="400" frameborder="0" scrolling="no"><a href="https://medium.com/media/3c385a3226b6585e8d462f5d95447bc4/href">https://medium.com/media/3c385a3226b6585e8d462f5d95447bc4/href</a></iframe><p>The <strong>20.000 training steps</strong> were executed in <strong>just over a second</strong>, at a rate of <strong>15.807 steps per second</strong> <em>(on a </em><strong><em>single CPU</em></strong><em>)</em>!</p><p>These performances hint at JAX’s impressive scaling capabilities, allowing practitioners to run large-scale parallelized experiments with minimal hardware requirements.</p><pre>Running for 20,000 iterations: 100%|██████████| 20000/20000 [00:01&lt;00:00, 15807.81it/s]</pre><p>We’ll take a closer look at <strong>parallelized rollout procedures</strong> to run <strong>statistically significant</strong> experiments and <strong>hyperparameter searches </strong>in a future article!</p><p>In the meantime, feel free to reproduce the experiment and dabble with hyperparameters using this notebook:</p><p><a href="https://github.com/RPegoud/jym/blob/main/notebooks/control/cartpole/dqn_cartpole.ipynb">jym/notebooks/control/cartpole/dqn_cartpole.ipynb at main · RPegoud/jym</a></p><h3>Conclusion</h3><p>As always, <strong>thanks for reading this far! </strong>I hope this article provided a decent introduction to Deep RL in JAX. Should you have any questions or feedback related to the content of this article, make sure to let me know, I’m always happy to have a little chat ;)</p><p>Until next time 👋</p><h3>Credits:</h3><ul><li><a href="https://gymnasium.farama.org/environments/classic_control/cart_pole/">Cartpole Gif</a>, OpenAI Gymnasium library, (MIT license)</li></ul><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=c1e45a179b92" width="1" height="1" alt=""><hr><p><a href="https://medium.com/data-science/a-gentle-introduction-to-deep-reinforcement-learning-in-jax-c1e45a179b92">A Gentle Introduction to Deep Reinforcement Learning in JAX</a> was originally published in <a href="https://medium.com/data-science">TDS Archive</a> on Medium, where people are continuing the conversation by highlighting and responding to this story.</p>]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[Implementing a Transformer Encoder from Scratch with JAX and Haiku ]]></title>
            <link>https://medium.com/data-science/implementing-a-transformer-encoder-from-scratch-with-jax-and-haiku-791d31b4f0dd?source=rss-27fba63b402e------2</link>
            <guid isPermaLink="false">https://medium.com/p/791d31b4f0dd</guid>
            <category><![CDATA[editors-pick]]></category>
            <category><![CDATA[transformers]]></category>
            <category><![CDATA[machine-learning]]></category>
            <category><![CDATA[deep-learning]]></category>
            <category><![CDATA[nlp]]></category>
            <dc:creator><![CDATA[Ryan Pégoud]]></dc:creator>
            <pubDate>Tue, 07 Nov 2023 14:54:41 GMT</pubDate>
            <atom:updated>2023-11-07T14:54:41.373Z</atom:updated>
            <content:encoded><![CDATA[<h4>Understanding the fundamental building blocks of Transformers.</h4><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*UfCDk9d2SgydLmUO" /><figcaption>Transformers, in the style of Edward Hopper (generated by Dall.E 3)</figcaption></figure><p>Introduced in 2017 in the seminal paper “<a href="https://arxiv.org/pdf/1706.03762.pdf"><strong><em>Attention is all you need</em></strong></a><strong><em>”</em></strong>[0], the Transformer architecture is arguably one of the most impactful breakthroughs in recent Deep Learning history, enabling the rise of large language models and even finding use in fields such as computer vision.</p><p>Succeeding to former state-of-the-art architectures relying on <strong>recurrence </strong>such as Long Short-Term Memory (<strong>LSTM</strong>) networks or Gated Recurrent Units (<strong>GRU</strong>), <strong>Transformers </strong>introduce the concept of <strong>self-attention</strong>, coupled with an <strong>encoder/decoder </strong>architecture.</p><p>In this article, we’ll implement the first half of a Transformer, the <strong>Encoder</strong>, from scratch and step by step. We’ll use <strong>JAX </strong>as our main framework along with <strong>Haiku, </strong>one of DeepMind’s deep learning libraries.</p><p>In case you are unfamiliar with JAX or need a fresh reminder about its amazing functionalities, I’ve already covered the topic in the context of Reinforcement Learning in my <strong>previous article</strong>:</p><p><a href="https://towardsdatascience.com/vectorize-and-parallelize-rl-environments-with-jax-q-learning-at-the-speed-of-light-49d07373adf5">Vectorize and Parallelize RL Environments with JAX: Q-learning at the Speed of Light⚡</a></p><p>We’ll go over each of the blocks that make up the encoder and learn to implement them efficiently. In particular, the outline of this article contains:</p><ul><li>The<strong> Embedding Layer</strong> and <strong>Positional Encodings</strong></li><li><strong>Multi-Head Attention</strong></li><li><strong>Residual Connections</strong> and <strong>Layer Normalization</strong></li><li><strong>Position-wise Feed-Forward Networks</strong></li></ul><p><em>Disclaimer: this article is not intended to be a complete introduction to these notions as we’ll focus on implementation first. If needed, please refer to the resources at the end of this post.</em></p><p><strong><em>As always, the fully commented code for this article as well as illustrated notebooks are available on </em></strong><a href="https://github.com/RPegoud/jab"><strong><em>GitHub</em></strong></a><strong><em>, feel free to star the repository if you enjoyed the article!</em></strong></p><p><a href="https://github.com/RPegoud/jab">GitHub - RPegoud/jab: A collection of foundational Deep Learning models implemented in JAX</a></p><h4>Main parameters</h4><p>Before we get started, we need to define a few parameters that will play a crucial role in the encoder block:</p><ul><li><strong>Sequence Length</strong> (seq_len): The number of tokens or words in a sequence.</li><li><strong>Embedding Dimension </strong>(embed_dim): The dimension of the embeddings, in other words, the number of numerical values used to describe a single token or word.</li><li><strong>Batch Size (</strong>batch_size<strong>): </strong>The size of a batch of inputs, i.e. the number of sequences processed at the same time.</li></ul><p>The input sequences to our encoder model will typically be of shape <strong>(</strong>batch_size<strong>, </strong>seq_len<strong>)</strong>. In this article, we’ll use batch_size=32 and seq_len=10, which means that our encoder will simultaneously process 32 sequences of 10 words.</p><p>Paying attention to the shape of our data at each step of the processing will enable us to better visualize and understand how the data flows in the encoder block. Here’s a high-level overview of our encoder, we’ll start from the bottom with the <strong>embedding layer</strong> and <strong>positional encodings</strong>:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/977/1*Mfz2VwpV4_pmBgXlRxVSLw.jpeg" /><figcaption>Representation of the <strong>Transformer Encoder block</strong> (made by the author)</figcaption></figure><h3>Embedding Layer and Positional Encodings</h3><p>As mentioned previously, our model takes batched sequences of tokens as inputs. Generating those tokens could be as simple as collecting a set of unique words in our dataset, and assigning an index to each of them. Then we would sample <strong>32</strong> <strong>sequences </strong>of <strong>10 words </strong>and replace each word with its index in the vocabulary. This procedure would provide us with an array of shape <strong>(</strong>batch_size<strong>, </strong>seq_len<strong>)</strong>, as expected.</p><p>We are now ready to get started with our Encoder. The first step is to create “<strong><em>positional embeddings</em></strong>” for our sequences. Positional embeddings are the <strong>sum </strong>of <strong>word embeddings</strong> and <strong>positional encodings</strong>.</p><h4>Word Embeddings</h4><p>Word embeddings allow us to encode the <strong>meaning </strong>and <strong>semantic relations</strong> <strong>between words</strong> in our vocabulary. In this article, the embedding dimension is fixed to <strong>64</strong>. This means that each word is represented by a<strong> 64-dimensional vector</strong> so that words with similar meanings have similar coordinates. Moreover, we can manipulate these vectors to <strong>extract relations between words</strong>, as depicted below.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/0*YG1Ram02GG4jewJi" /><figcaption>Example of analogies derived from word embeddings (image from developers.google.com)</figcaption></figure><p>Using Haiku, generating learnable embeddings is as simple as calling:</p><pre>hk.Embed(vocab_size, embed_dim)</pre><p>These embeddings will be updated along with other learnable parameters during model training <em>(more on that in a second)</em>.</p><h4>Positional Encodings</h4><p>As opposed to recurrent neural nets, Transformers can’t infer the position of a token given a shared hidden state as they <strong>lack recurrent</strong> or <strong>convolutional structures</strong>. Hence the introduction of <strong>positional encodings, vectors</strong> that convey a <strong>token’s position </strong>in the <strong>input sequence</strong>.</p><p>Essentially, each token is assigned a <strong>positional vector</strong> composed of <strong>alternating sine and cosine values</strong>. Those vectors match the dimensionality of word embeddings so that both can be summed.</p><p>In particular, the original Transformer paper uses the following functions:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/468/0*QMuqej-eyXCUlVKW.png" /></figure><figure><img alt="" src="https://cdn-images-1.medium.com/max/496/0*wUaDUdiDqfE48-EZ.png" /><figcaption>Positional Encoding functions (reproduced from “Attention is all you need”, Vaswani et al. 2017)</figcaption></figure><p>The below figures enable us to further understand the functioning of positional encodings. Let’s take a look at the first row of the uppermost plot, we can see <strong>alternating sequences of zeros and ones.</strong> Indeed, rows represent the position of a token in the sequence (the pos variable) while columns represent the embedding dimension (the i variable).</p><p>Therefore, when pos=0, the previous equations return sin(0)=0 for even embedding dimensions and cos(0)=1 for odd dimensions.</p><p>Moreover, we see that adjacent rows share similar values, whereas the first and last rows are wildly different. This property is helpful for the model to assess the <strong>distance between words</strong> in the sequence as well as their <strong>ordering</strong>.</p><p>Finally, the third plot represents the sum of positional encodings and embeddings, which is the output of the embedding block.</p><iframe src="https://cdn.embedly.com/widgets/media.html?src=https%3A%2F%2Fplotly.com%2F%7ERyan_pgd%2F11.embed%3Fautosize%3Dtrue&amp;display_name=Plotly&amp;url=https%3A%2F%2Fchart-studio.plotly.com%2F%7ERyan_pgd%2F11%2F&amp;image=https%3A%2F%2Fchart-studio.plotly.com%2Fstatic%2Fwebapp%2Fimages%2Fplotly-logo.8d56a320dbb8.png&amp;key=a19fcc184b9711e1b4764040d3dc5c07&amp;type=text%2Fhtml&amp;schema=plotly" width="600" height="400" frameborder="0" scrolling="no"><a href="https://medium.com/media/71e9f04e859d0202fffef34771d108fc/href">https://medium.com/media/71e9f04e859d0202fffef34771d108fc/href</a></iframe><p>Using Haiku, we define the embedding layer as follows. Similarly to other deep learning frameworks, Haiku allows us to define <strong>custom modules</strong> (here hk.Module) to <strong>store learnable parameters</strong> and <strong>define the behavior</strong> of our model’s components.</p><p>Each Haiku module needs to have an __init__and __call__function. Here, the call function simply computes the embeddings using the hk.Embed function and the positional encodings, before summing them.</p><p>The positional encoding function uses JAX functionalities such as vmapand<strong><em> </em></strong>lax.condfor performance. If you are unfamiliar with those functions, feel free to check out my <a href="https://medium.com/towards-data-science/vectorize-and-parallelize-rl-environments-with-jax-q-learning-at-the-speed-of-light-49d07373adf5">previous post</a> where they are presented more in-depth.</p><p>Put simply, vmapallows us to define a function for a <strong>single sample</strong> and <strong>vectorize it</strong> so that it can be applied to <strong>batches</strong> of data. The in_axesparameter is used to specify that we want to iterate over the first axis of the dim<em> </em>input, which is the embedding dimension. On the other hand, lax.cond<em> </em>is an XLA-compatible version of a Python if/else statement.</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/ebcac5ffcd7cc380167dcb3fdb61de75/href">https://medium.com/media/ebcac5ffcd7cc380167dcb3fdb61de75/href</a></iframe><h3>Self-attention and MultiHead-Attention</h3><p>Attention aims to compute the <strong>importance of each word in a sequence</strong>, <strong>relative to an input word</strong>. For example, in the sentence:</p><blockquote>“The black cat jumped on the sofa, lied down and fell asleep, as it was tired”.</blockquote><p>The word “<strong>it</strong>” could be quite ambiguous for the model, as <em>technically</em>, it could refer to both “<strong>cat</strong>” and “<strong>sofa</strong>”. A well-trained attention model would be able to understand that “<strong>it</strong>” refers to “<strong>cat</strong>” and therefore assign attention values to the rest of the sentence accordingly.</p><p>Essentially,<strong> attention values</strong> could be seen as <strong>weights</strong> that describe the <strong>importance </strong>of a certain word <strong>given the context of the input</strong> word. For instance, the attention vector for the word “<strong>jumped</strong>” would have high values for words like “<strong>cat</strong>” (<em>what </em>jumped?), “<strong>on</strong>”, and “<strong>sofa</strong>” (<em>where </em>did it jump?) as these words are <strong>relevant to its context</strong>.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*ySDEvRPl9WfONWRVOey44Q.png" /><figcaption>Visual representation of an<strong> attention vector</strong> (made by the author)</figcaption></figure><p>In the Transformer paper, attention is computed using <strong><em>Scaled Dot-Product Attention</em></strong>. Which is summarized by the formula:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/535/0*QEogF8q4SRV4uvl7.png" /><figcaption>Scaled Dot-Product Attention (reproduced from “Attention is all you need”, <em>Vaswani et al. 2017</em>)</figcaption></figure><p>Here, Q,K and V stand for <strong><em>Queries, Keys </em></strong>and <strong><em>Values</em></strong><em>. </em>These matrices are obtained by multiplying learned weight vectors WQ, WK and WV with positional embeddings.</p><p>These names are mainly <strong>abstractions </strong>used to help understand how the information is processed and weighted in the attention block. They are an allusion to <strong>retrieval systems</strong> vocabulary[2] (e.g. searching a video on YouTube for instance).</p><p>Here’s an <strong>intuitive </strong>explanation:</p><ul><li><strong>Queries</strong>: They can be interpreted as a “<em>set of questions</em>” about all the positions in a sequence. For instance, interrogating the context of a word and trying to identify the most relevant parts of the sequence.</li><li><strong>Keys</strong>: They can be seen as holding information that the queries interact with, the compatibility between a query and a key determines how much attention the query should pay to the corresponding value.</li><li><strong>Values</strong>: Matching keys and queries allows us to decide which keys are relevant, values are the actual content paired with the keys.</li></ul><p>In the following figure, the query is a YouTube search, the keys are the video descriptions and metadata, while the value are the associated videos.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/872/1*ubT1YkrPprthMq9pKGbroA.jpeg" /><figcaption>Intuitive representation of the Queries, Keys, Values concept (made by the author)</figcaption></figure><p>In our case, queries, keys, and values come from the <strong>same source</strong> (as they’re derived from the input sequences), hence the name <strong>self-attention</strong>.</p><p>The computation of attention scores is usually executed <strong>multiple times in parallel</strong>, each time with a <strong>fraction of the embeddings</strong>. This mechanism is called “<strong>Multi-Head Attention</strong>” and enables each head to learn several different representations of the data in parallel, leading to a more <strong>robust </strong>model.</p><p>A single attention head would generally process arrays with shape (batch_size, seq_len, d_k<strong>)<em> </em></strong>where d_kcan be set as the ratio between the number of heads and the dimension of the embeddings (d_k = n_heads/embed_dim). This way, concatenating the outputs of each head conveniently gives an array with shape <strong>(</strong>batch_size, seq_len, embed_dim<strong>)</strong>, as the input.</p><p>The computation of attention matrices can be broken down into several steps:</p><ul><li>First, we define <strong>learnable weight vectors</strong> WQ, WK, and WV. These vectors have shapes <strong>(</strong>n_heads, embed_dim, d_k<strong>)</strong>.</li><li>In parallel, we <strong>multiply </strong>the <strong>positional embeddings</strong> with the <strong>weight vectors</strong>. We obtain Q, K, and V matrices with shapes <strong>(</strong>batch_size, seq_len, d_k<strong>)</strong>.</li><li>We then <strong>scale </strong>the <strong>dot-product</strong> of Q and K (transposed). This scaling involves dividing the result of the dot-product by the square root of d_kand applying the softmax function on the matrices rows. Therefore, attention scores for an input token (i.e. a row) sum up to one, this helps prevent values from becoming too large and slowing down computation. The output has shape (batch_size, seq_len, seq_len)</li><li>Finally, we dot the result of the previous operation with V, making the shape of the output (batch_size, seq_len, d_k).</li></ul><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*IrLV4_NhMZ8MtWeC9xt0fA.png" /><figcaption>Visual representation of matrix operations inside <strong>an attention block </strong>(made by the author)</figcaption></figure><ul><li>The outputs of each attention head can then be <strong>concatenated </strong>to form a matrix with shape (batch_size, seq_len, embed_dim). The Transformer paper also adds a <strong>linear layer</strong> at the end of the multi-head attention module, to <strong>aggregate </strong>and <strong>combine </strong>the learned representations from<strong> all the attention heads</strong>.</li></ul><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*FVx4Hffl1lfZo1Lm2Tbf0w.png" /><figcaption>Concatenation of multi-head attention matrices and linear layer (made by the author)</figcaption></figure><p>In Haiku, the Multi-Head Attention module can be implemented as follows. The __call__function follows the same logic as the above graph while the class methods take advantage of JAX utilities such as vmap(to vectorize our operations over the different attention heads and matrices) and tree_map(to map matrix dot-products over weight vectors).</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/d3b7c8111f21d6f08c206b778f7be2f7/href">https://medium.com/media/d3b7c8111f21d6f08c206b778f7be2f7/href</a></iframe><h3>Residual Connections and Layer Normalization</h3><p>As you might have noticed on the Transformer graph, the multi-head attention block and the feed-forward net are followed by <strong>residual connections</strong> and <strong>layer normalization</strong>.</p><h4>Residual or skip connections</h4><p>Residual connections are a standard solution to <strong>solve </strong>the <strong>vanishing gradient problem</strong>, which occurs when gradients become too small to effectively update the model’s parameters.</p><p>As this issue naturally arises in particularly deep architectures, residual connections are used in a variety of complex models such as <strong>ResNet </strong><em>(</em><a href="https://arxiv.org/abs/1512.03385v1"><em>Kaiming et al</em></a><em>, 2015)</em> in computer vision, <strong>AlphaZero </strong>(<a href="https://arxiv.org/abs/1712.01815v1"><em>Silver et al</em></a><em>, 2017</em>) in reinforcement learning, and of course, <strong>Transformers</strong>.</p><p>In practice, residual connections simply forward the output of a specific layer to a following one, <strong>skipping one or more layers</strong> on the way. For instance, the residual connection around the multi-head attention is equivalent to summing the output of multi-head attention with positional embeddings.</p><p>This enables gradients to flow more efficiently through the architecture during backpropagation and can usually lead to<strong> faster convergence </strong>and more <strong>stable training</strong>.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/522/1*_BL1Q4sT5kzmBQ7wBl7DnQ.jpeg" /><figcaption>Representation of <strong>residual connections</strong> in Transformers (made by the author)</figcaption></figure><h4>Layer Normalization</h4><p>Layer normalization helps ensure that the values propagated through the model do not “<strong><em>explode</em></strong>” (tend toward infinity), which could easily happen in attention blocks, where several matrices are multiplied during each forward pass.</p><p>Unlike batch normalization, which normalizes across the batch dimension assuming a uniform distribution,<strong> layer normalization operates</strong> <strong>across the features</strong>. This approach is suitable for sentence batches where each may have <strong>unique distributions</strong> due to <strong>varying meanings</strong> and <strong>vocabularies</strong>.</p><p>By normalizing across <strong>features</strong>, such as <strong>embeddings </strong>or <strong>attention values</strong>, layer normalization<strong> standardizes data</strong> to a consistent scale<strong> without conflating distinct sentence characteristics</strong>, maintaining the unique distribution of each.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*NlJu3E6z-fZLExXGvTzcnA.jpeg" /><figcaption>Representation of <strong>Layer Normalization</strong> in the context of Transformers (made by the author)</figcaption></figure><p>The implementation of layer normalization is pretty straightforward, we initialize the learnable parameters alpha and beta and normalize along the desired feature axis.</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/4259fcbe42c4c4f7e646d6526e3b2ff0/href">https://medium.com/media/4259fcbe42c4c4f7e646d6526e3b2ff0/href</a></iframe><h3><strong>Position-wise Feed-Forward Network</strong></h3><p>The last component of the encoder that we need to cover is the<strong> position-wise feed-forward network</strong>. This fully connected network takes the normalized outputs of the attention block as inputs and is used to introduce <strong>non-linearity</strong> and increase the <strong>model’s capacity</strong> to learn complex functions.</p><p>It is composed of two dense layers separated by a <a href="https://paperswithcode.com/method/gelu">gelu activation</a>:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/664d0958c76af2c37963a3ca9be4e35f/href">https://medium.com/media/664d0958c76af2c37963a3ca9be4e35f/href</a></iframe><p>After this block, we have another residual connection and layer normalization to complete the encoder.</p><h3>Wrapping up</h3><p>There we have it! By now you should be familiar with the main concepts of the Transformer encoder. Here’s the full encoder class, notice that in Haiku, we assign a name to each layer, so that learnable parameters are separated and easy to access. The __call__function provides a good summary of the different steps of our encoder:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/f79e91cc89202e6ffaf523620c130ed2/href">https://medium.com/media/f79e91cc89202e6ffaf523620c130ed2/href</a></iframe><p>To use this module on actual data, we have to apply hk.transform to a function encapsulating the encoder class. Indeed, you might remember that JAX embraces the <strong>functional programming</strong> paradigm, therefore, Haiku follows the same principles.</p><p>We define a function containing an instance of the encoder class and return the output of a forward pass. Applying hk.transform returns a transformed object having access to two functions: init and apply.</p><p>The former enables us to initialize the module with a random key as well as some dummy data (notice that here we pass an array of zeros with shape batch_size, seq_len) while the latter allows us to process real data.</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/61a044ea8f07ec0257e911e68f17f2f0/href">https://medium.com/media/61a044ea8f07ec0257e911e68f17f2f0/href</a></iframe><pre># Note: the two following syntaxes are equivalent<br># 1: Using transform as a class decorator<br>@hk.transform<br>def encoder(x):<br>  ...<br>  return model(x) <br> <br>encoder.init(...)<br>encoder.apply(...)<br><br># 2: Applying transfom separately<br>def encoder(x):<br>  ...<br>  return model(x)<br><br>encoder_fn = hk.transform(encoder)<br>encoder_fn.init(...)<br>encoder_fn.apply(...)</pre><p>In the next article, we’ll <strong>complete the transformer</strong> architecture by adding a <strong>decoder</strong>, which reuses most of the blocks we introduced so far, and learn how to <strong>train a model</strong> on a specific task using Optax!</p><h3><strong>Conclusion</strong></h3><p><strong>Thank you for reading this far</strong>, if you are interested in dabbling with the code, you can find it fully commented on GitHub, along with additional details and a walkthrough using a toy dataset.</p><p><a href="https://github.com/RPegoud/jab">GitHub - RPegoud/jab: A collection of foundational Deep Learning models implemented in JAX</a></p><p>If you’d like to dig deeper into Transformers, the following section contains some articles that helped me redact this article.</p><p>Until next time 👋</p><h3>References and Resources:</h3><p>[1] <a href="https://arxiv.org/pdf/1706.03762.pdf"><strong><em>Attention is all you need</em></strong></a> (2017), Vaswani et al, Google</p><p>[2] <a href="https://stats.stackexchange.com/questions/421935/what-exactly-are-keys-queries-and-values-in-attention-mechanisms"><strong><em>What exactly are keys, queries, and values in attention mechanisms?</em></strong></a><strong><em> </em></strong><em>(2019)</em><strong><em> </em></strong>Stack Exchange</p><p>[3] <a href="http://jalammar.github.io/illustrated-transformer/"><strong><em>The Illustrated Transformer</em></strong></a><strong><em> </em></strong><em>(2018), </em><a href="http://jalammar.github.io/">Jay Alammar</a></p><p>[4] <a href="https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/"><strong><em>A Gentle Introduction to Positional Encoding in Transformer Models</em></strong></a><strong><em> </em></strong><em>(2023), </em><a href="https://machinelearningmastery.com/author/msaeed/"><strong>Mehreen Saeed</strong></a>, Machine Learning Mastery</p><ul><li><a href="https://jax.readthedocs.io/en/latest/index.html"><strong><em>JAX documentation</em></strong></a></li><li><a href="https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html"><strong><em>Haiku documentation</em></strong></a></li></ul><h3>Image Credits</h3><ul><li><a href="https://developers.google.com/machine-learning/crash-course/embeddings/translating-to-a-lower-dimensional-space?hl=fr">Word embeddings</a>, developers.google.com</li><li>Cat picture, <a href="https://unsplash.com/fr/photos/bulldog-francese-marrone-che-indossa-una-camicia-gialla-5PVXkqt2s9k">Karsten Winegeart</a>, Unsplash</li><li>Norway landscape, <a href="https://unsplash.com/fr/photos/corpo-de-agua-perto-da-montanha-LKOuYT5_dyw">Pascal Debrunner</a>, Unsplash</li><li>Dog picture, <a href="https://unsplash.com/fr/photos/chaton-tabby-argente-sur-marbre-7AIDE8PrvA0">Loan</a>, Unsplash</li></ul><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=791d31b4f0dd" width="1" height="1" alt=""><hr><p><a href="https://medium.com/data-science/implementing-a-transformer-encoder-from-scratch-with-jax-and-haiku-791d31b4f0dd">Implementing a Transformer Encoder from Scratch with JAX and Haiku 🤖</a> was originally published in <a href="https://medium.com/data-science">TDS Archive</a> on Medium, where people are continuing the conversation by highlighting and responding to this story.</p>]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[Vectorize and Parallelize RL Environments with JAX: Q-learning at the Speed of Light⚡]]></title>
            <link>https://medium.com/data-science/vectorize-and-parallelize-rl-environments-with-jax-q-learning-at-the-speed-of-light-49d07373adf5?source=rss-27fba63b402e------2</link>
            <guid isPermaLink="false">https://medium.com/p/49d07373adf5</guid>
            <category><![CDATA[parallel-computing]]></category>
            <category><![CDATA[jax]]></category>
            <category><![CDATA[python]]></category>
            <category><![CDATA[reinforcement-learning]]></category>
            <category><![CDATA[machine-learning]]></category>
            <dc:creator><![CDATA[Ryan Pégoud]]></dc:creator>
            <pubDate>Sun, 15 Oct 2023 15:41:03 GMT</pubDate>
            <atom:updated>2023-10-30T16:06:23.608Z</atom:updated>
            <content:encoded><![CDATA[<h4>In this article, we learn to vectorize an RL environment and train 30 Q-learning agents in parallel on a CPU, at 1.8 million iterations per second.</h4><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*Sn0-yKB__HAuw7BssQhzkw.png" /><figcaption>Image by <a href="https://unsplash.com/fr/@googledeepmind">Google DeepMind</a> on <a href="https://unsplash.com/fr">Unsplash</a></figcaption></figure><p>In the previous story, we introduced <strong>Temporal-Difference Learning, </strong>particularly <strong>Q-learning</strong>, in the context of a GridWorld.</p><p><a href="https://towardsdatascience.com/temporal-difference-learning-and-the-importance-of-exploration-an-illustrated-guide-5f9c3371413a">Temporal-Difference Learning and the importance of exploration: An illustrated guide</a></p><p>While this implementation served the purpose of demonstrating the differences in performances and exploration mechanisms of these algorithms, <strong><em>it was painfully slow</em></strong>.</p><p>Indeed, the environment and agents were mainly coded in <strong>Numpy</strong>, which is by no means a standard in RL, even though it makes the code easy to understand and debug.</p><p>In this article, we’ll see how to scale up RL experiments by <strong>vectorizing environments</strong> and seamlessly <strong>parallelizing </strong>the training of dozens of agents using <strong>JAX</strong>. In particular, this article covers:</p><ul><li>JAX basics and useful features for RL</li><li>Vectorized environment and why they are so fast</li><li>Implementation of an environment, policy, and Q-learning agent in JAX</li><li>Single-agent training</li><li>How to parallelize agent training, and how easy it is!</li></ul><p><em>All the code featured in this article is available on </em><a href="https://github.com/RPegoud"><strong><em>GitHub</em></strong></a><em>:</em></p><p><a href="https://github.com/RPegoud/jym">GitHub - RPegoud/jym: JAX implementation of RL algorithms and vectorized environments</a></p><h3>JAX Basics</h3><p>JAX is <em>yet another</em> Python Deep Learning framework developed by Google and widely used by companies such as DeepMind.</p><blockquote>“JAX is <a href="https://github.com/hips/autograd">Autograd</a> (automatic differenciation) and <a href="https://www.tensorflow.org/xla">XLA</a> (Accelerated Linear Algebra, a TensorFlow compiler), brought together for high-performance numerical computing.” — <a href="https://jax.readthedocs.io/en/latest/index.html">Official Documentation</a></blockquote><p>As opposed to what most Python developers are used to, JAX doesn’t embrace the <strong>object-oriented programming</strong> (OOP) paradigm, but rather <strong>functional programming (FP)[1]</strong>.</p><p>Put simply, it relies on <strong><em>pure functions</em></strong> (<strong>deterministic </strong>and <strong>without side effects</strong>) and <strong><em>immutable data structures</em> (</strong>instead of changing the data in place, <strong>new data structures</strong> are <strong>created with the desired modifications) </strong>as primary building blocks. As a result, FP encourages a more functional and mathematical approach to programming, making it well-suited for tasks like numerical computing and machine learning.</p><p>Let’s illustrate the differences between those two paradigms by looking at pseudocode for a Q-update function:</p><ul><li>The <strong>object-oriented</strong> approach relies on a <strong><em>class instance</em></strong> containing various <strong><em>state variables</em></strong> (such as the Q-values). The update function is defined as a class method that <strong>updates the <em>internal state</em></strong> of the instance.</li><li>The <strong>functional programming</strong> approach relies on a <strong><em>pure function</em></strong>. Indeed, this Q-update is <strong>deterministic</strong> as the Q-values are passed as an argument. Therefore, any call to this function with the <strong>same inputs</strong> will result in the <strong>same outputs</strong> whereas a class method’s outputs may depend on the internal state of the instance. Also, <strong>data structures</strong> such as arrays are <strong>defined </strong>and <strong>modified</strong> in the <strong>global scope</strong>.</li></ul><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*OuKd01A2rPMKjfqVKOAydA.png" /><figcaption>Implementing a Q-update in <strong>Object-Oriented Programming </strong>and <strong>Functional Programming </strong>(made by the author)</figcaption></figure><p>As such, JAX offers a variety of <strong>function decorators</strong> that are particularly useful in the context of RL:</p><ul><li><a href="https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap"><strong>vmap</strong></a><strong> (vectorized map)</strong>: Allows a function acting on a single sample to be applied on a <strong>batch</strong>. For instance, if <em>env.step()</em> is a function performing a step in a single environment, <em>vmap(env.step)()</em> is a function performing a step in <strong>multiple environments</strong>. In other words, vmap adds a <strong>batch dimension </strong>to a function.</li></ul><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*cG6t5RzjEDhEHZaLGRngFg.png" /><figcaption>Illustration of a <strong>step </strong>function vectorized using <strong>vmap </strong>(made by the author)</figcaption></figure><ul><li><strong>j</strong><a href="https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit"><strong>it</strong></a><strong> (just-in-time compilation)</strong>: Allows JAX to perform a “<em>Just In Time compilation of a JAX Python function” </em>making it <strong>XLA-compatible</strong><em>. </em>Essentially, using jit allows us to <strong>compile functions</strong> and provides <strong>significant speed improvements</strong> (in exchange for some additional overhead when first compiling the function).</li><li><a href="https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html#jax.pmap"><strong>pmap</strong></a><strong> (parallel map)</strong>: Similarly to vmap, pmap enables easy parallelization. However, instead of adding a batch dimension to a function, it replicates the function and executes it on <strong>several XLA devices</strong>. <em>Note: when applying pmap, jit is also applied </em><strong><em>automatically</em></strong><em>.</em></li></ul><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*hgFRWIDLA7V8vTV_7gAvZw.png" /><figcaption>Illustration of a <strong>step </strong>function parallelized using <strong>pmap </strong>(made by the author)</figcaption></figure><p>Now that we have laid down the basics of JAX, we’ll see how to obtain massive speed-ups by vectorizing environments.</p><h3>Vectorized Environments:</h3><p>First, what is a vectorized environment and what problems does vectorization solve?</p><p>In most cases, RL experiments are <strong>slowed down</strong> by <strong>CPU-GPU data transfers</strong>. Deep Learning RL algorithms such as <strong>Proximal Policy Optimization</strong> (PPO) use Neural Networks to approximate the policy.</p><p>As always in Deep Learning, Neural Networks use <strong>GPUs</strong> at <strong>training </strong>and <strong>inference </strong>time. However, in most cases, <strong>environments </strong>run on the <strong>CPU</strong> (even in the case of multiple environments being used in parallel).</p><p>This means that the usual RL loop of selecting actions via the policy (Neural Networks) and receiving observations and rewards from the environment requires <strong>constant back-and-forths</strong> between the GPU and the CPU, which <strong>hurts performance</strong>.</p><p>In addition, using frameworks such as PyTorch without <em>“jitting” </em>might cause some overhead, since the GPU might have to wait for Python to send back observations and rewards from the CPU.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*Bje2CxmcwEd-iNViCXUagQ.png" /><figcaption>Usual RL batched training setup in <strong>PyTorch </strong>(made by the author)</figcaption></figure><p>On the other hand, JAX enables us to easily run batched environments on the GPU, removing the friction caused by GPU-CPU data transfer.</p><p>Moreover, as jit compiles our JAX code to XLA, the execution is no longer (or at least less) affected by the inefficiency of Python.</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*6DFSLgGDU4M7kYSlu9WSyQ.png" /><figcaption>RL batched training setup in <strong>JAX </strong>(made by the author)</figcaption></figure><p>For more details and exciting applications to <strong>meta-learning RL research</strong>, I highly recommend this blog post by <a href="https://chrislu.page/blog/meta-disco/">Chris Lu</a>.</p><iframe src="https://cdn.embedly.com/widgets/media.html?type=text%2Fhtml&amp;key=a19fcc184b9711e1b4764040d3dc5c07&amp;schema=twitter&amp;url=https%3A//twitter.com/_chris_lu_/status/1643992216413831171%3Fs%3D20&amp;image=https%3A//i.embed.ly/1/image%3Furl%3Dhttps%253A%252F%252Fabs.twimg.com%252Ferrors%252Flogo46x38.png%26key%3Da19fcc184b9711e1b4764040d3dc5c07" width="500" height="281" frameborder="0" scrolling="no"><a href="https://medium.com/media/51f6a09d44b486f82f1dac941ba1c012/href">https://medium.com/media/51f6a09d44b486f82f1dac941ba1c012/href</a></iframe><h3>Environment, Agent, and Policy implementations:</h3><p>Let’s take a look at the implementation of the different parts of our RL experiment. Here’s a high-level overview of the basic functions we’ll need:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*9BLvUdYdKGbr2II-6aT1Rw.png" /><figcaption>Class methods required for a simple RL setup (made by the author)</figcaption></figure><h4>The environment</h4><p>This implementation follows the scheme provided by <a href="https://medium.com/@ngoodger_7766?source=post_page-----9f74338898ba--------------------------------">Nikolaj Goodger</a> in his great article on writing environments in JAX.</p><p><a href="https://medium.com/@ngoodger_7766/writing-an-rl-environment-in-jax-9f74338898ba">Writing an RL Environment in JAX</a></p><p>Let’s start with a <strong>high-level view</strong> of the environment and its methods. This is a general plan for implementing an environment in JAX:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/adc19912c513d5d16903c068ee5fb428/href">https://medium.com/media/adc19912c513d5d16903c068ee5fb428/href</a></iframe><p>Let’s take a closer look at the class methods <em>(as a reminder, functions starting with “_” are </em><strong><em>private </em></strong><em>and shall not be called outside of the scope of the class)</em>:</p><ul><li><strong>_get_obs</strong>: This method converts the environment state to an observation for the agent. In a <strong>partially observable</strong> or <strong>stochastic </strong>environment, the processing functions applied to the state would go here.</li><li><strong>_reset</strong>: As we’ll be running multiple agents in parallel, we need a method for individual resets on the completion of an episode.</li><li><strong>_reset_if_done</strong>: This method will be called at each step and trigger _reset if the “done” flag is set to True.</li><li><strong>reset</strong>: This method is called at the beginning of the experiment to get the initial state of each agent, as well as the associated random keys</li><li><strong>step</strong>: Given a state and an action, the environment returns an observation (new state), a reward, and the updated “done” flag.</li></ul><p>In practice, a generic implementation of a GridWorld environment would look like this:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/61cae4744a9500cf0fa1c2b3c47b8905/href">https://medium.com/media/61cae4744a9500cf0fa1c2b3c47b8905/href</a></iframe><p>Notice that, as mentioned earlier, all class methods follow the <strong>functional programming</strong> paradigm. Indeed, we never update the internal state of the class instance. Furthermore, the <strong>class attributes</strong> are all <strong>constants</strong> that won’t be modified after instantiation.</p><p>Let’s take a closer look:</p><ul><li><strong>__init__: </strong>In the context of our GridWorld, the available actions are <strong>[</strong>0, 1, 2, 3<strong>]</strong>. These actions are translated into a 2-dimensional array using <em>self.movements </em>and added to the state in the step function.</li><li><strong>_get_obs: </strong>Our environment is <strong>deterministic </strong>and <strong>fully observable</strong>, therefore the agent receives the state directly instead of a processed observation.</li><li><strong>_reset_if_done: </strong>The argument <em>env_state</em> corresponds to the (state, key) tuple where key is a <a href="https://jax.readthedocs.io/en/latest/jax.random.html"><em>jax.random.PRNGKey</em></a><em>. </em>This function simply returns the initial state if the <em>done </em>flag is set to True, however, we cannot use conventional Python control flow within JAX jitted functions. Using <a href="https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#cond"><em>jax.lax.cond</em></a><em> </em>we essentially get an expression equivalent to:</li></ul><pre>def cond(condition, true_fun, false_fun, operand):<br>  if condition: # if done flag == True<br>    return true_fun(operand)  # return self._reset(key)<br>  else:<br>    return false_fun(operand) # return env_state</pre><ul><li><strong>step: </strong>We convert the action to a movement and add it to the current state (<em>jax.numpy.clip </em>ensures that the agent stays within the grid). We then update the <em>env_state </em>tuple before checking if the environment needs to be reset. As the step function is used frequently throughout training, jitting it allows significant performance gains. The <em>@partial(jit, static_argnums=(0, ) </em>decorator signals that the “<em>self”</em> argument of the class method should be considered <strong>static</strong>. In other words, the <strong>class properties are constant</strong> and won’t change during successive calls to the step function.</li></ul><h4>Q-Learning Agent</h4><p>The Q-learning agent is defined by the <strong>update </strong>function, as well as a static <strong>learning rate</strong> and <strong>discount factor</strong>.</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/e4fa1d223ac28115e273440c1c292382/href">https://medium.com/media/e4fa1d223ac28115e273440c1c292382/href</a></iframe><p>Once again, when jitting the update function, we pass the “self” argument as static. Also, notice that the <em>q_values </em>matrix is modified in place using <em>set() </em>and its value is not stored as a class attribute.</p><h4>Epsilon-Greedy Policy</h4><p>Finally, the policy used in this experiment is the standard <strong>epsilon-greedy policy</strong>. One important detail is that it uses <strong>random tie-breaks</strong>, which means that if the maximal Q-value is not unique, the action will be <strong>sampled uniformly</strong> from the <strong>maximal Q-values</strong> <em>(using argmax would always return the first action with maximal Q-value).</em> This is especially important if Q-values are initialized as a matrix of zeros, as the action 0 (move right) would always be selected.</p><p>Otherwise, the policy can be summarized by this snippet:</p><pre>action = lax.cond(<br>            explore, # if p &lt; epsilon<br>            _random_action_fn, # select a random action given the key<br>            _greedy_action_fn, # select the greedy action w.r.t Q-values<br>            operand=subkey, # use subkey as an argument for the above funcs<br>        )<br>return action, subkey</pre><p>Note that when we use a <strong><em>key </em></strong>in JAX <em>(e.g. here we sampled a random float and used random.choice) </em>it is common practice to split the key afterward <em>(i.e. “move on to a new random state”, more details </em><a href="https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html#random-numbers-in-jax"><em>here</em></a><em>).</em></p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/07f0a02c79b8c098bdd782d1156ca04d/href">https://medium.com/media/07f0a02c79b8c098bdd782d1156ca04d/href</a></iframe><h3>Single-agent training loop:</h3><p>Now that we have all the required components, let’s train a single agent.</p><p>Here’s a <strong><em>Pythonic</em> </strong>training loop, as you can see we are essentially selecting an action using the policy, performing a step in the environment, and updating the Q-values, until the end of an episode. Then we repeat the process for <strong><em>N </em></strong>episodes. As we’ll see in a minute, this way of training an agent is quite <strong>inefficient</strong>, however, it summarizes the key steps of the algorithm in a readable way:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/ea2f949e4a8a39891b27ab768e94500f/href">https://medium.com/media/ea2f949e4a8a39891b27ab768e94500f/href</a></iframe><p>On a single CPU, we complete 10.000 episodes in 11 seconds, at a rate of 881 episodes and 21 680 steps per second.</p><pre>100%|██████████| 10000/10000 [00:11&lt;00:00, 881.86it/s]<br>Total Number of steps: 238 488<br>Number of steps per second: 21 680</pre><p>Now, let’s replicate the same training loop using JAX syntax. Here’s a high-level description of the <strong>rollout</strong> function:</p><figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*qNcfSp0T2OeW7aXvWPpXSA.png" /><figcaption>Training rollout function using <strong>JAX syntax </strong>(made by the author)</figcaption></figure><p>To summarize, the rollout function:</p><ol><li><strong>Initializes </strong>the <strong>observations</strong>, <strong>rewards</strong>, and <strong>done </strong>flags as empty arrays with a dimension equal to the number of time steps using <em>jax.numpy.zeros. </em>The <strong>Q-values</strong> are initialized as an empty matrix with shape <strong>[</strong>timesteps<strong>+1</strong>, grid_dimension_x, grid_dimension_y, n_actions<strong>]</strong>.</li><li>Calls the <strong><em>env.reset()</em> </strong>function to get the initial state</li><li>Uses the<strong><em> jax.lax.fori_loop()</em></strong> function to call a <strong><em>fori_body() </em></strong>function <strong><em>N</em></strong> times, where <strong><em>N</em></strong> is the <strong><em>timestep </em></strong>parameter</li><li>The <strong><em>fori_body() </em></strong>function behaves similarly to the previous Python loop. After selecting an action, performing a step, and computing the Q-update, we update the obs, rewards, done, and q_values arrays in place <em>(the Q-update targets the time step </em><strong><em>t+1</em></strong><em>)</em>.</li></ol><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/5e0d8b4186660c3c723fff427e76d747/href">https://medium.com/media/5e0d8b4186660c3c723fff427e76d747/href</a></iframe><p>This additional complexity leads to an <strong>85x speed-up</strong>, we now train our agent at roughly <strong>1.83 million steps per second</strong>. Note that here, the training is done on a <strong><em>single CPU</em></strong> as the environment is simplistic.</p><p>However, <strong>end-to-end vectorization scales even better</strong> when applied to <strong>complex environments</strong> and <strong>algorithms benefitting from multiple GPUs</strong> (<a href="https://chrislu.page/blog/meta-disco/">Chris Lu’s article</a> reports a whopping <strong>4000x speed-up</strong> between a CleanRL PyTorch implementation of PPO and a JAX reproduction).</p><pre>100%|██████████| 1000000/1000000 [00:00&lt;00:00, 1837563.94it/s]<br>Total Number of steps: 1 000 000<br>Number of steps per second: 1 837 563</pre><p>After training our agent, we plot the maximal Q-value for each cell (i.e. <em>state</em>) of the GridWorld and we observe that it has effectively learned to go from the initial state (bottom right corner) to the objective (top left corner).</p><iframe src="https://cdn.embedly.com/widgets/media.html?src=https%3A%2F%2Fplotly.com%2F%7ERyan_pgd%2F1.embed%3Fautosize%3Dtrue&amp;display_name=Plotly&amp;url=https%3A%2F%2Fchart-studio.plotly.com%2F%7ERyan_pgd%2F1%2F&amp;image=https%3A%2F%2Fchart-studio.plotly.com%2Fstatic%2Fwebapp%2Fimages%2Fplotly-logo.8d56a320dbb8.png&amp;key=a19fcc184b9711e1b4764040d3dc5c07&amp;type=text%2Fhtml&amp;schema=plotly" width="600" height="400" frameborder="0" scrolling="no"><a href="https://medium.com/media/30e43fb24a5519050ecfc1e370f31219/href">https://medium.com/media/30e43fb24a5519050ecfc1e370f31219/href</a></iframe><h3><strong>Parallel agents training loop:</strong></h3><p>As promised, now that we’ve written the functions required to train a <strong>single agent</strong>, we have little to no work left to train <strong>multiple agents</strong> in <strong>parallel </strong>on batched environments!</p><p>Thanks to <strong>vmap</strong> we can quickly transform our previous functions to work on batches of data. We only have to specify the expected input and output shapes, for instance for <strong><em>env.step:</em></strong></p><ul><li><strong>in_axes</strong> = ((0,0), 0) represents the input shape, which is composed of the <em>env_state </em>tuple (dimension (0, 0)) and an <em>observation </em>(dimension 0).</li><li><strong>out_axes </strong>= ((0, 0), 0, 0, 0) represents the output shape, with the output being ((env_state), obs, reward, done).</li><li>Now, we can call <strong><em>v_step </em></strong>on an <strong>array </strong>of <em>env_states </em>and <em>actions </em>and receive an <strong>array </strong>of processed <em>env_states</em>, <em>observations</em>, <em>rewards</em>, and <em>done flags.</em></li><li>Note that we also <strong>jit</strong> all batched functions for performance (arguably, jitting <em>env.reset() </em>is unnecessary given that it is only called once in our training function).</li></ul><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/ff345d5bfa6606bf53df5bbb041decbc/href">https://medium.com/media/ff345d5bfa6606bf53df5bbb041decbc/href</a></iframe><p>The last adjustment we have to make is to <strong>add a batch dimension</strong> to our arrays to account for each agent’s data.</p><p>By doing this, we obtain a function that allows us to train <strong>multiple agents in parallel</strong>, with minimal adjustments compared to the single agent function:</p><iframe src="" width="0" height="0" frameborder="0" scrolling="no"><a href="https://medium.com/media/f3d04e39e05060ddf9b7e2cddfe44228/href">https://medium.com/media/f3d04e39e05060ddf9b7e2cddfe44228/href</a></iframe><p>We get similar performances with this version of our training function:</p><pre>100%|██████████| 100000/100000 [00:02&lt;00:00, 49036.11it/s]<br>Total Number of steps: 100 000 * 30 = 3 000 000<br>Number of steps per second: 49 036 * 30 = 1 471 080</pre><p>And that’s it! Thanks for reading this far, I hope this article provided a helpful introduction to implementing vectorized environments in <strong>JAX</strong>.</p><p>If you enjoyed the read, please consider <strong>sharing </strong>this article and <strong>starring </strong>my GitHub repository, thanks for your support! 🙏</p><p><a href="https://github.com/RPegoud/jym">GitHub - RPegoud/jym: JAX implementation of RL algorithms and vectorized environments</a></p><p>Finally, for those interested in digging a little deeper, here’s a list of <strong>useful resources</strong> that helped me get started with JAX and redacting this article:</p><h3>A curated list of awesome JAX articles and resources:</h3><p>[1] Coderized, (functional programming) <a href="https://www.youtube.com/watch?v=HlgG395PQWw&amp;t=254s"><em>The purest coding style, where bugs are near impossible</em></a>, YouTube</p><p>[2] Aleksa Gordić,<a href="https://www.youtube.com/watch?v=SstuvS-tVc0&amp;list=PLBoQnSflObckOARbMK9Lt98Id0AKcZurq"><em> JAX From Zero to Hero YouTube Playlist</em></a><em> (2022), The AI Epiphany</em></p><p>[3] Nikolaj Goodger, <a href="https://medium.com/@ngoodger_7766/writing-an-rl-environment-in-jax-9f74338898ba"><em>Writing an RL Environment in JAX</em></a><em> (2021)</em></p><p>[4] Chris Lu<em>, </em><a href="https://chrislu.page/blog/meta-disco/"><em>Achieving 4000x Speedups and Meta-Evolving Discoveries with PureJaxRL</em></a><em> (2023), </em><a href="https://www.ox.ac.uk/">University of Oxford</a>, <a href="https://www.foersterlab.com/">Foerster Lab for AI Research</a></p><p>[5] Nicholas Vadivelu,<em> </em><a href="https://github.com/n2cholas/awesome-jax"><em>Awesome-JAX</em></a><em> (2020)</em>, a list of JAX libraries, projects, and resources</p><p>[6] JAX Official Documentation, <a href="https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html"><em>Training a Simple Neural Network, with PyTorch Data Loading</em></a></p><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=49d07373adf5" width="1" height="1" alt=""><hr><p><a href="https://medium.com/data-science/vectorize-and-parallelize-rl-environments-with-jax-q-learning-at-the-speed-of-light-49d07373adf5">Vectorize and Parallelize RL Environments with JAX: Q-learning at the Speed of Light⚡</a> was originally published in <a href="https://medium.com/data-science">TDS Archive</a> on Medium, where people are continuing the conversation by highlighting and responding to this story.</p>]]></content:encoded>
        </item>
    </channel>
</rss>