<?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom">
  <generator uri="http://jekyllrb.com" version="4.3.3">Jekyll</generator>
  
  
  <link href="http://sdbuchanan.com/feed.xml" rel="self" type="application/atom+xml" />
  <link href="http://sdbuchanan.com/" rel="alternate" type="text/html" />
  <updated>2026-01-26T10:05:52+00:00</updated>
  <id>http://sdbuchanan.com//</id>

  
    <title type="html">Sam D. Buchanan</title>
  

  
    <subtitle>Home</subtitle>
  

  
    <author>
        <name>Sam Buchanan</name>
      
        <email>sdbuchanan@berkeley.edu</email>
      
      
    </author>
  

  
  
    <entry>
      
      <title type="html">SPMD in JAX #3: Infrastructure Buildout</title>
      
      
      <link href="http://sdbuchanan.com/blog/jax-3/" rel="alternate" type="text/html" title="SPMD in JAX #3: Infrastructure Buildout" />
      
      <published>2025-12-10T00:00:00+00:00</published>
      <updated>2025-12-10T00:00:00+00:00</updated>
      <id>http://sdbuchanan.com/blog/jax-3</id>
      <content type="html" xml:base="http://sdbuchanan.com/blog/jax-3/">&lt;p&gt;In this third post in a series on programming with JAX for training
transformers, we pick up where we left off at the end of the &lt;a href=&quot;/blog/jax-2/&quot;&gt;previous post&lt;/a&gt;.
We’ll extend our previous small transformer model to the NanoGPT scale, incorporating
actual text data, distributed training, and a real optimizer, and perform some basic
evaluation! We’ll end up with a reasonable substrate for building future experiments on
top of.&lt;/p&gt;

&lt;p&gt;As in &lt;a href=&quot;/blog/jax-1/&quot;&gt;previous&lt;/a&gt; &lt;a href=&quot;/blog/jax-2/&quot;&gt;posts&lt;/a&gt;
in the series, we make good use of the Google scaling playbook &lt;a class=&quot;citation&quot; href=&quot;#scaling-book&quot;&gt;(Austin et al., 2025)&lt;/a&gt;
for profiling.&lt;/p&gt;

&lt;ul id=&quot;markdown-toc&quot;&gt;
  &lt;li&gt;&lt;a href=&quot;#package-conversion-and-improvements&quot; id=&quot;markdown-toc-package-conversion-and-improvements&quot;&gt;Package Conversion and Improvements&lt;/a&gt;    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#config-upgrades&quot; id=&quot;markdown-toc-config-upgrades&quot;&gt;Config Upgrades&lt;/a&gt;        &lt;ul&gt;
          &lt;li&gt;&lt;a href=&quot;#small-modifications-to-our-config-dataclass&quot; id=&quot;markdown-toc-small-modifications-to-our-config-dataclass&quot;&gt;Small Modifications to our Config Dataclass&lt;/a&gt;&lt;/li&gt;
        &lt;/ul&gt;
      &lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#model-improvements-and-refactors&quot; id=&quot;markdown-toc-model-improvements-and-refactors&quot;&gt;Model Improvements and Refactors&lt;/a&gt;        &lt;ul&gt;
          &lt;li&gt;&lt;a href=&quot;#splash-attention-integration&quot; id=&quot;markdown-toc-splash-attention-integration&quot;&gt;Splash Attention Integration&lt;/a&gt;&lt;/li&gt;
        &lt;/ul&gt;
      &lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#optimizer-improvements&quot; id=&quot;markdown-toc-optimizer-improvements&quot;&gt;Optimizer Improvements&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#distributed-training&quot; id=&quot;markdown-toc-distributed-training&quot;&gt;Distributed Training&lt;/a&gt;        &lt;ul&gt;
          &lt;li&gt;&lt;a href=&quot;#high-level-paradigm&quot; id=&quot;markdown-toc-high-level-paradigm&quot;&gt;High-Level Paradigm&lt;/a&gt;&lt;/li&gt;
          &lt;li&gt;&lt;a href=&quot;#setting-up-the-mesh-and-initializing-the-distributed-backend&quot; id=&quot;markdown-toc-setting-up-the-mesh-and-initializing-the-distributed-backend&quot;&gt;Setting up the Mesh and Initializing the Distributed Backend&lt;/a&gt;&lt;/li&gt;
          &lt;li&gt;&lt;a href=&quot;#modifications-for-data-loading&quot; id=&quot;markdown-toc-modifications-for-data-loading&quot;&gt;Modifications for Data Loading&lt;/a&gt;&lt;/li&gt;
          &lt;li&gt;&lt;a href=&quot;#remaining-cleanup&quot; id=&quot;markdown-toc-remaining-cleanup&quot;&gt;Remaining Cleanup&lt;/a&gt;&lt;/li&gt;
          &lt;li&gt;&lt;a href=&quot;#launching-everything&quot; id=&quot;markdown-toc-launching-everything&quot;&gt;Launching Everything&lt;/a&gt;&lt;/li&gt;
        &lt;/ul&gt;
      &lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#adding-logging&quot; id=&quot;markdown-toc-adding-logging&quot;&gt;Adding Logging&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#actual-data&quot; id=&quot;markdown-toc-actual-data&quot;&gt;Actual Data&lt;/a&gt;    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#checking-the-results&quot; id=&quot;markdown-toc-checking-the-results&quot;&gt;Checking the Results&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#conclusions&quot; id=&quot;markdown-toc-conclusions&quot;&gt;Conclusions&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#acknowledgments&quot; id=&quot;markdown-toc-acknowledgments&quot;&gt;Acknowledgments&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;h1 id=&quot;package-conversion-and-improvements&quot;&gt;Package Conversion and Improvements&lt;/h1&gt;

&lt;p&gt;We’re running in Python 3.13, and we’ve upgraded to JAX 0.8.0 from
&lt;a href=&quot;/blog/jax-1/&quot;&gt;previous&lt;/a&gt; &lt;a href=&quot;/blog/jax-2/&quot;&gt;posts&lt;/a&gt;.
Now that we’ve written a good chunk of model, training, and config
code, we’ll also be working on top of a packaged version of our code.&lt;/p&gt;

&lt;p&gt;The code is hosted in this GitHub repository:
&lt;a href=&quot;https://github.com/sdbuch/baremetal-gpt&quot;&gt;baremetal-gpt&lt;/a&gt;. There are instructions to
install the repository as a package using &lt;a href=&quot;https://astral.sh/uv&quot;&gt;uv&lt;/a&gt; there.&lt;/p&gt;

&lt;p&gt;We’ll start by talking through some refactoring that leads us from the state of our code
in the &lt;a href=&quot;/blog/jax-2/&quot;&gt;previous post&lt;/a&gt; to the state of the code in the
GitHub repository.&lt;/p&gt;

&lt;h2 id=&quot;config-upgrades&quot;&gt;Config Upgrades&lt;/h2&gt;

&lt;p&gt;To run experiments with our code in a distributed (i.e., multi-host) setting, we’ll
benefit from having a robust way to dispatch different experimental configurations from
the command line.
We’ll use &lt;a href=&quot;https://hydra.cc&quot;&gt;Hydra&lt;/a&gt; for this purpose.
Hydra is not the most minimal solution, but it’s the most feature-complete library for
managing configs that I know of (e.g., specifying lists as arguments), and for our
simple purposes we’ll be able to adopt it with only a small amount of added code and
complexity.&lt;/p&gt;

&lt;h3 id=&quot;small-modifications-to-our-config-dataclass&quot;&gt;Small Modifications to our Config Dataclass&lt;/h3&gt;

&lt;p&gt;Hydra supports registering config dataclasses as so-called “Structured Configs”.
&lt;a href=&quot;https://hydra.cc/docs/tutorials/structured_config/minimal_example/&quot;&gt;This tutorial&lt;/a&gt;
gives a good description of the core functionalities we need.
We only need to make a few small modifications to our previous dataclass:&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;Use only supported types: instead of random class objects (e.g., for dtypes or
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;PartitionSpec&lt;/code&gt;s), we need to create &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Enum&lt;/code&gt;s for passing values, and disambiguate
them at call time. There are other minor caveats (e.g., can’t arbitrarily nest lists;
restrictions on use of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Optional&lt;/code&gt;).&lt;/li&gt;
  &lt;li&gt;Make the config mutable (remove &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;frozen=True&lt;/code&gt;).&lt;/li&gt;
  &lt;li&gt;Re-register the config as static with JAX. Hydra will wrap our dataclass so that it
ends up duck-typed as a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Config&lt;/code&gt; object, but we will need to re-register the wrapped
config as static. Since the dataclass is no longer frozen, we can do this with a
post-init function, which we’ll now call when we launch our training loop.&lt;/li&gt;
  &lt;li&gt;Create a function to register the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Config&lt;/code&gt; dataclass in a Hydra &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ConfigStore&lt;/code&gt;
instance. In general, we can create arbitrarily stratified configs for easier
management (e.g., separate configs for model, experiment, data, etc.). We’ll call
this function in our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;main&lt;/code&gt; method.&lt;/li&gt;
  &lt;li&gt;Refactor the training loop into a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;train.py&lt;/code&gt; file with a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;main&lt;/code&gt; method that we’ll
call when we want to launch an experiment, and add the necessary Hydra setup code.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;We will also make a larger refactor, which will help us later with setting up config
YAML files for experiments: we will change from a flat config to a hierarchical config,
with different parameters split into different sub-configs (e.g., for the optimizer, for
data, …). This is not an objectively superior design – for example, it’s arguably a
bit easier to write new code and integrate static type checking with a flat config –
but we’ll follow it in order to add a bit more encapsulation to distinct config
parameters.&lt;/p&gt;

&lt;p&gt;Here’s a representative snipped from the new &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;config.py&lt;/code&gt; with the modifications discussed
above. There are many forward-looking modifications too, which we just show
out-of-context selections of.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# config.py
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# We omit the sub-configs, and other nitty-gritty details...
# Config is now top-level (orchestration parameters, etc.)
&lt;/span&gt;

&lt;span class=&quot;nd&quot;&gt;@dataclass&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kw_only&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;unsafe_hash&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;sh&quot;&gt;&quot;&quot;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Overall config class, containing top-level orchestration parameters&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&quot;&quot;&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;seed&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1337&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;logger_type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;LoggerType&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;LoggerType&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;WANDB&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;project_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;bmgpt-debug&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;run_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;val_log_interval&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1000&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# log validation metrics every &amp;lt;this many&amp;gt; batches
&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;train_dataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DatasetConfig&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;MISSING&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;val_list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;EvaluationConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;field&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;default_factory&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# validation metrics
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;eval_list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;EvaluationConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;field&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;default_factory&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;OptimizerConfig&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;field&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;default_factory&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;OptimizerConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ModelConfig&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;field&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;default_factory&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ModelConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;inference&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;InferenceConfig&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;field&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;default_factory&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;InferenceConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ShardingConfig&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;field&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;default_factory&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ShardingConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;register_configs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ConfigStore&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;instance&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;store&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;group&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;optimizer&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;base_optimizer&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;node&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;OptimizerConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;store&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;group&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;base_model&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;node&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ModelConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;store&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;group&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;inference&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;base_inference&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;node&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;InferenceConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;store&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;group&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;base_sharding&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;node&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ShardingConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;store&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;base_config&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;node&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;config_post_init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;sh&quot;&gt;&quot;&quot;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Input validation.&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;And here’s the relevant parts of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;train.py&lt;/code&gt; that we modify. There is some extra code for
distributed initialization, which we discuss in more detail later in the post.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;pathlib&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Path&lt;/span&gt;

&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hydra&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bmgpt.config&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config_post_init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;register_configs&lt;/span&gt;

&lt;span class=&quot;nf&quot;&gt;register_configs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;

&lt;span class=&quot;nd&quot;&gt;@hydra.main&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;version_base&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config_path&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;Path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;configs&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;absolute&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;resolve&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()),&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config_name&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;base_config&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;main&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;try&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Launch distributed and register configs
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distributed&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;initialize&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree_util&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;register_static&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;except&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;RuntimeError&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# This implies the distributed backend has already been initialized
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;pass&lt;/span&gt;
    &lt;span class=&quot;nf&quot;&gt;config_post_init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;__name__&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;__main__&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;nf&quot;&gt;main&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The modifications are quite simple: we register our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Config&lt;/code&gt; dataclass in the
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ConfigStore&lt;/code&gt; instance as &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;'config'&lt;/code&gt;, then decorate our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;main&lt;/code&gt; method in &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;train.py&lt;/code&gt; to
use that appropriate config. The &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;'config_path'&lt;/code&gt; argument in that decorator specifies a directory where we store
different &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;.yaml&lt;/code&gt; config files specifying different experimental setups (e.g., different
model architectures, different datasets, different distributed training environments,
etc.). Then we just call our post-init config method right at the start of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;main&lt;/code&gt;: right
now, it just validates some config inputs.&lt;/p&gt;

&lt;p&gt;When we want to launch an experiment now, we can run a command like the following (train
loop refactoring follows later in the blog):&lt;/p&gt;

&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;uv run &lt;span class=&quot;nt&quot;&gt;--extra&lt;/span&gt; tpu train model.num_layers&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;1
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Here, we are using &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;uv&lt;/code&gt; optional dependencies so that we can run on either CPU or TPU
depending on the platform, as well as a build script in our package (see the
&lt;a href=&quot;https://github.com/sdbuch/baremetal-gpt/blob/main/pyproject.toml&quot;&gt;pyproject.toml&lt;/a&gt;).
Super easy! Anything that isn’t specified follows our defaults from the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Config&lt;/code&gt;
dataclass definition.&lt;/p&gt;

&lt;p&gt;As the experiments we want to run become more diverse and complex, it becomes harder and
harder to productively specify all default overrides via the command line. To address
this, we can write YAML configs for different experiments as we mentioned above. For
example, we can make a file at &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;configs/experiments/text_toy.yaml&lt;/code&gt;, with:&lt;/p&gt;

&lt;div class=&quot;language-yaml highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# @package _global_&lt;/span&gt;

&lt;span class=&quot;na&quot;&gt;defaults&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt;
  &lt;span class=&quot;pi&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;/dataset/staircase@train_dataset&lt;/span&gt;

&lt;span class=&quot;na&quot;&gt;train_dataset&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt;
  &lt;span class=&quot;na&quot;&gt;num_steps&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;m&quot;&gt;100&lt;/span&gt;

&lt;span class=&quot;na&quot;&gt;eval_list&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt;
  &lt;span class=&quot;pi&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;na&quot;&gt;dataset&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;NUMBER_STAIRCASE&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;TEST&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;seq_len&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;m&quot;&gt;1&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# At inference time, this functions like prompt size&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;global_batch_size&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;m&quot;&gt;16&lt;/span&gt;
    &lt;span class=&quot;na&quot;&gt;evaluator&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;AUTOREGRESSIVE_ROLLOUTS&lt;/span&gt;

&lt;span class=&quot;na&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt;
  &lt;span class=&quot;na&quot;&gt;transformer_type&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;DISCRETE&lt;/span&gt;
  &lt;span class=&quot;na&quot;&gt;is_causal&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;True&lt;/span&gt;
  &lt;span class=&quot;na&quot;&gt;num_vocab&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;m&quot;&gt;10&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;See the Hydra docs for more details about how to parse this file – it has some
boilerplate to ensure it overrides the config defaults we specify in our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;config.py&lt;/code&gt;.
We’ll discuss some of the enums that appear here (you can identify them by the all-caps
text) later on. Then to run an experiment, we can simply override as:&lt;/p&gt;

&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;uv run &lt;span class=&quot;nt&quot;&gt;--extra&lt;/span&gt; tpu train +experiment&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;text_toy
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;In addition, any other command line arguments we want to override with can be added!&lt;/p&gt;

&lt;h2 id=&quot;model-improvements-and-refactors&quot;&gt;Model Improvements and Refactors&lt;/h2&gt;

&lt;p&gt;We’ll add support for the Pallas Splash Attention kernel, which will help us with model
scaling later!&lt;/p&gt;

&lt;h3 id=&quot;splash-attention-integration&quot;&gt;Splash Attention Integration&lt;/h3&gt;

&lt;p&gt;Splash Attention stands for “sparse Flash Attention” &lt;a class=&quot;citation&quot; href=&quot;#Pagliardini2023-nr&quot;&gt;(Pagliardini et al., 2023)&lt;/a&gt;.
Flash Attention is a well-known
memory-efficient implementation of the attention operation &lt;a class=&quot;citation&quot; href=&quot;#Dao2022-io&quot;&gt;(Dao et al., 2022)&lt;/a&gt;; it
also is often used to refer to the associated CUDA kernel and surrounding PyTorch
package for the algorithm.
Compared to flash attention, splash attention adds better support for sparse attention
masks, which arise frequently in practice: for example, when we want to pack many
relatively short sequences into one long sequence, or precisely mask between document
boundaries, we need to make sure not to attend across these boundaries.&lt;/p&gt;

&lt;p&gt;The types of optimizations that flash/splash attention entail require a level of
control over hardware-level execution that is not possible within the standard JAX-XLA
programming pipeline. So JAX exposes a domain-specific language for writing these
kernels, called Pallas! The &lt;em&gt;authoring&lt;/em&gt; of Pallas kernels is out-of-scope for this post,
but we will be interested in learning how to &lt;em&gt;call&lt;/em&gt; pre-written Pallas kernels so that
we can integrate splash attention.&lt;/p&gt;

&lt;p&gt;The interfaces for achieving this are naturally a bit rough around the edges, with scant
documentation and relatively few examples for how to proceed. With the help of multiple
LLMs, I eventually pieced together a solution based on example code available in the JAX
Pallas tests. Here’s a high-level overview, with the code to follow:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;We need to configure the Pallas kernel itself based on how we’ll use it at runtime:
specifically, the sequence length, number of heads, base mask type, and parallelism
configuration. The latter has to be enforced manually with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;shard_map&lt;/code&gt;. Once we figure
out where
everything needs to go, this works more-or-less okay with
our current model implementation structure, which uses &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Explicit&lt;/code&gt; mesh axes and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;vmap&lt;/code&gt;
as much as possible to simplify ‘inner’ model code. However, it currently doesn’t work
in JAX &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;0.8.1&lt;/code&gt;. It should work in the next release, thanks to a &lt;a href=&quot;https://github.com/jax-ml/jax/issues/33548&quot;&gt;bug
report&lt;/a&gt; I submitted and a fix from Yash
Katariya (thanks!).&lt;/li&gt;
  &lt;li&gt;Once we have the kernel (represented as a tuple of two &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Callable&lt;/code&gt;s, one corresponding
to the actual kernel and one corresponding to a wrapper with the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;shard_map&lt;/code&gt;
configuration), we need to thread it through our top-level &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_transformer&lt;/code&gt; call down to
the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_attn&lt;/code&gt; call. A more ‘elegant’ approach to doing this (at least at the call site)
might be to use some metaprogramming, but we will go with the simplest solution.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The code for the kernel setup is in &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;splash_helpers.py&lt;/code&gt;:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# splash_helpers.py
&lt;/span&gt;&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;functools&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;partial&lt;/span&gt;

&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bmgpt.config&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DatasetConfig&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax.experimental.pallas.ops.tpu.splash_attention&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;BlockSizes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;CausalMask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;MultiHeadMask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;make_splash_mha&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;make_splash_kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DatasetConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cache_capacity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;head_shards&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;q_seq_shards&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;use_splash&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# None ends up calling jax-xla attention
&lt;/span&gt;        &lt;span class=&quot;c1&quot;&gt;# see _attn
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# s is Q len (seq_len @ train; variable/1 at prefill/decode)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# t is K len (s + cache_capacity)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;s&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_capacity&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;is_causal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;MultiHeadMask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;
                &lt;span class=&quot;nc&quot;&gt;CausalMask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;offset&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_capacity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
                &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;MultiHeadMask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;FullMask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;min&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;!=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;or&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;!=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# splash attention kernel requires block size to be a multiple of 128
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;raise&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;NotImplementedError&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Splash block size needs to be a multiple of 128&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;block_sizes&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;BlockSizes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_q&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_kv&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_kv_compute&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_q_dkv&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_kv_dkv&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_kv_dkv_compute&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_q_dq&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_kv_dq&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;splash_spec&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;splash_sharding&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;NamedSharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;splash_spec&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;make_splash_mha&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;head_shards&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;head_shards&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;q_seq_shards&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_seq_shards&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_sizes&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_sizes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;kernel_spec&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;manual_sharding_spec&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;splash_sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;nd&quot;&gt;@partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shard_map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;in_specs&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kernel_spec&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;splash_spec&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;splash_spec&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;splash_spec&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()),&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;out_specs&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;splash_spec&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;check_vma&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;splash_sharded&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;segment_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;segment_ids&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;segment_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;nf&quot;&gt;return &lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;splash_sharded&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The function argument &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;cache_capacity&lt;/code&gt; specifies
the maximum KV cache size we’ll use (we’ve split this off from &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;max_seq_len&lt;/code&gt; which
filled a dual function encompassing this previously).
This allows us to specify a cache size of 0 for training, for example.
Within the function body, we have a new configuration parameter &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;use_splash&lt;/code&gt; to denote whether splash
attention should be used or not. For example, we always fall back to XLA attention for
autoregressive inference, since the splash attention Pallas kernel requires the block
size to be a multiple of 128 (and for decode, we almost necessarily have to process a
query block size of 1). In this case, the kernel creation factory just returns &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;None&lt;/code&gt;,
and we check for this in &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_attn&lt;/code&gt;.&lt;sup id=&quot;fnref:7&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:7&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;1&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

&lt;p&gt;The “sparse” part of splash attention comes from the proper use of the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;segment_ids&lt;/code&gt;
argument: distinct documents/sub-sequences can be assigned different &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;segment_ids&lt;/code&gt;.
We’ll see an example below, in how to set this up for use during training with respect
to the KV cache.&lt;/p&gt;

&lt;p&gt;First, at the level of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;train.py&lt;/code&gt;, we now create our kernels when we set up. For
training, this leads to a simple &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;kernel = make_splash_kernel(config,
config.train_dataset, 0, mesh)&lt;/code&gt; statement. We then pass the kernel to any model call we
have: for example, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;logits, _ = jax.vmap( partial(_transformer, config, kernel, params,
cache_params=cache_params))(inputs, train_state.kv_cache)&lt;/code&gt; in the training loop step.
Here, we’ve refactored the way we use the KV cache – a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;CacheParams&lt;/code&gt; class is defined
simply as follows:&lt;/p&gt;
&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;CacheParams&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NamedTuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;enabled&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;bool&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;p&gt;This way, as we’ll see momentarily, we can easily trace with respect to the size
argument (i.e., for inference) while still omitting all KV cache maintenance code at
compile time if we want to train (i.e., by setting &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;enabled=False&lt;/code&gt;).&lt;/p&gt;

&lt;p&gt;The kernel argument and this new &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;cache_params&lt;/code&gt; argument get threaded down to the level
of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_attn&lt;/code&gt;, which now reads as follows (abbreviating unchanged components):&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# model.py
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_make_causal_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq_len_q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_len_k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;sh&quot;&gt;&quot;&quot;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;(seq_len_q, seq_len_k) bool array with True for past token positions&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;q_positions&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq_len_q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;k_positions&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq_len_k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q_positions&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k_positions&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_make_cache_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq_len_q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_len_k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;sh&quot;&gt;&quot;&quot;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;(1, seq_len_k) bool array with True for actual cache+context positions&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;k_positions&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq_len_k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k_positions&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x_seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 2 x cache_capacity x n x h
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;cache_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;CacheParams&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# s: sequence length
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# d: embedding dim (config.d_model)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# n: attention heads (config.num_heads)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# h: head dim (config.d_head)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# x_seq: s x d
&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;sd,d3nh-&amp;gt;3nsh&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x_seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w_qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;att_qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;s&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# skipping rope...
&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;k_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v_cache&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;enabled&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;k_cache_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;dynamic_update_slice&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;k_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;v_cache_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;dynamic_update_slice&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;v_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;kv_cache_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;concatenate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k_cache_out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v_cache_out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;kv_cache_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Cache read scheme: to enable same mask for the same s value (Q seq len),
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#  we concatenate the full cache to K, and mask empty entries
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# For efficient training, set cache size zero + cache_params.enabled=False
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;cache_capacity&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;concatenate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;concatenate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;v_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Attention
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# t = s + cache_capacity
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;q_segment_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;kv_mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_make_cache_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;|&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;o&quot;&gt;~&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;_make_cache_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_capacity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;kv_mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;~&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;kv_segment_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;int32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;segment_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;SegmentIds&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_segment_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_segment_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;n&quot;&gt;splash_sharded&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;attn_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;splash_sharded&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.25&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.25&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;segment_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Make mask
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;is_causal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_make_causal_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_capacity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;cache_mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_make_cache_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;|&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                &lt;span class=&quot;o&quot;&gt;~&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;_make_cache_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_capacity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;amp;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_mask&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;~&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;_make_cache_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# full attention
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...]&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# broadcast over heads
&lt;/span&gt;        &lt;span class=&quot;c1&quot;&gt;# Scale and causal mask
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;nsh,nth-&amp;gt;nst&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;compute_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;value&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;where&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# type: ignore[reportArgumentType]
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;value&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;attn_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;nst,nth-&amp;gt;nsh&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;nsh,hnd-&amp;gt;sd&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;attn_out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w_o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;res_stream&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_cache_out&lt;/span&gt;


&lt;span class=&quot;c1&quot;&gt;# ... skipping to the cache init function
&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_kv_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;global_batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_capacity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;sharding_batch_layer&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;sharding_batch_layer&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_batch_layer&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;att_qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;global_batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_layers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;cache_capacity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;value&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The splash Pallas kernel expects the shape to be &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;N x S x H&lt;/code&gt;, so we refactor some code
compared to what we had before. The key change here is that in order to repeatedly reuse
the same splash Pallas kernel as we grow the cache, we write the code so that the KV
sequence length is always fixed as the cache capacity plus &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;S&lt;/code&gt; (the Q sequence length).
The cache capacity is specified dynamically at initialization, allowing us to use the
same &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Config&lt;/code&gt; object for different training/inference scenarios (e.g., cache size 0 for
training, and some fixed positive integer for autoregressive inference).
This means we need to construct a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;SegmentIds&lt;/code&gt; object that masks out all the unused
portions of the KV cache, which is performed correctly by the above code. Previously, we
were just updating the cache with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jax.lax.dynamic_update_slice&lt;/code&gt;, rather than
concatenating, but with the splash kernel structure, this would require us to change the
mask after every additional cache update (requiring a new kernel!).&lt;/p&gt;

&lt;p&gt;That’s it! We’ve integrated splash attention, allowing us to scale much more effectively
to larger model sizes down the road.&lt;/p&gt;

&lt;h2 id=&quot;optimizer-improvements&quot;&gt;Optimizer Improvements&lt;/h2&gt;

&lt;p&gt;We were using 100% vanilla SGD in the previous post. Let’s upgrade that to a much
stronger optimizer, AdamW! &lt;a class=&quot;citation&quot; href=&quot;#adamw-paper&quot;&gt;(Loshchilov &amp;amp; Hutter, 2017)&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;In our previous post, we used a very simple tree-mapped update, facilitated by the
simple structure of constant-learning-rate SGD:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;value_and_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;loss_fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;new_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;TrainState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;For implementing Adam and AdamW (the “W” denotes “decoupled weight decay”), it’s better
to use a more flexible API: both because the updates become harder to express in a
single concise &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;lambda&lt;/code&gt;, and because we generally want to be able to apply different
updates to different parameters.&lt;sup id=&quot;fnref:1&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:1&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;2&lt;/a&gt;&lt;/sup&gt; Here’s how we’ll do this, at a basic level:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;We’ll define a handfull of optimizer updates that share the same API.&lt;/li&gt;
  &lt;li&gt;We’ll allow one update to be selected via the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Config&lt;/code&gt; class (i.e., with the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Enum&lt;/code&gt;
we mentioned above), and then we’ll &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jax.tree.map&lt;/code&gt; it over the model pytree.&lt;/li&gt;
  &lt;li&gt;We’ll be able to provide a pytree of arguments to the update, as necessary, allowing
to do different things for different parameters.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;For example, here’s how this looks for the previous constant-learning-rate SGD update.
Since AdamW requires us to save &lt;em&gt;state&lt;/em&gt; information (i.e., first-order and second-order
momentum buffers), we’ll incorporate that into the API.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# Factory for getting the update, used externally
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;opt_update_factory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;opt_type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;OptType&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;match&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt_type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;case&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;OptType&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ADAMW&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;adamw_update&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;case&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;OptType&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;SGD&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sgd_update&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;OptState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NamedTuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 1st moment EMA
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;nu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 2nd moment EMA
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# step number
&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_adam_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;OptState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;OptState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;mu&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;zeros_like&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optimizer_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;value&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;nu&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;zeros_like&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optimizer_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;value&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;int32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;sgd_update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;OptState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;wd_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;bool&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;update&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;wd_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Apply weight decay
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;update&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;update&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;weight_decay&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The update API takes parameters, gradients, state, and additional flags, and returns the
update direction and the updated state. We add the update direction to the parameters
externally since this is a useful debugging abstraction (and matches the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;optax&lt;/code&gt; API).&lt;/p&gt;

&lt;p&gt;Here’s how it looks to apply the optimizer now. It gets a bit more cumbersome to do this
with only pytrees, since the update returns a tuple – meaning the output pytree
structure is different from the model/grad pytree structure. The latter is a &lt;em&gt;prefix&lt;/em&gt; of
the former, though, so we can still extract what we need fairly easily.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# In train.py
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;TrainState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NamedTuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Transformer&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;opt_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Any&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;


&lt;span class=&quot;nd&quot;&gt;@jax.jit&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;TrainState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;model_params&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;adam_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init_adam_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cache&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_kv_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_dataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;global_batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;TrainState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt_state&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;adam_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;main&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;opt_update&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;opt_update_factory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;    &lt;span class=&quot;nd&quot;&gt;@partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jit&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;donate_argnums&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;train_step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;TrainState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;value_and_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;loss_fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;update__opt_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;nf&quot;&gt;partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;opt_update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;opt_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;weight_decay_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Transpose the output tree to get update tree and state tree
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;update__opt_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;nf&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;new_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;TrainState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt_state&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;opt_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;This is pretty straightforward, once we’ve written it!&lt;/p&gt;

&lt;p&gt;Now, to add a new optimizer – AdamW in our case – we just need to define the
appropriate &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;opt_update&lt;/code&gt; and possibly define new state initialization functions. We’ve
already done the latter above for Adam. Here’s a simple implementation of AdamW that
agrees with the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;optax&lt;/code&gt; implementation (not bitwise):&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;adamw_update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;OptState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;wd_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;bool&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;beta1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;beta1&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;beta2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;beta2&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;eps&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;eps_adam&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;weight_decay&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;weight_decay&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;mu&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;beta1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mu&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;beta1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optimizer_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;value&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;nu&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;beta2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nu&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;beta2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optimizer_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;value&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;new_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;OptState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mu&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nu&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;step&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;mu_debias&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mu&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;beta1&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;new_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;nu_debias&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nu&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;beta2&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;new_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;update&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mu_debias&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;eps&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;sqrt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nu_debias&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;wd_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Apply weight decay
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;update&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;update&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;weight_decay&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;param&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;value&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;new_state&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The slightly tricky part remaining is how to pass the right weight decay mask pytree.
It’s best to do this with model &lt;strong&gt;metadata&lt;/strong&gt;, since the format we store the parameters in
doesn’t necessarily match how they’re used. For example, we store our attention QKV
parameters as a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;D x 3 x N x H&lt;/code&gt; array, where &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;D&lt;/code&gt; is the model dimension, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;N&lt;/code&gt; the number
of attention heads, and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;H&lt;/code&gt; the projected attention head dimension. We &lt;em&gt;use&lt;/em&gt; these
parameters as &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;D x H&lt;/code&gt; matrices, but we can’t safely infer this from &lt;em&gt;only&lt;/em&gt; the shape
information! Similarly, since we &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;vmap&lt;/code&gt; our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Blocks&lt;/code&gt; initialization over the layer
dimension &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;L&lt;/code&gt; when constructing the model, our MLP output bias parameters are shaped
like &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;L x D&lt;/code&gt;. This is a matrix, so when iterating over the model pytree, we can’t just
check whether the current leaf has &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ndim &amp;gt; 1&lt;/code&gt;!&lt;/p&gt;

&lt;p&gt;This is a situation where we clearly see the advantage of the bigger, standard model
libraries (e.g. Equinox): there, metadata and parameters can easily be stored together.
For us, we’d need to depart from our appealingly simple &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;NamedTuple&lt;/code&gt; model architecture
to do this (e.g. register custom pytree leaves). So we’ll instead use a slightly
unpleasant approach: we’ll just add a function to our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;model.py&lt;/code&gt; file that returns a
model “spec” associated to the architecture, containing all relevant metadata.&lt;sup id=&quot;fnref:8&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:8&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;3&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# in model.py
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;model_spec&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Any&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Make the spec (we need some way to pass metadata around)
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_make_spec_from_str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;tuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;|&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;param_str&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;__str__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;matrix_axes_dict&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;
            &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;.w_qkv&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;.w_o&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;.w_up&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;.w_down&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;.w&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;        &lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;matrix_axes_dict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;spec&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map_with_path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;p&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_make_spec_from_str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;p&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;spec&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The keys in the dictionary used to generate the output pytree correspond to the
parameter names in our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Transformer&lt;/code&gt; model. We can get these from the pytree using
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jax.tree.map_with_path&lt;/code&gt;, which gives the pytree structure (easily converted into a
string, like &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;.blocks.norm_attn.gamma&lt;/code&gt;) in addition to the leaves.&lt;/p&gt;

&lt;p&gt;The returned pytree looks like the following:&lt;/p&gt;

&lt;div class=&quot;language-text highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Transformer(emb=Embedding(w=(-2, -1)), blocks=Block(norm_attn=LayerNorm(gamma=None, beta=None), attn=Attn(w_qkv=(-4, -1), w_o=(-3, -1)), norm_mlp=LayerNorm(gamma=None, beta=None), mlp=Mlp(w_up=(-2, -1), bias_up=None, w_down=(-2, -1), bias_down=None)), unemb=Unembedding(w=(-2, -1)))
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Now, back to the matter at hand: we need to generate a mask pytree that tells us which
parameters in the model to apply weight decay to. To apply it only to the matrix
parameters in the model, we can do the following:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;spec&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;model_spec&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;weight_decay_mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;bool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;spec&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Since &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;None&lt;/code&gt; is not truthy, this will give us the mask we desire! Then we use it as
above, where we already provided &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;weight_decay_mask&lt;/code&gt; as an argument to the update
calculation tree map (but didn’t specify how to generate it).&lt;/p&gt;

&lt;p&gt;Here’s the entire updated train step. We add gradient clipping to it, since this doesn’t
require too much effort:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# in optimizers.py
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;grad_norm_and_clip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Transformer&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;tuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Gradient norms in param precision (NOTE: might want fp32?)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;grad_norms_squared&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;global_grad_norm&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;reduce&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;operator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad_norms_squared&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;truncated_norm&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;select&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;global_grad_norm&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;clip_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;global_grad_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;ones_like&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;global_grad_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;nf&quot;&gt;return &lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;truncated_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;grad_norms_squared&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;global_grad_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;It’s quite straightforward overall!&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;  &lt;span class=&quot;c1&quot;&gt;# Initialize state, configure forward pass and optimization
&lt;/span&gt;  &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;set_mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;cache_params&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;CacheParams&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;enabled&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;make_splash_kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_dataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;spec&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;model_spec&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;opt_update&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;opt_update_factory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;weight_decay_mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;bool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;spec&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


  &lt;span class=&quot;nd&quot;&gt;@partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jit&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;donate_argnums&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;train_step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;TrainState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
      &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;loss_fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
          &lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;targets&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt;
          &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;vmap&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
              &lt;span class=&quot;nf&quot;&gt;partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_params&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
          &lt;span class=&quot;p&quot;&gt;)(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
          &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;compute_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;value&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
          &lt;span class=&quot;n&quot;&gt;logprobs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;log_softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
          &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;take_along_axis&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;logprobs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;targets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[...,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

      &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;value_and_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;loss_fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;grad_clipped&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;global_grad_norm&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;grad_norm_and_clip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;update__opt_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
          &lt;span class=&quot;nf&quot;&gt;partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;opt_update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
          &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
          &lt;span class=&quot;n&quot;&gt;grad_clipped&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
          &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;opt_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
          &lt;span class=&quot;n&quot;&gt;weight_decay_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
      &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
      &lt;span class=&quot;c1&quot;&gt;# Transpose the output tree to get update tree and state tree
&lt;/span&gt;      &lt;span class=&quot;n&quot;&gt;update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
          &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;update__opt_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
      &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;new_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;TrainState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
          &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt_state&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;opt_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;
      &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

      &lt;span class=&quot;n&quot;&gt;metrics&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;batch_loss&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;grad_norm&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;global_grad_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
      &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metrics&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;new_state&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h2 id=&quot;distributed-training&quot;&gt;Distributed Training&lt;/h2&gt;

&lt;p&gt;Next, let’s set up distributed training. On the whole, this is actually very easy to do
on Google’s TPU VMs!&lt;sup id=&quot;fnref:2&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:2&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;4&lt;/a&gt;&lt;/sup&gt; This is because:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;The VMs are already configured with the IP addresses / ICI information of the other
hosts in the slice.&lt;/li&gt;
  &lt;li&gt;We’re using &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jax.sharding.AxisType.Explicit&lt;/code&gt; sharding, and our model is relatively
simple. So communication gets handled by the XLA compiler! We don’t actually need to
use &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;pmap&lt;/code&gt; or &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;pjit&lt;/code&gt; anywhere.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;For our purposes, the only part that requires care is data management, particularly data
loading. We’ll describe this below for our simple synthetic data distribution, as well
as the other changes we need to make for distributed training.&lt;/p&gt;

&lt;h3 id=&quot;high-level-paradigm&quot;&gt;High-Level Paradigm&lt;/h3&gt;

&lt;p&gt;In JAX distributed training, each host (or “process”) is physically connected to some
number of chips via PCIe. On &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;v4&lt;/code&gt; TPUs, which we’ve been using, each host
is connected to 4 chips (each chip with 2 cores). When working with TPU VMs, we can
think of one host as corresponding to one VM that we can &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ssh&lt;/code&gt; into; in our tests
previously, we’ve been working on a single host.&lt;/p&gt;

&lt;p&gt;Different hosts are arranged into “slices”. Within a slice, chips are interconnected
via high-bandwidth optical links, letting them communicate relatively efficiently! When
it comes to data movement from the host to the device, though, each host can only
directly communicate with the chips it is physically connected to. These are called the
“addressable” devices of that host.&lt;/p&gt;

&lt;p&gt;With &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jax.distributed&lt;/code&gt;, we can programmatically handle these differences while only
writing a single program! Our program gets run identically on each host, but we have
access to certain functions that can allow us to branch and shard differently depending
on which host we are. Key functions that we’ll see below are:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jax.distributed.initialize()&lt;/code&gt;, which synchronizes information about all
interconnected devices among the different hosts;&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jax.process_index()&lt;/code&gt;, which returns a different integer for each host (allowing
branching);&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jax.Array.addressable_shards&lt;/code&gt;, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jax.Array.is_fully_addressable&lt;/code&gt; (etc.), which can be
used to programmatically determine &lt;em&gt;which devices&lt;/em&gt; can directly modify a given array.
For example, if &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt; is a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jax.Array&lt;/code&gt; that is fully replicated across devices (with
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jax.P()&lt;/code&gt; as its partition spec), then &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x.is_fully_addressable&lt;/code&gt; is &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;False&lt;/code&gt; in a
multi-host setting! We want to avoid having our program modify an array in a way that
is not compatible with either its sharding or the devices with respect to which it is
addressable.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;For more detailed information, see Chapter 2 of the TPU scaling book &lt;a class=&quot;citation&quot; href=&quot;#scaling-book&quot;&gt;(Austin et al., 2025)&lt;/a&gt; as well as relevant sections of the JAX documentation &lt;a class=&quot;citation&quot; href=&quot;#jax-multi-process&quot;&gt;(JAX Team, 2025)&lt;/a&gt;.&lt;/p&gt;

&lt;h3 id=&quot;setting-up-the-mesh-and-initializing-the-distributed-backend&quot;&gt;Setting up the Mesh and Initializing the Distributed Backend&lt;/h3&gt;

&lt;p&gt;One multi-host footgun I learned about the hard way is that setting a global mesh with
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jax.set_mesh&lt;/code&gt;, as we have been doing so far, makes it hard to create arrays that exist
&lt;em&gt;only on the local process&lt;/em&gt; (more on that below). So we switch to creating a mesh from
our config options at the start of the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;main&lt;/code&gt; function, then instantiate it as a context
manager when we need it. This turns out to be pretty easy—we don’t need to pass the
mesh to all parts of our model’s forward pass and add it as an additional parameter to
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;out_sharding&lt;/code&gt; (etc.)!&lt;/p&gt;

&lt;p&gt;Here’s the new setup (which we saw part of above):&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# in config.py
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;mesh_from_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;make_mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh_shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh_axis_names&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;nf&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh_shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;AxisType&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Explicit&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,),&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;


&lt;span class=&quot;c1&quot;&gt;# in train.py
&lt;/span&gt;&lt;span class=&quot;nd&quot;&gt;@hydra.main&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;version_base&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config_path&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;Path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;configs&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;absolute&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;resolve&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()),&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config_name&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;base_config&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;main&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;try&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Launch distributed and register configs
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distributed&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;initialize&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree_util&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;register_static&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;except&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;RuntimeError&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# This implies the distributed backend has already been initialized
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;pass&lt;/span&gt;
    &lt;span class=&quot;nf&quot;&gt;config_post_init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;mesh_from_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;p&gt;That’s it! We just have to be sure to set the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;config.mesh_shape&lt;/code&gt; parameter
appropriately when we launch a job (see below), as well as all our desired sharding
settings. Our config makes simple data parallel training over the entire list of devices
the default.&lt;/p&gt;

&lt;h3 id=&quot;modifications-for-data-loading&quot;&gt;Modifications for Data Loading&lt;/h3&gt;

&lt;p&gt;We refactor the data generation code into &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;data.py&lt;/code&gt;. At a high level, here’s what we end
up doing in &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;train.py&lt;/code&gt; to set up the data loading:&lt;sup id=&quot;fnref:4&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:4&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;5&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# in train.py
# ...
# Randomness
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seed&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;key_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_val&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_eval&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Data
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_iter&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;get_distributed_batch_iter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_dataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;We abstract the previous dataset-specific code into this factory function, which calls
the appropriate sub-functions (letting us reuse this for val/eval):&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# in data.py
&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;dataset_dataloader_factory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DatasetConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;match&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;case&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DatasetName&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MNIST&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;nf&quot;&gt;return &lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;load_mnist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dataloader_without_replacement&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;case&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DatasetName&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NUMBER_STAIRCASE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;nf&quot;&gt;return &lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                &lt;span class=&quot;nf&quot;&gt;make_number_staircase_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;dataloader_with_replacement&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;case&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DatasetName&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;SHAKESPEARE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;nf&quot;&gt;return &lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                &lt;span class=&quot;nf&quot;&gt;load_shakespeare&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;dataloader_with_replacement&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;c1&quot;&gt;# Helper to just get the global batch iterator, if we don't need the local data/loader
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;get_distributed_batch_iter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dataset_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DatasetConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dataloader&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;dataset_dataloader_factory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dataset_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;get_dataset_on_device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;dataloader&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dataset_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Distributed data loading has a large number of potential footguns; the JAX Advanced
Guides listing has a useful entry about this &lt;a class=&quot;citation&quot; href=&quot;#jax-distributed-data-loading&quot;&gt;(JAX Team, 2025)&lt;/a&gt;,
and see also the article about explicit sharding we used in previous blogs &lt;a class=&quot;citation&quot; href=&quot;#jax-explicit-sharding&quot;&gt;(JAX Team, 2025)&lt;/a&gt;.
In the present setting, we take advantage of the fact that we can load the entire
dataset into memory on each host. The following code does this:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;make_number_staircase_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DatasetConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Simple data sequence!
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# Small, so mega replicate for ease
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;text&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;012345678987654321&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;c&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;text&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;int32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;Xtr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Xdev&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Xte&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;split_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;match&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;case&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;SplitType&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;TRAIN&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Xtr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;case&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;SplitType&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;TEST&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Xte&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;case&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;SplitType&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;VAL&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Xdev&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;split_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_fraction&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dev_fraction&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;num_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;Xtr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_fraction&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;Xdev&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;
        &lt;span class=&quot;nf&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_fraction&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_fraction&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dev_fraction&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;Xte&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_fraction&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dev_fraction&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Xtr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Xdev&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Xte&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The tricky part here is to note that as configured, the above code generates the data
array so that it is &lt;em&gt;locally addressable&lt;/em&gt; by every process: each host creates its own
(identical) &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;data&lt;/code&gt; array. More precisely, if we evaluate &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;data.is_locally_addressable&lt;/code&gt;,
it will be &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;True&lt;/code&gt;. This must be contrasted with the case of a &lt;em&gt;fully replicated&lt;/em&gt; data
array, which we don’t want here—our approach will be to piece together a global batch
out of local batches on each device, and having a locally addressable data array is
(currently) the canonical way to do this!&lt;/p&gt;

&lt;p&gt;Next, we need to make sure each host has access to an independent and identically
distributed stream of batches. This gets handled by the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;get_dataset_on_device&lt;/code&gt; and
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;dataloader&lt;/code&gt; functions.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# in data.py
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# Simple next-token prediction dataloader for text training:
# - Each host has the whole dataset (data input)
# - Random sampling without replacement to draw batches (consumes key)
# - Targets are the next token (expect data to be (num_data,) shape)
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;dataloader_with_replacement&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DatasetConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;tuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DataloaderOutputType&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;num_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;fold_in&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;process_index&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;it&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;count&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;fold_in&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;offsets&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;randint&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;global_batch_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;//&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;process_count&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),),&lt;/span&gt;
            &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;num_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;seqs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;at&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;offsets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)].&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;targets&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;at&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;offsets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)].&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;yield&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seqs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;targets&lt;/span&gt;


&lt;span class=&quot;c1&quot;&gt;# Helper to map a dataloader with make_array_from_process_local_data
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;get_dataset_on_device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dataloader&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DataloaderOutputType&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Mesh&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DataloaderOutputType&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;make_array_from_process_local_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;nc&quot;&gt;NamedSharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;dataloader&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;We &lt;strong&gt;must&lt;/strong&gt; pass the mesh we created earlier here in order to correctly create the
global batch out of the local shards that the dataloader generates. Explicit sharding
seems to propagate to outputs, so the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;data.at&lt;/code&gt; operations return locally-addressable
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;batch&lt;/code&gt; arrays on each process, which &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;make_array_from_process_local_data&lt;/code&gt; pieces
together into the global array. To generate independent local batches on each process,
we are sure to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;fold_in&lt;/code&gt; the process index to the RNG in order to create a distinct
stream of random offsets.&lt;/p&gt;

&lt;p&gt;This quick-and-dirty data loader, which does not shuffle nor guarantee sampling without
replacement for batches across processes, is fine for our purposes on this synthetic
dataset. It should also be okay for use with much larger datasets, when training for a
relatively small number of tokens. But it’s worth keeping in mind that something more
robust would be needed in general.&lt;sup id=&quot;fnref:5&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:5&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;6&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

&lt;h3 id=&quot;remaining-cleanup&quot;&gt;Remaining Cleanup&lt;/h3&gt;

&lt;p&gt;All that’s left is to judiciously insert our mesh context manager at key points of the
training loop. We saw above (when discussing optimizers) that we need to add it to the
model initialization call.
Other than that, we just need to add it to the model call in our training loop:&lt;/p&gt;
&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# in train.py
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_steps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;next&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_iter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;set_mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;cur_metrics&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;train_step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;For simplicity’s sake, we also wrap our sampling code with it.&lt;/p&gt;

&lt;h3 id=&quot;launching-everything&quot;&gt;Launching Everything&lt;/h3&gt;

&lt;p&gt;Given our DIY ethos, we handle launching distributed jobs on our TPU VM with an ad-hoc
launch script:&lt;/p&gt;

&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c&quot;&gt;# run.sh&lt;/span&gt;

&lt;span class=&quot;nv&quot;&gt;TPU_NAME&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;nv&quot;&gt;$1&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt;
&lt;span class=&quot;nb&quot;&gt;shift
&lt;/span&gt;&lt;span class=&quot;nv&quot;&gt;SSH_FLAGS&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s1&quot;&gt;'-A -o ForwardAgent=yes'&lt;/span&gt;
&lt;span class=&quot;nv&quot;&gt;COMMANDS&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;if [ ! -d &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&quot;&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;baremetal-gpt&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&quot;&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt; ]; then git clone git@github.com:sdbuch/baremetal-gpt; fi &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;
    &amp;amp;&amp;amp; export HYDRA_FULL_ERROR=1 &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;
    &amp;amp;&amp;amp; export WANDB_ENTITY='&lt;/span&gt;&lt;span class=&quot;nv&quot;&gt;$WANDB_ENTITY&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;' &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;
    &amp;amp;&amp;amp; export WANDB_API_KEY='&lt;/span&gt;&lt;span class=&quot;nv&quot;&gt;$WANDB_API_KEY&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;' &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;
    &amp;amp;&amp;amp; export HF_TOKEN='&lt;/span&gt;&lt;span class=&quot;nv&quot;&gt;$HF_TOKEN&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;' &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;
    &amp;amp;&amp;amp; cd baremetal-gpt &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;
    &amp;amp;&amp;amp; git fetch &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;
    &amp;amp;&amp;amp; git checkout -f main &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;
    &amp;amp;&amp;amp; git pull &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;
    &amp;amp;&amp;amp; uv sync --extra tpu &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;
    &amp;amp;&amp;amp; uv run train &lt;/span&gt;&lt;span class=&quot;nv&quot;&gt;$@&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt;

gcloud compute tpus tpu-vm ssh &lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;nv&quot;&gt;$TPU_NAME&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt; &lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;
  &lt;span class=&quot;nt&quot;&gt;--ssh-flag&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;nv&quot;&gt;$SSH_FLAGS&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt; &lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;
  &lt;span class=&quot;nt&quot;&gt;--command&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;nv&quot;&gt;$COMMANDS&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt; &lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;
  &lt;span class=&quot;nt&quot;&gt;--worker&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;all
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;For logging purposes, we need the Weights &amp;amp; Biases default entity and your API key to be
exported as environment variables locally. We use Github for code synchronization, and
provide the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tpu&lt;/code&gt; extra to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;uv&lt;/code&gt; when building since we’re running on a TPU VM.&lt;/p&gt;

&lt;p&gt;Here’s an example call on the local machine, for data-parallel training:&lt;/p&gt;

&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;./deploy/run.sh tpu-v4-32 &lt;span class=&quot;nv&quot;&gt;num_vocab&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;10 &lt;span class=&quot;nv&quot;&gt;num_layers&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;4 &lt;span class=&quot;nv&quot;&gt;num_steps&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;300 &lt;span class=&quot;nv&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;3e-4 &lt;span class=&quot;s1&quot;&gt;'mesh_shape=[16]'&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;For this to work well, when I first create the TPU VM,&lt;sup id=&quot;fnref:3&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:3&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;7&lt;/a&gt;&lt;/sup&gt; I make sure the
following startup script is run on each host in the configuration:&lt;/p&gt;

&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c&quot;&gt;# startup.sh&lt;/span&gt;
curl &lt;span class=&quot;nt&quot;&gt;-LsSf&lt;/span&gt; https://astral.sh/uv/install.sh | sh  &lt;span class=&quot;c&quot;&gt;# Download uv&lt;/span&gt;
&lt;span class=&quot;nb&quot;&gt;sudo cp&lt;/span&gt; /root/.local/bin/uv /usr/local/bin/uv
&lt;span class=&quot;nb&quot;&gt;sudo cp&lt;/span&gt; /root/.local/bin/uvx /usr/local/bin/uvx
ssh-keyscan &lt;span class=&quot;nt&quot;&gt;-H&lt;/span&gt; github.com &lt;span class=&quot;o&quot;&gt;&amp;gt;&amp;gt;&lt;/span&gt; /etc/ssh/ssh_known_hosts  &lt;span class=&quot;c&quot;&gt;# Code copying&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;This installs &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;uv&lt;/code&gt; and adds &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;github.com&lt;/code&gt; to the known hosts, so that I can connect to
Github via &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ssh&lt;/code&gt; (agent forwarding is enabled for authentication). If we don’t add it to
the known hosts list, it will ask to confirm connection and the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;gcloud&lt;/code&gt; command will
fail.&lt;/p&gt;

&lt;p&gt;That’s it! This gives a pretty easy minimal setup for distributed training with JAX on
TPU VMs.&lt;/p&gt;

&lt;h2 id=&quot;adding-logging&quot;&gt;Adding Logging&lt;/h2&gt;

&lt;p&gt;Custom logging is easy to add, as well. We add an &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Enum&lt;/code&gt; to the config to select the
logger, then add two basic loggers: one for printing to the command line, and one for
using Weights and Biases.&lt;/p&gt;

&lt;p&gt;There are two things to keep in mind here:&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;In a distributed setting, by default we only want one process (say, the one with
index &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;0&lt;/code&gt;) to log metrics.&lt;sup id=&quot;fnref:6&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:6&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;8&lt;/a&gt;&lt;/sup&gt;&lt;/li&gt;
  &lt;li&gt;We want to set up appropriate host-device pipelining, so as we alluded to in the
previous post (taking a trick from the JAX training cookbook) we should log the
metrics from the &lt;em&gt;previous&lt;/em&gt; batch on every step. We can incorporate automatic
buffering into our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Logger&lt;/code&gt; class to do this.&lt;/li&gt;
&lt;/ol&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# config.py
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;LoggerType&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Enum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;PRINT&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;WANDB&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;wandb&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;


&lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# loggers.py
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;logger_factory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;logger_type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;LoggerType&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;match&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;logger_type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;case&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;LoggerType&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;PRINT&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;PrintLogger&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;case&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;LoggerType&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;WANDB&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;WandbLogger&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;get_run_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;base&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;sh&quot;&gt;&quot;&quot;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Prepend current UTC time to config-specified run name&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;timestamp&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;datetime&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;now&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;timezone&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;utc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;strftime&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;%Y%m%d-%H%M%S&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;run_name&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;timestamp&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;base&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;run_name&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Logger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;sh&quot;&gt;&quot;&quot;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Logger base class. Implements a depth-1 buffer for device-host pipelining&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&quot;&quot;&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;project_name&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;project_name&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;run_name&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;get_run_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;run_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;isinstance&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DictConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config_dict&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;OmegaConf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;to_container&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;resolve&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config_dict&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;asdict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config_dict&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;prev_log_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__enter__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__exit__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exc_type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exc_value&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;traceback&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;log&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;log_dict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;dict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;pass&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;warn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;message&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;pass&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;buffer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;log_dict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;dict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;dict&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;|&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;prev_log_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;buffered_metrics&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;log_dict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;prev_log_data&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;buffered_metrics&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;flush_buffer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;prev_log_data&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;log&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;({})&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;prev_log_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;PrintLogger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Logger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nf&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__enter__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Project: &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;project_name&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Run: &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;run_name&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;__enter__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;log&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;log_dict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;dict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nf&quot;&gt;if &lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;buffered_dict&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;buffer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;log_dict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt;
        &lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;metric&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;: &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;val&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metric&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;val&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;buffered_dict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;items&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sep&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\t&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;warn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;message&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;message&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;WandbLogger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Logger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nf&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;is_master&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;process_index&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__enter__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;is_master&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;wandb&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;project&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;project_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;run_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;__enter__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__exit__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exc_type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exc_value&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;traceback&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;is_master&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;wandb&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;finish&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;__exit__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;exc_type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exc_value&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;traceback&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;log&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;log_dict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;dict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;is_master&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;nf&quot;&gt;if &lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;buffered_dict&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;buffer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;log_dict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
                &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;wandb&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;log&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;buffered_dict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;warn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;message&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;is_master&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;message&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Then we simply wrap our training loop with this context manager:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# train.py
# ...
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Logger&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;logger_factory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;logger_type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Logger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;logger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;enumerate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_iter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;set_mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;metrics&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;train_step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;logger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;log&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;metrics&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;|&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;})&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_dataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_steps&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;break&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h1 id=&quot;actual-data&quot;&gt;Actual Data&lt;/h1&gt;

&lt;p&gt;Now that we have a reasonable backend, let’s get it working on some text data, and try
it at GPT-2 scale! Within the context of the above refactors, we just need to write the
data preprocessing code and the dataloader. We also need to add evaluation code – we’ve
done this in the Github, but not detailed it above.&lt;/p&gt;

&lt;p&gt;We’ll do a minimal demo with Andrej Karpathy’s Tiny Shakespeare dataset, available in
the nanoGPT repo. We adapt Karpathy’s script for downloading and tokenizing the dataset,
making use of the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tiktoken&lt;/code&gt; library for this purpose:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# data/download_and_tokenize_tiny_shakespeare.py
&lt;/span&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;

&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax.numpy&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;requests&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tiktoken&lt;/span&gt;

&lt;span class=&quot;sh&quot;&gt;&quot;&quot;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;From https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare/prepare.py&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&quot;&quot;&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# download the tiny shakespeare dataset
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_file_path&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;join&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;dirname&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__file__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;input.txt&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;exists&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_file_path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;data_url&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;open&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_file_path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;w&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;encoding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;utf-8&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;write&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;requests&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data_url&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;text&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;open&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_file_path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;r&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;encoding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;utf-8&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;read&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;train_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.9&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;val_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.9&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# encode with tiktoken gpt2 bpe
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;enc&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tiktoken&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;get_encoding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;gpt2&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;train_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;enc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;encode_ordinary&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;val_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;enc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;encode_ordinary&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;val_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;train has &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt; tokens&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;val has &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;val_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt; tokens&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# export to bin files
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;int32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;val_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;val_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;int32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;save&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;data/tiny-shakespeare/train.npy&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;save&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;data/tiny-shakespeare/val.npy&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;val_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# train.bin has 301,966 tokens
# val.bin has 36,059 tokens
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;In a distributed setting, we run this on each host, with a command like:&lt;/p&gt;

&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c&quot;&gt;# deploy/download_shakespeare.sh&lt;/span&gt;

&lt;span class=&quot;nv&quot;&gt;TPU_NAME&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;nv&quot;&gt;$1&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt;
&lt;span class=&quot;nb&quot;&gt;shift
&lt;/span&gt;&lt;span class=&quot;nv&quot;&gt;SSH_FLAGS&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s1&quot;&gt;'-A -o ForwardAgent=yes'&lt;/span&gt;
&lt;span class=&quot;nv&quot;&gt;COMMANDS&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;if [ ! -d &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&quot;&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;baremetal-gpt&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&quot;&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt; ]; then git clone git@github.com:sdbuch/baremetal-gpt; fi &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;
    &amp;amp;&amp;amp; export HF_TOKEN='&lt;/span&gt;&lt;span class=&quot;nv&quot;&gt;$HF_TOKEN&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;' &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;
    &amp;amp;&amp;amp; cd baremetal-gpt &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;
    &amp;amp;&amp;amp; git fetch &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;
    &amp;amp;&amp;amp; git checkout -f main &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;
    &amp;amp;&amp;amp; git pull &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;
    &amp;amp;&amp;amp; uv sync --extra tpu &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;
    &amp;amp;&amp;amp; uv run data/tiny-shakespeare/download_and_tokenize_tiny_shakespeare.py &lt;/span&gt;&lt;span class=&quot;nv&quot;&gt;$@&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt;

gcloud compute tpus tpu-vm ssh &lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;nv&quot;&gt;$TPU_NAME&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt; &lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;
  &lt;span class=&quot;nt&quot;&gt;--ssh-flag&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;nv&quot;&gt;$SSH_FLAGS&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt; &lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;
  &lt;span class=&quot;nt&quot;&gt;--command&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;nv&quot;&gt;$COMMANDS&lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;&lt;/span&gt; &lt;span class=&quot;se&quot;&gt;\&lt;/span&gt;
  &lt;span class=&quot;nt&quot;&gt;--worker&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;all
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Then the data is ready for loading. We write a python loader for the data to go with our
previously-written factor, using the same &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;dataloader_with_replacement&lt;/code&gt; as for the
synthetic data:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# data.py
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;load_shakespeare&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DatasetConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;path&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;load&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;path&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;value&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;.npy&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Then running the experiment is simply a matter of setting up the right experimental
config. We use the following, which runs almost all the evaluation code we wrote (to
test it!):&lt;/p&gt;

&lt;div class=&quot;language-yaml highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# configs/experiment/tiny-shakespeare.yaml&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# @package _global_&lt;/span&gt;

&lt;span class=&quot;na&quot;&gt;defaults&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt;
  &lt;span class=&quot;pi&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;/dataset/shakespeare@train_dataset&lt;/span&gt;

&lt;span class=&quot;na&quot;&gt;val_log_interval&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;m&quot;&gt;250&lt;/span&gt;

&lt;span class=&quot;na&quot;&gt;train_dataset&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt;
  &lt;span class=&quot;na&quot;&gt;num_steps&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;m&quot;&gt;3000&lt;/span&gt;

&lt;span class=&quot;na&quot;&gt;val_list&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt;
  &lt;span class=&quot;pi&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;na&quot;&gt;dataset&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;SHAKESPEARE&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;${train_dataset.path}&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;VAL&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;seq_len&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;${train_dataset.seq_len}&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# At inference time, this functions like prompt size&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;global_batch_size&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;${train_dataset.global_batch_size}&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;num_steps&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;m&quot;&gt;200&lt;/span&gt;
    &lt;span class=&quot;na&quot;&gt;evaluator&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;NLL&lt;/span&gt;
  &lt;span class=&quot;pi&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;na&quot;&gt;dataset&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;SHAKESPEARE&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;${train_dataset.path}&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;VAL&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;seq_len&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;m&quot;&gt;128&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# At inference time, this functions like prompt size&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;global_batch_size&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;m&quot;&gt;16&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;use_splash&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;False&lt;/span&gt;
    &lt;span class=&quot;na&quot;&gt;evaluator&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;AUTOREGRESSIVE_ROLLOUTS&lt;/span&gt;

&lt;span class=&quot;na&quot;&gt;eval_list&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt;
  &lt;span class=&quot;pi&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;na&quot;&gt;dataset&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;SHAKESPEARE&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;${train_dataset.path}&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;VAL&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;seq_len&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;m&quot;&gt;128&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# At inference time, this functions like prompt size&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;global_batch_size&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;m&quot;&gt;16&lt;/span&gt;
      &lt;span class=&quot;na&quot;&gt;use_splash&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;False&lt;/span&gt;
    &lt;span class=&quot;na&quot;&gt;evaluator&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;AUTOREGRESSIVE_ROLLOUTS&lt;/span&gt;

&lt;span class=&quot;na&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt;
  &lt;span class=&quot;na&quot;&gt;transformer_type&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;DISCRETE&lt;/span&gt;
  &lt;span class=&quot;na&quot;&gt;max_seq_len&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;${train_dataset.seq_len}&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# no longer than train&lt;/span&gt;
  &lt;span class=&quot;na&quot;&gt;is_causal&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;True&lt;/span&gt;
  &lt;span class=&quot;na&quot;&gt;num_vocab&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;m&quot;&gt;50257&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# gpt2 tokenizer&lt;/span&gt;

&lt;span class=&quot;na&quot;&gt;optimizer&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt;
  &lt;span class=&quot;na&quot;&gt;weight_decay&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;1e-1&lt;/span&gt;


&lt;span class=&quot;na&quot;&gt;inference&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt;
  &lt;span class=&quot;na&quot;&gt;tokenizer&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;GPT2&lt;/span&gt;
  &lt;span class=&quot;na&quot;&gt;max_tokens_to_generate&lt;/span&gt;&lt;span class=&quot;pi&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;m&quot;&gt;128&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The tokenizer parameters, as suggested, are just used for rollouts at inference time.
We use some Hydra interpolation in the config above to avoid retyping the same values.
Here’s the overall training code we run (which is pretty short!):&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# train.py
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# ...
&lt;/span&gt;&lt;span class=&quot;nd&quot;&gt;@hydra.main&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;version_base&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config_path&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;Path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;configs&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;absolute&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;resolve&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()),&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config_name&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;base_config&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;main&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;try&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Launch distributed and register configs
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distributed&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;initialize&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree_util&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;register_static&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;except&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;RuntimeError&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# This implies the distributed backend has already been initialized
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;pass&lt;/span&gt;
    &lt;span class=&quot;nf&quot;&gt;config_post_init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;mesh_from_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;Logger&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;logger_factory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;logger_type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Randomness
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seed&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;key_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_val&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_eval&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Data
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;batch_iter&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;get_distributed_batch_iter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_dataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Initialize state, configure forward pass and optimization
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;set_mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cache_params&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;CacheParams&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;enabled&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;make_splash_kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_dataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;spec&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;model_spec&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;opt_update&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;opt_update_factory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;weight_decay_mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;bool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;spec&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;nd&quot;&gt;@partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jit&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;donate_argnums&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;train_step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;TrainState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;loss_fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;targets&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;vmap&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                &lt;span class=&quot;nf&quot;&gt;partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_params&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;)(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;compute_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;value&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;logprobs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;log_softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;take_along_axis&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;logprobs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;targets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[...,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

        &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;value_and_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;loss_fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;grad_clipped&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;global_grad_norm&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;grad_norm_and_clip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;update__opt_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;nf&quot;&gt;partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;opt_update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;grad_clipped&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;opt_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;weight_decay_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Transpose the output tree to get update tree and state tree
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;update__opt_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;new_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;TrainState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt_state&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;opt_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;n&quot;&gt;metrics&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;batch_loss&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;grad_norm&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;global_grad_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metrics&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;new_state&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Training loop
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Logger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;logger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;do_evals&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;eval_loop&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;logger&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;logger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;enumerate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_iter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;set_mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;metrics&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;train_step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;logger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;log&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;metrics&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;|&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;})&lt;/span&gt;
            &lt;span class=&quot;nf&quot;&gt;if &lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;step&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;val_log_interval&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
                &lt;span class=&quot;c1&quot;&gt;# Calculate val metrics
&lt;/span&gt;                &lt;span class=&quot;n&quot;&gt;key_val&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;do_evals&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key_val&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;val_list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_dataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_steps&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
                &lt;span class=&quot;k&quot;&gt;break&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# Run evals (testing)
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;key_eval&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;do_evals&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key_eval&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;eval_list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;eval_loop&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;eval_list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;EvaluationConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;logger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Logger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;logger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;flush_buffer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;evaluation&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;eval_list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;make_splash_kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;evaluation&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_e&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;batch_iter&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;get_distributed_batch_iter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;evaluation&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dataset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;evaluation_fn&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;evaluator_factory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;evaluation&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;metrics&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;evaluation_fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_e&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_iter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;logger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;log&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;metrics&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;|&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;})&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;logger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;flush_buffer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Check out the &lt;a href=&quot;https://github.com/sdbuch/baremetal-gpt/blob/33b76c0e4412913717c7bce7b50308fd4ae13781/src/bmgpt/evaluators.py&quot;&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;evaluators.py&lt;/code&gt; file in the Github&lt;/a&gt; for the full structure of the evaluation code.&lt;/p&gt;

&lt;h2 id=&quot;checking-the-results&quot;&gt;Checking the Results&lt;/h2&gt;

&lt;p&gt;The default model we train has 12 layers, embedding dimension 768, 12 heads per
attention layer, and a MLP expansion factor of 4. With GPT2 vocabulary size, this ends
up at about 160M parameters. Our Tiny Shakespeare config follows nanoGPT, setting
the defaults as:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# configs/dataset/shakespeare.yaml
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;SHAKESPEARE&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;data/tiny-shakespeare&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;TRAIN&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;256&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;global_batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;So we process 32K tokens per batch. The total number of tokens in the training set is
only 301,966! So without additional regularization, we expect to overfit this dataset
very rapidly with our model.&lt;/p&gt;

&lt;p&gt;We test this out, via the launch command&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;deploy&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;run&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sh&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tpu&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;v4&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;32&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;deploy&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;v4&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;16&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;experiment&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tiny&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shakespeare&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;on a 16-chip TPU v4 VM (4 hosts). Following the experiment config we saw above, we train
for 3000 steps, and log val metrics every 250 steps.&lt;/p&gt;

&lt;figure&gt;
    &lt;img src=&quot;/assets/blog/tiny-shakespeare-batch-loss.png&quot; alt=&quot;Batch loss, first 500 steps of Tiny Shakespeare training&quot; /&gt;
    &lt;figcaption&gt;Plot of batch loss (negative log-likelihood) versus training step for
    the first 500 steps of Tiny Shakespeare training. The model overfits the training
    data.&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;We see that the model rapidly overfits the training data. It takes about 10 steps per
epoch, so the plot shows about 50 epochs worth of training – significantly more than we
encounter in standard language model training datasets.&lt;/p&gt;

&lt;figure&gt;
    &lt;img src=&quot;/assets/blog/tiny-shakespeare-grad-norm.png&quot; alt=&quot;Gradient norm, first 500 steps of Tiny Shakespeare training&quot; /&gt;
    &lt;figcaption&gt;Plot of gradient norm versus training step for
    the first 500 steps of Tiny Shakespeare training. Training is stable, even without
    learning rate scheduling.&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;The model trains stably (we are performing gradient clipping to norm 1.0 by default),
even without learning rate scheduling. Scheduling would improve the initial stability;
since we are heavily overfitting, there does not seem to be any need for LR cooldown at
the end of training.&lt;/p&gt;

&lt;figure&gt;
    &lt;img src=&quot;/assets/blog/tiny-shakespeare-val-loss.png&quot; alt=&quot;Validation loss, logged every 250 steps, over 3000 steps of training&quot; /&gt;
    &lt;figcaption&gt;Plot of validation loss (negative log likelihood) versus training step for
    the 3000 step Tiny Shakespeare training trajectory. The model is heavily
    overfit.&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;Our evaluation code lets us log various metrics, including the validation set negative
log likelihood every 250 steps of training. We verify that the model is significantly
overfit to the training set – the validation loss actually increases monotonically!
Taking a look at some of the trained model’s rollouts verifies that they are indeed not
generalizing very well (although they’re quite fun to read):&lt;/p&gt;

&lt;pre&gt;&lt;code class=&quot;language-txt&quot;&gt;################################
Prompt:
################################
 father? Gentles, methinks you frown:
And wherefore gaze this goodly company,
As if they saw some wondrous monument,
Some comet or unusual prodigy?

BAPTISTA:
Why, sir, you know this is your wedding-day:
First were we sad, fearing you would not come;
Now sadder, that you come so unprovided.
Fie, doff this habit, shame to your estate,
An eye-sore to our solemn festival!

TRANIO:
And tells us, what occasion of import
Hath all
################################
Generated text:
################################
 the castle wall:
And wheither is come to me, and sit by fortune's side,
The right is not! God's my son,
Immoderate but a little distressed widow;
For joyful wife's son is no weary way.

ROMEO:
What say'st thou, that didst me yet?

MERCUTIO:
Good dream of mine own.

ROMEO:
Give me that mattock and the wrenching iron.
Hold, take this letter; early, and my lord

Till then be urged to my advancement?

Nurse:
&lt;/code&gt;&lt;/pre&gt;

&lt;p&gt;The model seems to be regurgitating tokens from Romeo and Juliet, with some random
digressions sprinkled in!
For these rollouts, we generate up to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;max_seq_len&lt;/code&gt; tokens (based on the training config; for us,
it’s 256) off of a 128 token prompt. This avoids out-of-distribution RoPE bugs and KV
cache errors.&lt;/p&gt;

&lt;h1 id=&quot;conclusions&quot;&gt;Conclusions&lt;/h1&gt;

&lt;p&gt;That’s all for our base infra buildout! Some of the key takeaways include:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;Correct configuration of splash attention Pallas kernel;&lt;/li&gt;
  &lt;li&gt;Flexible infrastructure for training and inference across different datasets and
evaluation types;&lt;/li&gt;
  &lt;li&gt;Basic training improvements (with fairly robust infrastructure, except perhaps for our
model spec) that will let us experiment with new optimizers in the future!&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;As we’ve gone from the notebook-based code in the previous blog to the full repository,
we’ve seen along the way that building reasonably general experiment infrastructure
requires a bit of a different approach than writing one-off research code. It also
requires a decent amount of time! There’s a good chance a lot of this can be automated
in a fast-paced environment with LLMs – I have not used LLMs for writing any code in
these blogs – but writing the code by hand is hard to beat for learning and for
precision of implementation (as &lt;a href=&quot;https://www.youtube.com/watch?v=lXUZvyajciY&quot;&gt;Karpathy suggested in his recent conversation with
Dwarkesh!&lt;/a&gt;). As a small piece of evidence
of this, we’ve ended up with a very concise and readable &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;train.py&lt;/code&gt; script: not much
more than 150 lines!&lt;/p&gt;

&lt;p&gt;Some remaining infrastructure-related tweaks we should add in the future:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;Model checkpointing: we’ll add this at some point soon, likely just using Orbax.&lt;/li&gt;
  &lt;li&gt;Further parallelism configurations, in particular basic FSDP.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Look forward to less infrastructure and closer-to-research experimentation in future
iterations of this blog series.&lt;/p&gt;

&lt;h1 id=&quot;acknowledgments&quot;&gt;Acknowledgments&lt;/h1&gt;

&lt;p&gt;Thanks to the &lt;a href=&quot;https://sites.research.google/trc/about/&quot;&gt;TRC program&lt;/a&gt; for
compute support.&lt;/p&gt;

&lt;div class=&quot;footnotes&quot; role=&quot;doc-endnotes&quot;&gt;
  &lt;ol&gt;
    &lt;li id=&quot;fn:7&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;The reason we actually check for &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;None&lt;/code&gt; is that we generally need to call our
model with different configurations (e.g., when doing autoregressive inference, or
when evaluating with a different parallelism configuration), and it is suboptimal to
create an entirely different &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Config&lt;/code&gt; object for each of these, or to manually edit
the global &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Config&lt;/code&gt; object for these as we run. So instead, we check for &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;None&lt;/code&gt;, and
just pass this as the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;kernel&lt;/code&gt; parameter to the model as we need to. &lt;a href=&quot;#fnref:7&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:1&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;For example, we typically only apply weight decay to matrix-like parameters in the
network (e.g., not to biases). &lt;a href=&quot;#fnref:1&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:8&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;Note that even this approach is not perfect, as for example the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;w_out&lt;/code&gt; matrix
associated to the attention output projection actually does reduce over one of the
dimensions marked as non-matrix. We should revisit this for future improvements
(i.e., Muon-type optimizers!). &lt;a href=&quot;#fnref:8&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:2&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;At least in our setting, where we’re not using multislice operation. I think the
distributed backend is just as easy to set up for multislice, though! &lt;a href=&quot;#fnref:2&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:4&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;Creating the training/val/test split &lt;em&gt;on each distributed process&lt;/em&gt; is probably not
a good idea, even though we seed each host’s RNG with the same seed, and generate the
splits with the same key. It won’t matter for our purposes here because the
synthetic data distribution is so simple (even if every process had a different
training split, it would be okay, because there isn’t a material distinction between
the training and test set for this distribution). We’ll fix this with proper
processing of data when we move to text data. &lt;a href=&quot;#fnref:4&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:5&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;We could use &lt;a href=&quot;https://github.com/google/grain&quot;&gt;Grain&lt;/a&gt; for this purpose! &lt;a href=&quot;#fnref:5&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:3&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;Here, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tpu-v4-32&lt;/code&gt;, for a four-host &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;v4&lt;/code&gt; setup: the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;32&lt;/code&gt; denotes the number of
cores, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;chips = cores / 2&lt;/code&gt;, and the number of chips dictates the mesh shape. &lt;a href=&quot;#fnref:3&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:6&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;For more complex debugging, say in interpretability settings, we need to add a bit
more custom code in order to have each process log its respective metrics (say,
per-device gradients before all reduce, etc.). &lt;a href=&quot;#fnref:6&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
  &lt;/ol&gt;
&lt;/div&gt;</content>

      
      
      
      
      

      
        <author>
            <name>Sam Buchanan</name>
          
            <email>sdbuchanan@berkeley.edu</email>
          
          
        </author>
      

      
        <category term="jax" />
      
        <category term="transformers" />
      

      

      
        <summary type="html">In this third post in a series on programming with JAX for training transformers, we pick up where we left off at the end of the previous post. We’ll extend our previous small transformer model to the NanoGPT scale, incorporating actual text data, distributed training, and a real optimizer, and perform some basic evaluation! We’ll end up with a reasonable substrate for building future experiments on top of.</summary>
      

      
      
    </entry>
  
  
  
    <entry>
      
      <title type="html">A Faster Manifold Muon with ADMM</title>
      
      
      <link href="http://sdbuchanan.com/blog/manifold-muon/" rel="alternate" type="text/html" title="A Faster Manifold Muon with ADMM" />
      
      <published>2025-10-20T00:00:00+00:00</published>
      <updated>2025-10-20T00:00:00+00:00</updated>
      <id>http://sdbuchanan.com/blog/manifold-muon</id>
      <content type="html" xml:base="http://sdbuchanan.com/blog/manifold-muon/">&lt;p&gt;In a recent blog post, Jeremy
Bernstein (Thinking Machines) gave an algorithm for optimizing matrix-valued
parameters—say, weights in a transformer—that ensures that both the update
directions and the parameters themselves are well-conditioned &lt;a class=&quot;citation&quot; href=&quot;#bernstein2025manifolds&quot;&gt;(Bernstein, 2025)&lt;/a&gt;.
This algorithm, called &lt;em&gt;manifold Muon&lt;/em&gt;, has an inner loop that is relatively slow to
converge, and requires some hyperparameter tuning to get right.&lt;/p&gt;

&lt;figure&gt;
    &lt;img src=&quot;/assets/blog/admm-dual-ascent.png&quot; alt=&quot;Comparing
    dual ascent to ADMM rate of convergence&quot; /&gt;
    &lt;figcaption&gt;Comparison of the default manifold Muon inner loop solver (dual ascent)
    to a solver based on ADMM, for one weight matrix in a neural network being trained.
    ADMM converges much more rapidly, with minimal hyperparameter tuning!&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;In this post, we’ll describe a modification of the manifold Muon algorithm that greatly
speeds up convergence—by more than a factor of two in wall-clock time in our experiments
on 1x H100! To realize this, we apply a venerable algorithm from convex optimization
called ADMM (the “alternating direction method of multipliers”) to the optimization
problem at the heart of the manifold Muon update.&lt;/p&gt;

&lt;p&gt;The rest of the post will re-introduce the manifold Muon subproblem, give some intuition
for why the &lt;a href=&quot;https://github.com/thinking-machines-lab/manifolds&quot;&gt;existing
implementation&lt;/a&gt; leaves some room for
improvement on the table, and walk through the ADMM derivation for manifold Muon. We’ll
verify the improved algorithm on the same toy model used in the linked Github
repository. We hope this improvement will facilitate larger-scale experimentation
with manifold Muon!&lt;/p&gt;

&lt;p&gt;If you haven’t already, you’re encouraged to read &lt;a href=&quot;https://thinkingmachines.ai/blog/modular-manifolds/&quot;&gt;Jeremy’s blog
post&lt;/a&gt; on manifold Muon before
proceeding. It gives an excellent technical overview and derivation of the method, and
accessible, well-written intuition on the deep links between optimization and geometry.
See also Jianlin Su’s work &lt;a class=&quot;citation&quot; href=&quot;#kexuefm-11221&quot;&gt;(苏剑林, 2025)&lt;/a&gt;, which gives a more thorough
technical derivation.&lt;/p&gt;

&lt;h1 id=&quot;background-manifold-muon&quot;&gt;Background: Manifold Muon&lt;/h1&gt;

&lt;p&gt;Manifold Muon is a first-order method for optimizing a matrix-valued parameter $\vW \in
\bbR^{m \times n}$ (we’ll assume $m \geq n$ throughout), such that it satisfies the
quadratic constraint $\vW^\top \vW = \vI$. Given a putative update direction $\vG \in
\bbR^{m \times n}$—say, back-propagated gradients of some loss with respect to
$\vW$—and a step size $\eta &amp;gt; 0$, we choose an update $\vA$ by solving the
convex optimization problem&lt;sup id=&quot;fnref:2&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:2&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;1&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

\[\begin{equation}\label{eq:manifold-muon}
    \min_{\vA \in \bbR^{m \times n}}\, \ip{\vA}{\vG}
    \quad
    \text{subject to}
    \quad
    \norm{\vA} \leq \eta,
    \quad
    \vA^\top \vW + \vW^\top \vA  = \Zero,
    \end{equation}\]

&lt;p&gt;then perform the update&lt;/p&gt;

\[\vW \mapsto \mathrm{msign} \left( \vW + \vA  \right) ,\]

&lt;p&gt;where $\mathrm{msign}(\vX) = \vX (\vX^\top \vX)^{-1/2}$ is the matrix sign function.&lt;sup id=&quot;fnref:1&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:1&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;2&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

&lt;p&gt;The manifold Muon subproblem \eqref{eq:manifold-muon} does not
have a closed-form solution in general, so Jeremy proposed
an iterative solver for it, based on the equivalent &lt;em&gt;unconstrained&lt;/em&gt; formulation&lt;sup id=&quot;fnref:3&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:3&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;3&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

\[\begin{equation}\label{eq:manifold-muon-dual}
    \min_{\vLambda \in \bbR^{n \times n}}\,
    \norm*{\vG + \vW\left( \vLambda + \vLambda^\top \right)}_*.
    \end{equation}\]

&lt;p&gt;One can obtain an optimal solution $\vA_\star$ to \eqref{eq:manifold-muon} from an
optimal solution $\vLambda_\star$ to \eqref{eq:manifold-muon-dual} via $\vA_\star =
-\eta \mathop{\mathrm{msign}}( \vG + \vW(\vLambda_\star + \vLambda_\star^\top))$.&lt;/p&gt;

&lt;p&gt;In practice, we solve \eqref{eq:manifold-muon-dual} with subgradient descent:
we pick an iteration count $K \in \bbN$ and step sizes $\gamma_k &amp;gt; 0$ for $k = 1, \dots, K$,
and repeatedly perform the update&lt;/p&gt;

\[\begin{equation}\label{eq:subgradient}
\begin{split}
    \vGamma_{k} &amp;amp;=
    \mathop{\mathrm{msign}}( \vG + \vW(\vLambda_k + \vLambda_k^\top))^\top \vW
    + \vW^\top\mathop{\mathrm{msign}}( \vG + \vW(\vLambda_k + \vLambda_k^\top)); \\
    \vLambda_{k+1} &amp;amp;= \vLambda_k
    -
    \gamma_k \vGamma_k.
    \end{split}
\end{equation}\]

&lt;p&gt;The step sizes $\gamma_k$ are chosen to decay as $k \to K$ to guarantee convergence.
The gradient $\vGamma_k$ is computed efficiently on modern accelerators using
Newton-Schulz iterations, or optimal algorithms such as the “Polar Express” algorithm
&lt;a class=&quot;citation&quot; href=&quot;#Amsel2025-qj&quot;&gt;(Amsel et al., 2025)&lt;/a&gt;, as in the &lt;a href=&quot;https://github.com/thinking-machines-lab/manifolds/&quot;&gt;Github for Jeremy’s
blog&lt;/a&gt;.&lt;/p&gt;

&lt;h1 id=&quot;drawbacks-of-subgradient-descent&quot;&gt;Drawbacks of Subgradient Descent&lt;/h1&gt;

&lt;p&gt;The subgradient descent iteration \eqref{eq:subgradient} is guaranteed to converge, but
it does so &lt;strong&gt;very slowly&lt;/strong&gt;!
In fact, its worst-case rate of convergence is $O(1/\sqrt{k})$ &lt;a class=&quot;citation&quot; href=&quot;#book-wright-ma&quot;&gt;(Wright &amp;amp; Ma, 2022)&lt;/a&gt;,
in contrast to the standard $O(1/k)$ rate for gradient descent on smooth objectives.&lt;/p&gt;

&lt;p&gt;Even worse, this is &lt;em&gt;typical&lt;/em&gt; behavior for this algorithm, not just worst-case.
It stems from the fact that subgradient descent is an algorithm for &lt;em&gt;nonsmooth&lt;/em&gt;
problems, and nonsmoothness leads to a ‘chattering’ behavior of iterates. Here is a
quick simulation on a canonical example, the absolute value function $\abs{\spcdot}$ in
dimension one, to illustrate:&lt;sup id=&quot;fnref:5&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:5&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;4&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;math&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sqrt&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;init&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;2.0&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;iters&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;grads&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;it&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;init&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;it&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;it&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;it&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;sqrt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;iters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;it&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;grads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;figure&gt;
    &lt;img src=&quot;/assets/blog/subgradient_descent.gif&quot; alt=&quot;Subgradient
    descent on the absolute value function&quot; /&gt;
    &lt;figcaption&gt;Subgradient descent on the absolute value function, with $1/\sqrt{k}$ LR
    decay.&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;The nuclear norm objective in \eqref{eq:manifold-muon-dual} is nondifferentiable
at every point $\vLambda$ for which $\vG + \vW(\vLambda + \vLambda^\top)$ is not full
rank. Even if the iterates $\vLambda_k$ never witness such a point of
nondifferentiability, the lack of continuity of the gradient near these points leads to
the same ‘chattering’ behavior of the iterates as in the toy example above.&lt;/p&gt;

&lt;p&gt;Fortunately, we can do better, with a bit of algorithmic cleverness (and without too
much more compute)!&lt;/p&gt;

&lt;h1 id=&quot;a-faster-algorithm-from-admm-with-splitting&quot;&gt;A Faster Algorithm from ADMM with Splitting&lt;/h1&gt;

&lt;p&gt;ADMM is an optimization algorithm for solving problems of the form&lt;/p&gt;

\[\begin{equation}\label{eq:admm}
\begin{split}
    \min_{\vx \in \bbR^{m}, \vz \in \bbR^{n}}\, &amp;amp;f(\vx) + g(\vz) \\
    \text{subject to}\,\, &amp;amp;\vA \vx + \vB \vz = \vc,
    \end{split}
\end{equation}\]

&lt;p&gt;where $f$ and $g$ are convex.
It was originally developed in the 1970s, and was re-popularized among a generation or two
of researchers by the excellent monograph of Boyd, Parikh et al. &lt;a class=&quot;citation&quot; href=&quot;#Boyd2011-yb&quot;&gt;(Boyd et al., 2011)&lt;/a&gt;,
which we recommend as a general reference beyond the rapid overview we give here.
The key empirical property that makes ADMM very well suited for use in solving the
manifold Muon subproblem \eqref{eq:manifold-muon-dual} is its &lt;em&gt;rapid initial
convergence&lt;/em&gt; to a solution of reasonable quality—the algorithm’s worst-case
convergence rate is only $O(1/k)$, but this initial rapid convergence often allows it to
be stopped much earlier.&lt;/p&gt;

&lt;p&gt;The ADMM algorithm solves problems of the form \eqref{eq:admm} as follows. For a penalty
parameter $\rho &amp;gt; 0$, we define the “augmented Lagrangian”&lt;/p&gt;

\[\sL_{\rho}(\vx, \vz, \vlambda)
    = f(\vx) + g(\vz) + \ip{\vlambda}{\vA \vx + \vB \vz - \vc} + \frac{\rho}{2} \norm*{
        \vA \vx + \vB \vz - \vc
    }_2^2,\]

&lt;p&gt;then iteratively perform the updates&lt;/p&gt;

\[\begin{equation}\label{eq:admm-update}
\begin{split}
    \vx_{k+1} &amp;amp;= \argmin_{\vx}\, \sL_{\rho}(\vx, \vz_k, \vlambda_k) \\
    \vz_{k+1} &amp;amp;= \argmin_{\vz}\, \sL_{\rho}(\vx_{k+1}, \vz, \vlambda_k) \\
    \vlambda_{k+1} &amp;amp;= \vlambda_k + \rho \left( \vA \vx_{k+1} + \vB \vz_{k+1} - \vc \right).
    \end{split}
\end{equation}\]

&lt;p&gt;The algorithm is reminiscent of a block coordinate descent method, with gradient ascent
on the dual variable $\vlambda$ (with step size $\rho$).
It converges for any setting of the penalty parameter $\rho$. Empirically, this
parameter can be tuned to accelerate convergence.&lt;/p&gt;

&lt;p&gt;ADMM is ideally suited for problems where the minimization subproblems in the $\vx$ and
$\vz$ updates in \eqref{eq:admm-update} have closed-form solutions. For the manifold
Muon subproblem \eqref{eq:manifold-muon-dual}, we need to pass two related obstacles:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;How do we apply ADMM to \eqref{eq:manifold-muon-dual}? It doesn’t seem to have the
ADMM-amenable structure of \eqref{eq:admm}.&lt;/li&gt;
  &lt;li&gt;Having resolved this, do the minimization subproblems in the ADMM update
\eqref{eq:admm-update} applied to \eqref{eq:manifold-muon-dual} have closed-form
solutions that can be computed efficiently on GPUs/TPUs?&lt;/li&gt;
&lt;/ol&gt;

&lt;h2 id=&quot;the-immortal-splitting-trick&quot;&gt;The Immortal Splitting Trick&lt;/h2&gt;

&lt;p&gt;To resolve the first issue, we apply a trick known as &lt;em&gt;variable splitting&lt;/em&gt;. We add an
auxiliary variable $\vX \in \bbR^{m \times n}$ to the manifold Muon subproblem
\eqref{eq:manifold-muon-dual} along with an extra constraint to give the equivalent
problem&lt;/p&gt;

\[\begin{equation}\label{eq:manifold-muon-split}
\begin{split}
    \min_{\vLambda \in \bbR^{n \times n}, \vX \in \bbR^{m \times n}}\, &amp;amp;\norm*{\vX}_* \\
    \text{subject to}\,\, &amp;amp;\vX = \vG + \vW(\vLambda + \vLambda^\top).
\end{split}
\end{equation}\]

&lt;p&gt;This may initially seem counterintuitive—didn’t we work hard to reduce the original
manifold Muon problem \eqref{eq:manifold-muon} to an unconstrained form? But the
advantage is that this problem is now amenable to ADMM, since it has the form
of \eqref{eq:admm} after we re-interpret matrices as vectors!&lt;/p&gt;

&lt;p&gt;For this problem, given that we’ve already used $\vLambda$ for a dual variable, we’ll
write the ADMM dual variable as $\vOmega \in \bbR^{m \times n}$, giving the augmented
Lagrangian&lt;/p&gt;

\[\begin{equation}\label{eq:augmented-lagrangian}
\begin{split}
    \sL_{\rho}(\vLambda, \vX, \vOmega)
    = \norm*{\vX}_* &amp;amp;+ \ip{\vOmega}{\vX - \vW(\vLambda + \vLambda^\top) - \vG}  \\
    &amp;amp;+ \frac{\rho}{2} \norm*{ {\vX - \vW(\vLambda + \vLambda^\top) - \vG} }_{\frob}^2.
    \end{split}
\end{equation}\]

&lt;p&gt;Now we can work on the ADMM subproblems for this augmented Lagrangian.&lt;/p&gt;

&lt;h2 id=&quot;admm-subproblems-for-manifold-muon&quot;&gt;ADMM Subproblems for Manifold Muon&lt;/h2&gt;

&lt;p&gt;In deriving ADMM algorithms for specific problems, it’s often useful to “complete the
square” to more easily see the solutions to the minimization subproblems. We rewrite
\eqref{eq:augmented-lagrangian} as&lt;/p&gt;

\[\begin{equation}
\begin{split}
    \sL_{\rho}(\vLambda, \vX, \vOmega)
    &amp;amp;= \norm*{\vX}_* + \frac{\rho}{2}\norm*{\frac{1}{\rho}\vOmega}_{\frob}^2
    - \frac{\rho}{2}\norm*{\frac{1}{\rho}\vOmega}_{\frob}^2 +
    \ip{\vOmega}{\vX - \vW(\vLambda + \vLambda^\top) - \vG}  \\
    &amp;amp;\qquad\qquad+ \,\frac{\rho}{2} \norm*{ {\vX - \vW(\vLambda + \vLambda^\top) - \vG} }_{\frob}^2 \\
    &amp;amp;= \norm*{\vX}_*
    - \frac{\rho}{2}\norm*{\frac{1}{\rho}\vOmega}_{\frob}^2 +
      \frac{\rho}{2} \norm*{ \frac{1}{\rho} \vOmega + {\vX - \vW(\vLambda + \vLambda^\top) - \vG} }_{\frob}^2.
    \end{split}
\end{equation}\]

&lt;p&gt;Now the dependence on $\vLambda$ is an easy-to-parse quadratic function, and the
dependence on $\vX$ is also simple.&lt;/p&gt;

&lt;h3 id=&quot;vlambda-subproblem&quot;&gt;$\vLambda$ Subproblem&lt;/h3&gt;

&lt;p&gt;To solve&lt;/p&gt;

\[\argmin_{\vLambda}\, \sL_{\rho}(\vLambda, \vX_k, \vOmega_k)
    =
    \argmin_{\vLambda}\,
    \norm*{ \frac{1}{\rho} \vOmega_k + {\vX_k - \vW(\vLambda + \vLambda^\top) - \vG} }_{\frob}^2,\]

&lt;p&gt;it’s helpful to define $\Sym(\vLambda) = \half ( \vLambda + \vLambda^\top)$. This is
a linear operator, and in fact is an &lt;a href=&quot;https://en.wikipedia.org/wiki/Projection_(linear_algebra)#Projection_matrix&quot;&gt;orthogonal
projection&lt;/a&gt;.
Then because $\vW^\top \vW = \vI$, we have&lt;/p&gt;

\[\argmin_{\vLambda}\, \sL_{\rho}(\vLambda, \vX_k, \vOmega_k)
    =
    \argmin_{\vLambda}\,
    \norm*{ \Sym(\vLambda) - \frac{1}{2}\vW^\top\left(
    \frac{1}{\rho} \vOmega_k + \vX_k - \vG
    \right)
    }_{\frob}^2.\]

&lt;p&gt;It then follows that&lt;/p&gt;

\[\frac{1}{2}\Sym\left(\vW^\top\left(
    \frac{1}{\rho} \vOmega_k + \vX_k - \vG
    \right)\right)
    \in
    \argmin_{\vLambda}\, \sL_{\rho}(\vLambda, \vX_k, \vOmega_k),\]

&lt;p&gt;and that this is the minimum Euclidean norm solution to this problem.&lt;/p&gt;

&lt;p&gt;In particular, we can solve the $\vLambda$ subproblem in closed form, using only matrix
multiplies and transposes!&lt;/p&gt;

&lt;h3 id=&quot;vx-subproblem&quot;&gt;$\vX$ Subproblem&lt;/h3&gt;

&lt;p&gt;We have to solve&lt;/p&gt;

\[\begin{split}
    \argmin_{\vX}\, &amp;amp;\sL_{\rho}(\vLambda_{k+1}, \vX, \vOmega_k)
    =
    \\
    &amp;amp;\qquad\frac{1}{\rho}\norm*{\vX}_*
    + \frac{1}{2} \norm*{ \vX + \frac{1}{\rho} \vOmega_k - \vW(\vLambda_{k+1} + \vLambda_{k+1}^\top) - \vG }_{\frob}^2.
\end{split}\]

&lt;p&gt;This optimization problem actually has a closed-form solution called &lt;em&gt;singular value
thresholding&lt;/em&gt;, which can be derived by considering the SVD of everything that doesn’t
depend on $\vX$ in the quadratic term.
Namely, if we define $\vM = \vW(\vLambda_{k+1} + \vLambda_{k+1}^\top) + \vG - \frac{1}{\rho} \vOmega_k$
for concision and let $\vU$, $\diag(\vs)$, $\vV$
denote an economy SVD of $\vM$,
then it can be shown that &lt;a class=&quot;citation&quot; href=&quot;#book-wright-ma&quot;&gt;(Wright &amp;amp; Ma, 2022)&lt;/a&gt;&lt;/p&gt;

\[\argmin_{\vX}\, \sL_{\rho}(\vLambda_{k+1}, \vX, \vOmega_k)
    =
    \vU \diag \left( \max \set*{ \vs - \frac{1}{\rho}, 0 } \right)\vV^\top,\]

&lt;p&gt;with the natural broadcasting interpretation. Thus, the $\vX$ subproblem can be solved
in closed-form with a singular value decomposition!
Unfortunately, this is not enough for our purposes: SVD algorithms are not as
well-suited for scaling up on GPUs/TPUs as algorithms that rely purely on matrix
multiplications.&lt;/p&gt;

&lt;p&gt;Fortunately, it turns out that we can compute singular value thresholding &lt;em&gt;using only
the $\mathrm{msign}$ algorithm&lt;/em&gt;, letting us exploit efficient algorithms for this
operation that have already been developed for Muon!&lt;sup id=&quot;fnref:6&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:6&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;5&lt;/a&gt;&lt;/sup&gt; There are likely to be better
custom solutions for this problem, but this quick-and-dirty approach will get us on our
way. Here’s how:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;
    &lt;p&gt;First, we make a simple observation: &lt;em&gt;if $s_i - 1/\rho \geq 0$ for every singular
value $s_i$, then&lt;/em&gt;&lt;/p&gt;

\[\vU \diag \left( \max \set*{ \vs - \frac{1}{\rho}, 0 } \right)\vV^\top
 =
 \vM - \frac{1}{\rho} \vU \vV^\top.\]

    &lt;p&gt;Since $\vU\vV^\top = \mathrm{msign}(\vM)$, we’re done!&lt;/p&gt;
  &lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;We cannot rely on every singular value being larger than $1/\rho$ in general. But,
observe that those singular values that are smaller simply need to be set to $0$,
which we can achieve by other means.
More precisely, since $n \leq m$, consider the matrix
$\mathrm{msign}(\vM^\top \vM - \frac{1}{\rho^2} \vI)$.
Let $\vV_{\geq}$ denote the submatrix of (right) singular vectors of $\vM$
associated to singular values that are greater than or equal to $1/\rho$, and
$\vV_{&amp;lt;}$ the complementary submatrix. Then one has&lt;/p&gt;

\[\mathrm{msign}(\vM^\top \vM - \frac{1}{\rho^2} \vI)
=
\vV_{\geq} \vV_{\geq}^\top - \vV_{&amp;lt;} \vV_{&amp;lt;}^\top.\]

    &lt;p&gt;Each of the (unsigned) factors in this sum is an orthogonal projection matrix. Hence
we have&lt;/p&gt;

\[\frac{1}{2} \left( \vI + \mathrm{msign}(\vM^\top \vM - \frac{1}{\rho^2} \vI) \right)
=
\frac{1}{2} \left(\vV_{\geq} \vV_{\geq}^\top + \left(\vI - \vV_{&amp;lt;} \vV_{&amp;lt;}^\top \right) \right)
=
\vV_{\geq} \vV_{\geq}^\top.\]

    &lt;p&gt;The matrix $\vV_{\geq} \vV_{\geq}^\top$ &lt;em&gt;projects onto the right singular vectors of
$\vM$ associated to singular values at least as large as $1/\rho$&lt;/em&gt;. This is exactly
what we need to implement the zero clipping operation in singular value
thresholding!&lt;/p&gt;
  &lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;In summary, we’ve shown how to compute singular value thresholding using only matrix
sign operations. Zooming out, we’ve shown that the solution to the $\vX$ subproblem in
our manifold Muon ADMM algorithm is&lt;/p&gt;

\[\argmin_{\vX}\, \sL_{\rho}(\vLambda_{k+1}, \vX, \vOmega_k)
    =
    \frac{1}{2}\left( \vM - \frac{1}{\rho} \mathrm{msign}(\vM) \right)
    \left( \vI + \mathrm{msign}\left(\vM^\top \vM - \frac{1}{\rho^2} \vI\right) \right).\]

&lt;p&gt;This can be computed using only matrix multiplications and matrix sign operations, which
themselves are efficiently approximated using matrix multiplications. Almost as
hardware-efficient as we could hope for!&lt;/p&gt;

&lt;h2 id=&quot;admm-for-manifold-muon-final-algorithm&quot;&gt;ADMM for Manifold Muon: Final Algorithm&lt;/h2&gt;

&lt;p&gt;Plugging our derivations above into the basic ADMM skeleton
\eqref{eq:admm-update} leads to the following final algorithm, here in math
form:&lt;sup id=&quot;fnref:7&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:7&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;6&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

\[\begin{equation}\label{eq:admm-update-mfld-muon}
\begin{split}
    \vLambda_{k+1} &amp;amp;=
    \frac{1}{2}\Sym\left(\vW^\top\left(
    \frac{1}{\rho} \vOmega_k + \vX_k - \vG
    \right)\right) \\
    \vM_{k+1} &amp;amp;=
    2\vW\vLambda_{k+1} + \vG - \frac{1}{\rho} \vOmega_k
    \\
    \vX_{k+1} &amp;amp;=
    \frac{1}{2}\left( \vM_{k+1} - \frac{1}{\rho} \mathrm{msign}(\vM_{k+1}) \right)
    \left( \vI + \mathrm{msign}\left(\vM_{k+1}^\top \vM_{k+1} - \frac{1}{\rho^2} \vI\right) \right) \\
    \vOmega_{k+1} &amp;amp;= \vOmega_k + \rho \left( \vX_{k+1} - 2\vW\vLambda_{k+1} - \vG \right).
    \end{split}
\end{equation}\]

&lt;p&gt;When we’re done iterating, just as we did for dual ascent,
we return $\vA_\star =
-\eta \mathop{\mathrm{msign}}( \vG + 2\vW\vLambda_\star)$.&lt;/p&gt;

&lt;p&gt;In code form, implementing the ADMM update we have derived above in PyTorch and
combining it with the remainder of the manifold Muon algorithm leads to the
following. This code is available in &lt;a href=&quot;https://github.com/sdbuch/thinky-manifolds&quot;&gt;my fork of Jeremy’s
Github repository&lt;/a&gt;, and it’s
a modification of Jeremy’s original code. Disclaimer: it has not been heavily
optimized, beyond choosing sensible defaults for the hyperparameters via
experiments we’ll share below!&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;nd&quot;&gt;@torch.no_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;manifold_muon_admm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;G&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;eta&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;steps&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;rho&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;4.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Ensure that W and G are both tall matrices
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;should_tranpose&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;should_tranpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;G&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;G&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Initialize the lagrangian, slack, and dual variable
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;Lambda&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.25&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;G&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;G&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;G&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Lambda&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;Omega&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;zeros_like&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Solve the dual problem with ADMM to find the update direction A
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;steps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Update for Lambda (orthonormal least-squares solve)
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;P&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mT&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;rho&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Omega&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;G&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;Lambda_upd&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.25&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;P&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mT&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Update for X (singular value thresholding)
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;B&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;G&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Lambda_upd&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;rho&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Omega&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;eye&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;eye&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;P_pos&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;eye&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;msign&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mT&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;B&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;rho&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;eye&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;X_upd&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;rho&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;msign&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;P_pos&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Update for Omega (dual ascent)
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;Omega_upd&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Omega&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;rho&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_upd&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Lambda_upd&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;G&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;Lambda&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Omega&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Lambda_upd&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X_upd&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Omega_upd&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Calculate A from final ADMM solution
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# (at convergence, G + 2 * W @ Lambda \approx X)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;A&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;msign&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;G&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Lambda&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Descend on the primal problem
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;new_W&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;eta&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Retract to the manifold
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;new_W&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;msign&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;new_W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Restore the shape of the solution and return
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;new_W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;should_tranpose&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;new_W&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;In this code, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;msign&lt;/code&gt; calls Jeremy’s implementation of the “Polar Express”
algorithm &lt;a class=&quot;citation&quot; href=&quot;#Amsel2025-qj&quot;&gt;(Amsel et al., 2025)&lt;/a&gt;, which is a fairly aggressive algorithm for
efficiently approximating the matrix sign function. It’s worth keeping this in
mind, as a source of numerical inexactness in our subsequent experiments: for
ADMM to be useful to us, it has to be able to cope well with these
approximations!
There are also some ‘reparameterizations’ of learning rates relative to the
notation we’ve used above, but the algorithm is equivalent.&lt;/p&gt;

&lt;h2 id=&quot;intuition-for-the-speedup-over-dual-ascent&quot;&gt;Intuition for the Speedup over Dual Ascent&lt;/h2&gt;

&lt;p&gt;As a final sanity check, we can use our derivations above to visualize a toy
case of the manifold Muon ADMM algorithm, and contrast it with the pathological
behavior of subgradient descent we visualized before.
If we pick $m = n = 1$, $W = 1$ and $G = 0$ and simplify with algebra, the
$\Lambda_k$ and $\Omega_k$ variables actually become redundant, and we end up
with the very simple equivalent ADMM update&lt;/p&gt;

\[X_{k+1} = \sign(X_k) \max \set{\abs{X_k} - 1/\rho, 0}.\]

&lt;p&gt;It should be intuitively clear that this iteration converges stably to zero: we reduce
the magnitude of $X_k$ by $1/\rho$ until it drops below zero, then it stays there!
On an intuitive level, this is because the $\vX$ update in the ADMM algorithm
we derived for manifold Muon is &lt;em&gt;continuous&lt;/em&gt;, whereas the subgradient descent update is
&lt;em&gt;discontinuous&lt;/em&gt;.&lt;sup id=&quot;fnref:10&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:10&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;7&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

&lt;figure&gt;
    &lt;img src=&quot;/assets/blog/admm_descent.gif&quot; alt=&quot;ADMM
    (simplified) on the absolute value function&quot; /&gt;
    &lt;figcaption&gt;ADMM on the absolute value function with splitting (after algebraic
    simplifications), with $\rho = 16.0$.&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;h1 id=&quot;experimental-results&quot;&gt;Experimental Results&lt;/h1&gt;

&lt;p&gt;We base our experiments on the experimental setup in the &lt;a href=&quot;https://github.com/thinking-machines-lab/manifolds/&quot;&gt;Thinking Machines
Github repository&lt;/a&gt; for the
blog post: a simple three-layer MLP is trained on CIFAR-10 for five epochs, and
per-epoch timing information is reported.&lt;/p&gt;

&lt;p&gt;Here’s a summary of what we’ll look at below:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;We perform a sweep over the ADMM penalty parameter $\rho$ and the number of
ADMM iterations for each manifold Muon solve. This gives us a sensible choice of
hyperparameters for ADMM for this setup, and lets us verify performance.&lt;/li&gt;
  &lt;li&gt;Since the model being trained is so small-scale, we
also plot full dual ascent and ADMM inner loop loss curves for each iteration
of the outer loop and eyeball these to assess how fast each algorithm is
converging.&lt;sup id=&quot;fnref:8&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:8&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;8&lt;/a&gt;&lt;/sup&gt; We believe this ranking will be more indicative of properties of
the algorithms that persist at scale (since raw test error may not be a suitable
proxy, based on what we’ll see below).&lt;/li&gt;
  &lt;li&gt;Having convinced ourselves of sensible ADMM performance at different choices
of inner loop iterations, we look at timing!&lt;/li&gt;
  &lt;li&gt;We finally look at some other relevant metrics for ADMM, such as feasibility
(how well are we satisfying the splitting constraint).&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The code for these experiments is available in &lt;a href=&quot;https://github.com/sdbuch/thinky-manifolds/tree/admm&quot;&gt;the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;admm&lt;/code&gt; branch of my fork
of Jeremy’s Github repository&lt;/a&gt;.&lt;/p&gt;

&lt;h2 id=&quot;parameter-sweep&quot;&gt;Parameter Sweep&lt;/h2&gt;

&lt;p&gt;The main takeaway of this experiment is that the model training might be a bit too
small scale! In general, for this model, manifold Muon leads to a fairly similar
test error for all methods examined, including running zero inner loop
iterations. This suggests a need to examine behaviors further at larger scale.
Nevertheless, we sweep over values in $\set{0, 5, 10, 25, 50, 100}$ for the
number of inner loop iterations to run, and for ADMM, we sweep over $\rho \in
\set{2, 4, 8, 16}$. We also run dual ascent as a baseline, with the same
initialization as in Jeremy’s repo.&lt;sup id=&quot;fnref:9&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:9&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;9&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

&lt;figure&gt;
    &lt;img src=&quot;/assets/blog/accuracy_comparison_rhos.png&quot; alt=&quot;Plot
    of test accuracy vs. total inner loop step count for different ADMM penalty
    parameters (or dual ascent).&quot; /&gt;
    &lt;figcaption&gt; Plot
    of test accuracy vs. total inner loop step counts for different ADMM penalty
    parameters (or dual ascent). &lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;From this plot, we see that test performance is not directly related with the number of
steps we run the inner optimizer for. To try to have some generalizable signal towards
larger-scale experiments, we will focus instead on assessing convergence of the inner
loop with different representative hyperparameter values. Based on this plot, we select
$(K, \rho) = (10, 4.0)$, $(100, \mathrm{DA})$ (for dual ascent), and as a baseline
$(100, 8.0)$. This will give us a sample of ADMM’s speed of convergence, as well as its
performance run for a large number of iterations.&lt;/p&gt;

&lt;h2 id=&quot;inner-loop-convergence&quot;&gt;Inner Loop Convergence&lt;/h2&gt;

&lt;p&gt;We compare the inner loop losses for the hyperparameter settings above.&lt;/p&gt;

&lt;p&gt;We can choose to visualize the inner loop loss for one of three possible weight
matrices (first layer, second layer, third layer), for any one of the outer loop
iterations (we do a total of five epochs over the data).&lt;/p&gt;

&lt;figure&gt;
    &lt;img src=&quot;/assets/blog/fc1_weight_e0_d00_dual_loss.png&quot; alt=&quot;Dual ascent losses
    for layer 1 weight, epoch 0, step 0&quot; /&gt;
    &lt;img src=&quot;/assets/blog/fc2_weight_e0_d00_dual_loss.png&quot; alt=&quot;Dual ascent losses
    for layer 2 weight, epoch 0, step 0&quot; /&gt;
    &lt;img src=&quot;/assets/blog/fc3_weight_e0_d00_dual_loss.png&quot; alt=&quot;Dual ascent losses
    for layer 3 weight, epoch 0, step 0&quot; /&gt;
    &lt;figcaption&gt;Dual ascent losses for each weight matrix in the network, for the
    nonzero-steps hyperparameter choices shown in the table above. Traces can be
    identified based on where they terminate.&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;Notice that both nonsmoothness and initial scale of the learning rate for dual ascent
being too large both seem to play a role in the slow convergence of dual ascent in the
plots above, which are evaluated for each weight matrix at the first epoch and first
iteration of optimization. Dual ascent needs all 100 iterations to reach a good loss
value, whereas the ADMM algorithms converge much more quickly. The ADMM algorithms only
require setting the penalty parameter, and are not very sensitive to it in this
experiment, as long as it is sufficiently large: we found smaller values of the penalty
parameter to lead to instabilities in inner loop loss convergence, although outer loop
performance was not significantly affected.&lt;sup id=&quot;fnref:11&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:11&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;10&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

&lt;p&gt;For the ADMM iteration counts, ten iterations is aggressive, but reasonable: only for
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;fc3&lt;/code&gt; does it end up a bit far from convergence, and this seems to be due to the setting
of the penalty parameter we’ve chosen more than anything else.&lt;/p&gt;

&lt;p&gt;These plots are more or less representative for all iterations of training. In later
iterations, the loss decrease is often less dramatic than at the first iteration, but
trends persist: dual ascent is slower, ten ADMM iterations is aggressive, but not
unreasonable (see below).&lt;/p&gt;

&lt;figure&gt;
    &lt;img src=&quot;/assets/blog/fc1_weight_e1_d42_dual_loss.png&quot; alt=&quot;Dual ascent losses
    for layer 1 weight, epoch 1, step 42&quot; /&gt;
    &lt;figcaption&gt;Dual ascent losses for weights in the first layer of the network, for
    the nonzero-steps hyperparameter choices shown in the table above. Traces can be
    identified based on where they terminate.&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;h2 id=&quot;timing-and-other-diagnostics&quot;&gt;Timing and Other Diagnostics&lt;/h2&gt;

&lt;p&gt;We profile the code with the pre-implemented timing commands in the original repository.
It would be good to do this in a more rigorous setting, also after optimizing the code a
bit more (e.g., adding compilation to unroll and fuse the inner loop where possible).
But this test gives us a rough sense of how much we stand to improve.&lt;/p&gt;

&lt;p&gt;Here are outputs for the dual ascent algorithm (100 iterations), run on a 1x H100 VM.&lt;/p&gt;

&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;ubuntu@flaky-muskox-1:~/thinky-manifolds&lt;span class=&quot;nv&quot;&gt;$ &lt;/span&gt;uv run src/main.py
Training with: manifold_muon
Epochs: 5 &lt;span class=&quot;nt&quot;&gt;---&lt;/span&gt; LR: 0.1
Epoch 1, Loss: 1.6895257502186054, Time: 11.8445 seconds
Epoch 2, Loss: 1.4129536565469236, Time: 11.7680 seconds
Epoch 3, Loss: 1.2799630311070656, Time: 11.7151 seconds
Epoch 4, Loss: 1.183229492635143, Time: 11.6794 seconds
Epoch 5, Loss: 1.1131725724862547, Time: 12.1164 seconds
Accuracy of the network on the 10000 &lt;span class=&quot;nb&quot;&gt;test &lt;/span&gt;images: 52.58 %
Accuracy of the network on the 50000 train images: 65.936 %
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;And here are outputs for the ADMM algorithm (10 iterations).&lt;/p&gt;

&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;ubuntu@flaky-muskox-1:~/thinky-manifolds&lt;span class=&quot;nv&quot;&gt;$ &lt;/span&gt;uv run src/main.py &lt;span class=&quot;nt&quot;&gt;--update&lt;/span&gt; manifold_muon_admm
Training with: manifold_muon_admm
Epochs: 5 &lt;span class=&quot;nt&quot;&gt;---&lt;/span&gt; LR: 0.1
Epoch 1, Loss: 1.6916365696459401, Time: 5.1556 seconds
Epoch 2, Loss: 1.4166164446850211, Time: 5.1545 seconds
Epoch 3, Loss: 1.2806039537702287, Time: 5.0617 seconds
Epoch 4, Loss: 1.1837433649569142, Time: 5.0579 seconds
Epoch 5, Loss: 1.1136368123852476, Time: 5.0620 seconds
Accuracy of the network on the 10000 &lt;span class=&quot;nb&quot;&gt;test &lt;/span&gt;images: 52.62 %
Accuracy of the network on the 50000 train images: 66.086 %
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;We see a 2.3x speedup—not bad! We’ve cut the iteration count by a factor of 10, but
ADMM uses more FLOPs per iteration because of the extra variables from splitting. With
further optimizations (e.g., better fusing), it’s likely the ADMM runtime would better
reflect the total iteration count and improve further.&lt;/p&gt;

&lt;p&gt;Another natural question is the extent to which we’re obtaining ADMM iterates
$(\vLambda_\star, \vX_\star)$ that satisfy the desired ADMM constraint $\vX = \vG +
\vW(\vLambda + \vLambda^\top)$. We show a plot below from epoch 0, iteration 0 for the
first layer (other layers/steps are similar, or better). It demonstrates that the
residual norm is always rather small, but still non-negligible before convergence. Ten
iterations of ADMM is not quite converged for this layer, but strikes a reasonable tradeoff between convergence and efficiency.&lt;/p&gt;

&lt;figure&gt;
    &lt;img src=&quot;/assets/blog/fc1_weight_e0_d00_feasibility_residual.png&quot; alt=&quot;Feasibility
    (RMSE) for layer 1, epoch 0, step 0&quot; /&gt;
    &lt;figcaption&gt;Feasibility (RMSE) for the ADMM algorithm, for
    the nonzero-steps hyperparameter choices shown in the table above. Traces can be
    identified based on where they terminate.&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;h1 id=&quot;conclusions&quot;&gt;Conclusions&lt;/h1&gt;

&lt;p&gt;We’ve managed to speed up manifold Muon by a significant amount over the dual ascent
baseline by deriving and implementing ADMM for this problem! ADMM converges faster and
requires less hyperparameter tuning, allowing us to safely use a much smaller number of
inner loop iterations.&lt;/p&gt;

&lt;p&gt;You can download the code and try it yourself &lt;a href=&quot;https://github.com/sdbuch/thinky-manifolds&quot;&gt;from my
fork&lt;/a&gt;. I’m excited to run larger-scale
experiments with this algorithm to see how it works!&lt;/p&gt;

&lt;p&gt;As a closing thought: here’s a quote from Jeremy’s original blog post:&lt;/p&gt;

&lt;blockquote&gt;
  &lt;p&gt;Manifold Muon increased the wall clock time per step compared to AdamW, although this could be improved by running fewer steps of dual ascent or adding momentum to the algorithm and running dual ascent online. Depending on other systems bottlenecks, the overhead may not be an issue.&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;I think the switch from dual ascent to ADMM dovetails with all of these considerations!
In particular, the results on this experiment (highly clustered performance) suggest
that incremental / online manifold Muon might be a natural choice for efficiency’s sake,
and our ADMM algorithm should work out-of-the-box in this regime.&lt;/p&gt;

&lt;h1 id=&quot;cite&quot;&gt;Cite&lt;/h1&gt;

&lt;p&gt;If you found this post or the code useful, please consider citing it:&lt;/p&gt;

&lt;div class=&quot;language-bibtex highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;nc&quot;&gt;@misc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;nl&quot;&gt;buchanan2025mmuonadmm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
  &lt;span class=&quot;na&quot;&gt;author&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;{Buchanan, Sam}&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
  &lt;span class=&quot;na&quot;&gt;title&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;{A Faster Manifold Muon with {ADMM}}&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
  &lt;span class=&quot;na&quot;&gt;year&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;m&quot;&gt;2025&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
  &lt;span class=&quot;na&quot;&gt;howpublished&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;{\url{https://sdbuchanan.com/blog/manifold-muon/}}&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h1 id=&quot;acknowledgments&quot;&gt;Acknowledgments&lt;/h1&gt;

&lt;p&gt;Thanks to &lt;a href=&quot;https://jeremybernste.in/&quot;&gt;Jeremy Bernstein&lt;/a&gt; for helpful discussions.
Thanks to the &lt;a href=&quot;https://sites.research.google/trc/about/&quot;&gt;TRC program&lt;/a&gt;,
&lt;a href=&quot;https://hyperbolic.ai&quot;&gt;Hyperbolic&lt;/a&gt;, and &lt;a href=&quot;https://mithril.ai&quot;&gt;Mithril&lt;/a&gt;
for compute.&lt;/p&gt;

&lt;h1 id=&quot;postscript-other-algorithms-than-admm&quot;&gt;Postscript: Other Algorithms than ADMM&lt;/h1&gt;

&lt;p&gt;After writing this post, it came to my attention that Franz Louis Cesista had also
written about an algorithmic improvement to the baseline dual ascent method for manifold
Muon using the primal-dual hybrid gradient method (PDHG) &lt;a class=&quot;citation&quot; href=&quot;#cesista2025steepestdescentfinsler&quot;&gt;(Cesista, 2025)&lt;/a&gt;. PDHG is sometimes referred to as “linearized ADMM”.&lt;/p&gt;

&lt;p&gt;PDHG for manifold Muon has some similar benefits to the ADMM approach we derived above:
it involves proximal steps to “smooth out” the nonsmooth aspects of the subproblem
\eqref{eq:manifold-muon}.&lt;sup id=&quot;fnref:12&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:12&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;11&lt;/a&gt;&lt;/sup&gt; At the same time, PDHG is classically most commonly used
in situations where the ADMM subproblems are intractable (e.g., in image processing, for
total-variation-regularized image denoising), since it involves three hyperparameters
that require careful tuning in practice &lt;a class=&quot;citation&quot; href=&quot;#Goldstein2015-uo&quot;&gt;(Goldstein et al., 2015)&lt;/a&gt;, as opposed to
than ADMM’s one.&lt;/p&gt;

&lt;p&gt;It would be interesting to run a thorough comparison between PDHG and ADMM for manifold
Muon. The JAX code in Cesista’s blog post is missing some implementations, but could
likely be integrated into our (PyTorch) Github repo with a bit of effort.&lt;/p&gt;

&lt;h1 id=&quot;postscript-other-quadratic-constraints&quot;&gt;Postscript: Other Quadratic Constraints&lt;/h1&gt;

&lt;p&gt;Another recent blog post by Ben
Keigwin, Dhruv Pai, and Nathan Chen at Tilde Research describes how the essential
structure of the manifold Muon algorithm is preserved under replacements of the
constraint $\vW^\top \vW = \vI$ by similar constraints (e.g., constraining only the
diagonal) &lt;a class=&quot;citation&quot; href=&quot;#keigwin2025gramspace&quot;&gt;(Keigwin et al., 2025)&lt;/a&gt;. As a corollary, the ADMM algorithm we derived
above can be applied to constraints from this family with little extra effort! We leave
this as an exercise to
the reader.&lt;/p&gt;

&lt;div class=&quot;footnotes&quot; role=&quot;doc-endnotes&quot;&gt;
  &lt;ol&gt;
    &lt;li id=&quot;fn:2&quot; role=&quot;doc-endnote&quot;&gt;

      &lt;p&gt;Here and below, we write $\ip{\vA}{\vB} = \tr(\vA^\top \vB) = \sum_{i, j} A_{ij}
B_{ij}$ to denote the Frobenius inner product on matrices, which treats them as if
they were vectors,
and $\norm{\spcdot}$ to denote the operator norm (or spectral norm, or Schatten
$\infty$-norm), which is the largest singular value of its (matrix) argument. &lt;a href=&quot;#fnref:2&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:1&quot; role=&quot;doc-endnote&quot;&gt;

      &lt;p&gt;As written, the matrix sign function is only defined for inputs that have no zero
singular values. It turns out that as long as $\vA$ satisfies the tangent constraint
$ \vA^\top \vW + \vW^\top \vA = \Zero$, the update $\vW -\eta \vA$ &lt;em&gt;never&lt;/em&gt; has zero
singular values! This can be derived straightforwardly via contradiction, and
provides a strong motivation for the need to &lt;em&gt;exactly&lt;/em&gt; solve the manifold Muon
subproblem. &lt;a href=&quot;#fnref:1&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:3&quot; role=&quot;doc-endnote&quot;&gt;

      &lt;p&gt;Here and below, we write $\norm{\spcdot}_*$ to denote the nuclear norm (or Schatten
1-norm), which is the sum of the singular values of its (matrix) argument. &lt;a href=&quot;#fnref:3&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:5&quot; role=&quot;doc-endnote&quot;&gt;

      &lt;p&gt;Less naive algorithms for nonsmooth optimization can ameliorate this pathological
behavior, sometimes to a significant extent. See, for example, work of Damek Davis
and collaborators &lt;a class=&quot;citation&quot; href=&quot;#Davis2024-ef&quot;&gt;(Davis et al., 2024)&lt;/a&gt;. These algorithms definitely merit
a closer look in the context of manifold Muon! But we will stick our neck out and
suggest that it might be hard to beat ADMM’s solution quality vs. iteration count
tradeoff, based on the experiments to come—although we would be happy to be proven
wrong. &lt;a href=&quot;#fnref:5&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:6&quot; role=&quot;doc-endnote&quot;&gt;

      &lt;p&gt;This is a relatively straightforward argument, but at the same time, I haven’t
seen it anywhere in the compressed sensing literature previously (and this
literature applied singular value thresholding to &lt;em&gt;almost everything&lt;/em&gt; one could).
If you know of a reference that has applied this previously, please let me know! &lt;a href=&quot;#fnref:6&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:7&quot; role=&quot;doc-endnote&quot;&gt;

      &lt;p&gt;After writing down these updates all together, it becomes clear that
$\vLambda_{k+1}$ is always a symmetric matrix. We use this invariant to
simplify some steps in the algorithm here and below! &lt;a href=&quot;#fnref:7&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:10&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;More precisely, it’s straightforward to prove (using the SVD) that singular value
thresholding is a $1$-Lipschitz continuous function of the input matrix with respect
to the operator norm. &lt;a href=&quot;#fnref:10&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:8&quot; role=&quot;doc-endnote&quot;&gt;

      &lt;p&gt;It may be obvious, but it’s worth repeating: the Manifold Muon runs an
iterative optimization solver after &lt;em&gt;every forward/backward pass&lt;/em&gt; of the
model being trained, for every weight matrix in the model.
This leads to some interesting debugging issues, such
as small differences in the final solution obtained by different manifold
Muon solvers at iteration zero leading to significant differences in the
behavior of subsequent gradients/inner loop solves (even though everything
is convex!). These differences seem to mostly be restricted to inner loop
loss, as final val error does not seem to vary much, but it will be
interesting to study this behavior at larger scales. &lt;a href=&quot;#fnref:8&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:9&quot; role=&quot;doc-endnote&quot;&gt;

      &lt;p&gt;For the sake of comparison, we also use this initialization for our
$\vLambda$ in our ADMM algorithm. &lt;a href=&quot;#fnref:9&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:11&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;This is likely because of the fact that we
compute &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;msign&lt;/code&gt; with an approximate algorithm, leading to unstable behavior when we
aggressively threshold singular values in the $\vX$ update of ADMM. &lt;a href=&quot;#fnref:11&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:12&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;In contrast to ADMM, it proceeds directly on the Lagrangian for
\eqref{eq:manifold-muon}, and hence involves proximal steps on the (characteristic
function of the) operator norm ball, rather than on the nuclear norm, as in ADMM. &lt;a href=&quot;#fnref:12&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
  &lt;/ol&gt;
&lt;/div&gt;</content>

      
      
      
      
      

      
        <author>
            <name>Sam Buchanan</name>
          
            <email>sdbuchanan@berkeley.edu</email>
          
          
        </author>
      

      
        <category term="optimization" />
      

      

      
        <summary type="html">In a recent blog post, Jeremy Bernstein (Thinking Machines) gave an algorithm for optimizing matrix-valued parameters—say, weights in a transformer—that ensures that both the update directions and the parameters themselves are well-conditioned (Bernstein, 2025). This algorithm, called manifold Muon, has an inner loop that is relatively slow to converge, and requires some hyperparameter tuning to get right.</summary>
      

      
      
    </entry>
  
  
  
    <entry>
      
      <title type="html">SPMD in JAX #2: Transformers in Bare-Metal JAX</title>
      
      
      <link href="http://sdbuchanan.com/blog/jax-2/" rel="alternate" type="text/html" title="SPMD in JAX #2: Transformers in Bare-Metal JAX" />
      
      <published>2025-09-12T00:00:00+00:00</published>
      <updated>2025-09-12T00:00:00+00:00</updated>
      <id>http://sdbuchanan.com/blog/jax-2</id>
      <content type="html" xml:base="http://sdbuchanan.com/blog/jax-2/">&lt;p&gt;In this second post in a series on programming with JAX for training
models like transformers, we write and train a transformer with a decent set of
bells and whistles, then benchmark and scale it as much as we can on our TPU v4
half-cube!&lt;/p&gt;

&lt;p&gt;As in previous posts in the series, we make good use of the Google scaling
playbook &lt;a class=&quot;citation&quot; href=&quot;#scaling-book&quot;&gt;(Austin et al., 2025)&lt;/a&gt;, as well as some useful JAX tutorials, especially &lt;a class=&quot;citation&quot; href=&quot;#jax-training-cookbook&quot;&gt;(JAX Team, 2025)&lt;/a&gt;.&lt;/p&gt;

&lt;ul id=&quot;markdown-toc&quot;&gt;
  &lt;li&gt;&lt;a href=&quot;#setup&quot; id=&quot;markdown-toc-setup&quot;&gt;Setup&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#design-choices&quot; id=&quot;markdown-toc-design-choices&quot;&gt;Design Choices&lt;/a&gt;    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#model-architecture-overview&quot; id=&quot;markdown-toc-model-architecture-overview&quot;&gt;Model Architecture Overview&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#building-the-model-first-cut&quot; id=&quot;markdown-toc-building-the-model-first-cut&quot;&gt;Building the Model: First Cut&lt;/a&gt;    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#a-quick-note-on-configuration&quot; id=&quot;markdown-toc-a-quick-note-on-configuration&quot;&gt;A Quick Note on Configuration&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#brief-setup&quot; id=&quot;markdown-toc-brief-setup&quot;&gt;Brief Setup&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#data&quot; id=&quot;markdown-toc-data&quot;&gt;Data&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#defining-the-model-mlp&quot; id=&quot;markdown-toc-defining-the-model-mlp&quot;&gt;Defining the Model: MLP&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#defining-the-model-embeddings-and-placeholders&quot; id=&quot;markdown-toc-defining-the-model-embeddings-and-placeholders&quot;&gt;Defining the Model: Embeddings and Placeholders&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#defining-the-model-building-blocks-to-transformer&quot; id=&quot;markdown-toc-defining-the-model-building-blocks-to-transformer&quot;&gt;Defining the Model: Building Blocks to Transformer&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#defining-the-model-initialization&quot; id=&quot;markdown-toc-defining-the-model-initialization&quot;&gt;Defining the Model: Initialization&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#testing-the-basic-model&quot; id=&quot;markdown-toc-testing-the-basic-model&quot;&gt;Testing the Basic Model&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#a-minimal-training-loop&quot; id=&quot;markdown-toc-a-minimal-training-loop&quot;&gt;A Minimal Training Loop&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#sampling-from-the-model-first-cut&quot; id=&quot;markdown-toc-sampling-from-the-model-first-cut&quot;&gt;Sampling from the Model: First Cut&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#building-the-model-attention-and-positional-encodings&quot; id=&quot;markdown-toc-building-the-model-attention-and-positional-encodings&quot;&gt;Building the Model: Attention and Positional Encodings&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#building-the-model-sampling-with-a-cache&quot; id=&quot;markdown-toc-building-the-model-sampling-with-a-cache&quot;&gt;Building the Model: Sampling with a Cache&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#conclusions&quot; id=&quot;markdown-toc-conclusions&quot;&gt;Conclusions&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#accumulated-configs&quot; id=&quot;markdown-toc-accumulated-configs&quot;&gt;Accumulated Configs&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#acknowledgments&quot; id=&quot;markdown-toc-acknowledgments&quot;&gt;Acknowledgments&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;h3 id=&quot;setup&quot;&gt;Setup&lt;/h3&gt;

&lt;p&gt;As usual, we’ll test things out in Python 3.13 and JAX 0.7.1 on a single TPU v4
host.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax.numpy&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;devices&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h1 id=&quot;design-choices&quot;&gt;Design Choices&lt;/h1&gt;

&lt;p&gt;A project like this is for learning and pedagogy—it’s inevitably behind the
research/infra frontiers. So the code we write is going to lean into this
somewhat. Notably,&lt;/p&gt;

&lt;blockquote&gt;
  &lt;p&gt;&lt;em&gt;We’ll implement everything in &lt;strong&gt;bare-metal JAX&lt;/strong&gt;, rather than using
neural network libraries like Flax/nnx/Equinox.&lt;/em&gt;&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;This will let us expose and
understand all the implementation details, both for infra and architecture, that
we might normally overlook. On the flip side, we’ll take some performance hits
for this (which we’ll attempt to profile and characterize), and it will lead to some
slightly unfortunate code repetition.&lt;sup id=&quot;fnref:1&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:1&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;1&lt;/a&gt;&lt;/sup&gt; We’ll also be reinventing the wheel
throughout – but for pedagogy’s sake!&lt;/p&gt;

&lt;h2 id=&quot;model-architecture-overview&quot;&gt;Model Architecture Overview&lt;/h2&gt;

&lt;p&gt;Our approach to implementing the network architecture is a modification of the approach
used in the JAX training cookbook &lt;a class=&quot;citation&quot; href=&quot;#jax-training-cookbook&quot;&gt;(JAX Team, 2025)&lt;/a&gt;, which uses a data
structure like those in Google’s &lt;a href=&quot;https://github.com/google/ml_collections&quot;&gt;ML
Collections&lt;/a&gt; library to store the model. At a
high level, this approach consists of these standard JAX neural net patterns:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;Storing the model parameters as a &lt;a href=&quot;https://docs.jax.dev/en/latest/pytrees.html&quot;&gt;pytree&lt;/a&gt;.&lt;/li&gt;
  &lt;li&gt;Defining separate functions for initialization and the model’s forward pass which take
such a pytree as input (as well as other relevant inputs).
This pattern promotes functionally-pure code for the model’s initialization and forward pass,
making it easy to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jit&lt;/code&gt; for performance.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Our implementation tries to move this slightly in the direction of
&lt;a href=&quot;https://docs.kidger.site/equinox/&quot;&gt;Equinox&lt;/a&gt;, without all the powerful features that
library entails. In particular, we will:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;Use
&lt;a href=&quot;https://docs.python.org/3/library/collections.html#collections.namedtuple&quot;&gt;NamedTuples&lt;/a&gt;
for parameters and layers (i.e., where you’d normally expect to use a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;nn.Module&lt;/code&gt; in
Pytorch), enabling good static type checking with e.g. Pyright.
NamedTuples are &lt;a href=&quot;https://docs.jax.dev/en/latest/pytrees.html#internal-pytree-handling&quot;&gt;automatically registered as
pytrees&lt;/a&gt; in JAX!&lt;/li&gt;
  &lt;li&gt;Write a top-level function for model initialization, and layer-level functions for
each separate layer’s operation (which, as above, take parameters and inputs as
arguments). This is a ‘bottom-up’ version of the previous design, where we’ll end up
with many different functions, one for each ‘layer’, which are combined bottom-up to
implement the overall model’s forward pass (more like in a typical neural network
library).&lt;/li&gt;
&lt;/ul&gt;

&lt;h1 id=&quot;building-the-model-first-cut&quot;&gt;Building the Model: First Cut&lt;/h1&gt;

&lt;p&gt;When writing this post, I wrote and debugged the initial model/training code iteratively
in a notebook. Rather than just jumping to the ‘final answer’, we’ll walk through the
development process below.&lt;/p&gt;

&lt;p&gt;In this part of the post, we’ll build up to the transformer through a few simple
sub-modules of the overall model. Our targets will be:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;We’ll work with a very simple data distribution, described below, which requires
minimal dataloading code (and no tokenization).&lt;/li&gt;
  &lt;li&gt;We’ll focus on the training loss, leaving sampling for later. Although simple sampling
gives us an important correctness test (“did we accidentally give the model access to
future tokens?”, for example…), doing it ‘fast’ requires a bit of code, so we’ll
save it for later.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;On the way to these targets, we’ll build up the following:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;Just the MLP, giving us a position-wise model (bigrams with a nonlinearity!).&lt;/li&gt;
  &lt;li&gt;A training loop with SGD and cross-entropy (next token prediction) loss.&lt;/li&gt;
  &lt;li&gt;Just causal attention, letting us build a transformer that can solve our toy task
(discussed in the next section) with two layers.&lt;/li&gt;
  &lt;li&gt;Rotary positional encodings, so that we can solve our task with a one-layer model.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;We’ll then step back and refactor things with an eye towards training larger models on
actual language data in a distributed setting.&lt;/p&gt;

&lt;p&gt;If you’re following along and you notice anything I’ve implemented that seems
suboptimal, please let me know! Drop me an email or a DM on Twitter (links at the bottom
of the page).&lt;/p&gt;

&lt;h2 id=&quot;a-quick-note-on-configuration&quot;&gt;A Quick Note on Configuration&lt;/h2&gt;

&lt;p&gt;For overall convenience, we eventually want to combine all our different configuration
options in a large dataclass, then pass our runtime instantiation of the config (with
defaults, overrides, etc.) to all our functions, which can extract just the parameters
they need.&lt;/p&gt;

&lt;p&gt;When writing this code, I tend to define the config options I need locally, then go back
later and turn them into attributes of the overall config dataclass. To avoid
frontloading with all the complexity of the config, and to avoid too much unnecessary
description of the refactoring process, we’ll reference different config options below
as attributes of an (as of yet undefined) Config class. Code outputs will also have this
class defined. You can Ctrl+F for type definitions and defaults, which will appear
later, if it feels necessary.&lt;/p&gt;

&lt;h2 id=&quot;brief-setup&quot;&gt;Brief Setup&lt;/h2&gt;

&lt;p&gt;We are going to make reference to different sources of randomness below, which we seed
from our Config class.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# Test it out
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seed&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;key_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_sampling&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h2 id=&quot;data&quot;&gt;Data&lt;/h2&gt;

&lt;p&gt;The data will consist of an (infinite) sequence of consecutive digits (from 0 to 9),
which ‘reflect’ when they reach 0 or 9. I.e.,&lt;/p&gt;
&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;01234567898765432101...
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;We’ll generate a dataset of fixed-length subsequences of this infinite sequence for
training.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;text&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;012345678987654321&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;seqs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;c&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;text&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;text&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seqs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;int32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;n&quot;&gt;key_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sk&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;permutation&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sk&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;num_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;Xtr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.8&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;Xdev&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.8&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.9&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;Xte&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.9&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;Xtr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
[8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8
 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9
 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8
 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7
 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6
 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5
 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6 5 4 3 2 1 0 1 2 3 4 5 6 7 8 9 8 7 6]
(14540, 257)
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;We shuffle the sequences above and perform a simple train-dev-test split (although we’ll
just use train for now). We take one more character than the sequence length, so that we
can easily turn these raw sequences into inputs and targets. We do this with a simple
dataloader:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;itertools&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;it&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;dataloader&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;it&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;count&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;fold_in&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;offsets&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;randint&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;global_batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,),&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;nf&quot;&gt;yield &lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Xtr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;at&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;offsets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Xtr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;at&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;offsets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:].&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;


&lt;span class=&quot;n&quot;&gt;key_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sk&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;device_put&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
    &lt;span class=&quot;nf&quot;&gt;dataloader&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sk&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Here &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;batch&lt;/code&gt; is a map iterator which wraps the dataloader with a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;device_put&lt;/code&gt; statement to enforce
our desired sharding (e.g., data parallel). When we create our config class, we create a
global mesh and set it as the default mesh, so we can pass a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;PartitionSpec&lt;/code&gt; object to
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;device_put&lt;/code&gt; in our single-host setting (see the &lt;a href=&quot;/blog/jax-1&quot;&gt;previous post&lt;/a&gt;).&lt;/p&gt;

&lt;p&gt;The batches being loaded have shape
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(config.global_batch_size, config.seq_len)&lt;/code&gt;, so expect &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;config.sharding_data&lt;/code&gt; to be
for example &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jax.P('dp')&lt;/code&gt; for a “data-parallel” mesh axis &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;'dp'&lt;/code&gt;, which leads the batch
to be sharded across accelerators.&lt;/p&gt;

&lt;h2 id=&quot;defining-the-model-mlp&quot;&gt;Defining the Model: MLP&lt;/h2&gt;

&lt;p&gt;We’ll implement the overarching structure of our model here, described above in the
“Design Choices” section, but keep the low-level implementations restricted to just the
MLP.&lt;/p&gt;

&lt;p&gt;We’ll follow the following conventions:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;“Layers” correspond to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;NamedTuples&lt;/code&gt; (which are pytrees). The elements of such a
“layer” can be either &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jax.Array&lt;/code&gt;s (parameters/pytree leaves), or other NamedTuples
(for composing layers).&lt;/li&gt;
  &lt;li&gt;We’ll name our layers with capital letters, and the functions that implement them
(i.e., their “forward” method in Pytorch lingo) in lower-case, with a leading
underscore.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Here’s our MLP layer following these conventions:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;typing&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NamedTuple&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NamedTuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;w_up&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;bias_up&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;w_down&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;bias_down&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;preact&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;dot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w_up&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_mlp_hidden&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;use_bias_mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;preact&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bias_up&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;act&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;gelu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;preact&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;approximate&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;dot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;act&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w_down&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_res_stream&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;use_bias_mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bias_down&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;We would create parameters for such a layer and apply it as follows:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;my_mlp&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;w_up&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)),&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;bias_up&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,)),&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;w_down&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)),&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;bias_down&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,)),&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;_mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;my_mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;ones&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,)))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
Array([0., 0.], dtype=float32)
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;It’s important to note here that we are &lt;strong&gt;not&lt;/strong&gt; able to enforce shape information at
this level of implementation, since abstractly speaking an MLP can be implemented for
any compatible weight matrix dimensions. These shapes get enforced later, when we
initialize an &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Mlp&lt;/code&gt; with settings specified in our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;config&lt;/code&gt; instance.&lt;/p&gt;

&lt;p&gt;More importantly, there is some ambiguity in the types of the arguments &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;params&lt;/code&gt;
that &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_mlp&lt;/code&gt; will apply to – e.g. the above function could apply to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt; being a vector
of shape &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(d_model,)&lt;/code&gt;, or a matrix of shape &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(seq_len, d_model)&lt;/code&gt;, as long as the
parameters of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Mlp&lt;/code&gt; are initialized correctly. In other words, it’s up to us to make
sure we apply this implementation correctly.&lt;/p&gt;

&lt;p&gt;In our subsequent implementation, we’re going to follow a helpful principle that JAX
affords – we’ll operate at the finest level of granularity possible at every stage of
the model, and ‘move up’ to the next level of granularity as necessary, using JAX’s
powerful primitives like &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;vmap&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;scan&lt;/code&gt;. In particular, we’ll see soon that we’ll
only use the above &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_mlp&lt;/code&gt; function in cases where &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt; is a vector of shape &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(d_model,)&lt;/code&gt;.
Since the MLPs in transformers operate the same on every position in the sequence,
and every sequence in the batch, we’ll just &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;vmap&lt;/code&gt; our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_mlp&lt;/code&gt; function when we need to
use it at the next level of granularity! This is a pretty empowering design strategy
that JAX enables: it lets us write our code’s core numerical functionality without
having to plan ahead for exactly how it will be used later, making implementation
simpler and separating complexity nicely.&lt;/p&gt;

&lt;p&gt;A few other notes (some drawbacks of this bare metal approach, which
seem unvaoidable):&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;There’s no ‘state’ in the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_mlp&lt;/code&gt; function – it just performs the operation of a MLP
given parameters/inputs. This makes it easy to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jit&lt;/code&gt;.&lt;/li&gt;
  &lt;li&gt;Initialization is &lt;strong&gt;not&lt;/strong&gt; handled here – we’ll have to do it later,
since we want to initialize the entire transformer (top-down) in one shot.&lt;/li&gt;
  &lt;li&gt;We enforce shardings for the hidden activations/output of the MLP in the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_mlp&lt;/code&gt;
function. We enforce shardings for the &lt;em&gt;parameters&lt;/em&gt; of the MLP when we initialize them
(see later).&lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;defining-the-model-embeddings-and-placeholders&quot;&gt;Defining the Model: Embeddings and Placeholders&lt;/h2&gt;

&lt;p&gt;To have a basic working model, we need token embeddings (mapping the raw integer tokens
we generated above to vectors the transformer can operate on), unembeddings (mapping the
vector outputs of the transformer to logits, which we can map to a probability
distribution over tokens to predict with), and normalization layers. We’ll have
attention layers too, eventually; for now, we’ll make a placeholder function for the
attention layer.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NamedTuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;w_qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;w_o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x_seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NamedTuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;w&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;token&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;emb&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;at&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;token&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_res_stream&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;emb&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Unembedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NamedTuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;w&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_unembedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Unembedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;dot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_res_stream&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;LayerNorm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NamedTuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;gamma&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;beta&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_layernorm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;LayerNorm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x_std&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;standardize&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;compute_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;epsilon&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;eps_ln&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gamma&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_std&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;use_bias_ln&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;beta&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;These are standard unoptimized implementations for these different building blocks.
Looking ahead slightly, we make sure to provide the ability to compute the layer
normalization operation in a higher precision (&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;config.compute_dtype&lt;/code&gt;), since the
variance operation in &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;standardize&lt;/code&gt; can lead to underflow and other losses of precision.
The API for &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_attn&lt;/code&gt; involves arguments for utilizing a cache for faster inference (we’d
normally leave this out to begin with, and add it only later).&lt;/p&gt;

&lt;h2 id=&quot;defining-the-model-building-blocks-to-transformer&quot;&gt;Defining the Model: Building Blocks to Transformer&lt;/h2&gt;

&lt;p&gt;A transformer consists of an embedding, then a sequence of transformer blocks (which
involve normalization + attention and a residual connection, then normalization + MLP
and a residual connection), then the unembedding.
Accordingly, we make a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Block&lt;/code&gt; class which combines our low-level building blocks.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;functools&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;partial&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NamedTuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;norm_attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;LayerNorm&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Attn&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;norm_mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;LayerNorm&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Mlp&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x_seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cache_in&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;att_skip&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_seq&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;vmap&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_layernorm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;norm_attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_in&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;att_skip&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;mlp_skip&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;vmap&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_layernorm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;norm_mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;vmap&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mlp_skip&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_out&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Notice that we’re following the design principle we mentioned above with the MLP: here
we expect &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x_seq&lt;/code&gt; to have two dimensions, the first for the sequence length and the
second for the model dimension, and we &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;vmap&lt;/code&gt; the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_mlp&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_layernorm&lt;/code&gt; functions to
apply to it correctly.&lt;/p&gt;

&lt;p&gt;We can now define the transformer. We’ll do this in a slightly refined way versus just
using a for loop: we’ll expect the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;blocks&lt;/code&gt; parameter below, of type &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Block&lt;/code&gt;, to have a
mapped leading axis corresponding to the &lt;em&gt;layer dimension&lt;/em&gt;, and we’ll consume this
leading axis using a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;scan&lt;/code&gt;. That means we expect every parameter in the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Block&lt;/code&gt; pytree
we pass to not just have its usual shape, say &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;param.shape&lt;/code&gt; (expected by the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_mlp&lt;/code&gt;,
etc. APIs above), but additionally be of shape &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(num_layers,) + param.shape&lt;/code&gt;. This turns
out to be very easy to guarantee with a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;vmap&lt;/code&gt; at initialization (see below).&lt;/p&gt;

&lt;p&gt;If you aren’t familiar with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;scan&lt;/code&gt;, you can read
about the API &lt;a href=&quot;https://docs.jax.dev/en/latest/_autosummary/jax.lax.scan.html&quot;&gt;here&lt;/a&gt;.
It takes a function that accepts two arguments – a ‘carry’ and an ‘input’ – and
produces one output, as well as an ‘initial’ carry, and an initial input, which &lt;strong&gt;must&lt;/strong&gt;
have an additional mapped axis somewhere relative to the input shape the function
argument expects. It will then iteratively apply the function to the current carry and
the current input; the output carry is passed to the next call of the function. In this
way, it’s like a vectorized for loop – you can also draw a direct analogy to a
recurrent neural network, or other dynamical systems.&lt;/p&gt;

&lt;p&gt;In our case, the ‘carry’ for scan is our network’s input and activations (we pass these
from layer to layer sequentially), and the ‘input’ for scan is the mapped &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Block&lt;/code&gt; pytree
(as well as the mapped KV cache for inference – the updated cache becomes the output –
but we can ignore that for now).&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NamedTuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;emb&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Embedding&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# vmapped at init
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;unemb&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Unembedding&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;tokens&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cache_in&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x_seq&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;vmap&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;emb&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tokens&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_block_fun&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params__cache_in&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;tuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_in&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params__cache_in&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_in&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;scan&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_block_fun&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_in&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;vmap&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_unembedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;unemb&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_out&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Using &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;scan&lt;/code&gt; leads to a very concise forward pass definition for the transformer. Note
that our transformer is still operating on sequences of embeddings – we’ll apply it to
a batch of sequences with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;vmap&lt;/code&gt; later!&lt;/p&gt;

&lt;h2 id=&quot;defining-the-model-initialization&quot;&gt;Defining the Model: Initialization&lt;/h2&gt;

&lt;p&gt;The last step before testing is to initialize the model. This leads to a long
definition, but the intrinsic complexity is not high.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_model_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;emb&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_std&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;normal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_vocab&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_res_stream&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;emb&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_unembedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Unembedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;unemb&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_std&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;normal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_vocab&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_res_stream&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Unembedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;unemb&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;k_w_up&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k_w_down&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;w_up&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_std&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;normal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;k_w_up&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mlp_factor&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_wup&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;w_down&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_std&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;normal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;k_w_down&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mlp_factor&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_wdown&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;bias_up&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mlp_factor&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_mlp_hidden&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;bias_down&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_res_stream&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w_up&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w_up&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bias_up&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bias_up&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;w_down&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w_down&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bias_down&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bias_down&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;k_qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k_o&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;w_qkv&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_std&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;normal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;k_qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_wqkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;w_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_std&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;normal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;k_o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_wo&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w_qkv&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w_qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;w_o&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w_out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_layernorm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;LayerNorm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;gamma&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;ones&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_res_stream&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;beta&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_res_stream&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;LayerNorm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gamma&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gamma&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;beta&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;beta&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;key_attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_mlp&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;norm_attn&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;init_layernorm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;attn&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;init_attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key_attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;norm_mlp&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;init_layernorm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;mlp&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;init_mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key_mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Make the full network
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;key_emb&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_unemb&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;keys_blocks&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_layers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;emb&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;init_embedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key_emb&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;blocks&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;vmap&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;keys_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;unemb&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;init_unembedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key_unemb&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Notice above how we produce the mapped axis needed for the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;scan&lt;/code&gt; in the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_transformer&lt;/code&gt;
function above: we just split our PRNGKey into an array of keys, one for each layer, and
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;vmap&lt;/code&gt; the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;init_block&lt;/code&gt; function over the array of keys. Easy!&lt;/p&gt;

&lt;h2 id=&quot;testing-the-basic-model&quot;&gt;Testing the Basic Model&lt;/h2&gt;

&lt;p&gt;We can perform a quick test to make sure everything above is at least syntactically
correct. There are two things we want to see:&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;Can we run the model in eager mode?&lt;/li&gt;
  &lt;li&gt;Can we &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jit&lt;/code&gt; the model?&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;We really only care about 2 (it implies 1), so let’s test it.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;key_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sk&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_model_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sk&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Number of layers in model: &lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_layers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Model dimension: &lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;cache&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;tokens&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;nd&quot;&gt;@jax.jit&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tokens&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tokens&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out_cache&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tokens&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
Number of layers in model:  2
Model dimension:  768
Array([[0.894531, -0.482422, 1.41406, 0.194336, -2.76562, -1.27344,
        1.21094, 0.417969, -0.460938, -0.304688],
       [0.0610352, 2.8125, 1.08594, -0.351562, -0.326172, -1.46094,
        0.0354004, 0.283203, 0.796875, -1.42969],
       [0.186523, 1.21875, 0.429688, -1.10156, 0.667969, -2.07812,
        -0.078125, 0.410156, 0.0273438, -2.60938],
       [0.201172, 1.46094, 0.124512, -1.27344, -1.05469, -0.667969,
        0.453125, 1.78906, -1.00781, 1.67188],
       [-0.925781, 0.988281, -1.14844, 0.0942383, 0.0898438, 1.70312,
        1.125, 1.03125, -1.17969, -1.69531],
       [-0.710938, -0.679688, -1.53125, 0.800781, 0.443359, -0.71875,
        -2.29688, 0.488281, -0.169922, -1.38281],
       [1.44531, 0.949219, -0.589844, 0.357422, -1.30469, 2.85938,
        -0.388672, 0.46875, -0.695312, -0.425781],
       [-1.85156, 0.679688, 1.25, 0.109863, -0.667969, -0.361328,
        -0.0405273, -0.964844, 3.01562, 0.333984],
       [-0.304688, 1.83594, 0.734375, 1.16406, -0.114746, -0.800781,
        0.486328, 1, -1.00781, -1.01562],
       [-1.66406, 0.925781, -0.0908203, 1, 0.298828, -2.40625, 0.574219,
        0.945312, -2.54688, -0.123047]], dtype=bfloat16)
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Looks good! There might be a small issue with the unembedding scaling, as we can see:&lt;/p&gt;
&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
Array([[0.163086, 0.0410156, 0.275391, 0.0810547, 0.00418091, 0.0186768,
        0.224609, 0.101562, 0.0419922, 0.0493164],
       [0.0390625, 0.613281, 0.108887, 0.026001, 0.0264893, 0.00848389,
        0.0380859, 0.0488281, 0.081543, 0.00872803],
       [0.100098, 0.279297, 0.126953, 0.02771, 0.161133, 0.010376,
        0.0766602, 0.124512, 0.0849609, 0.006073],
       [0.0583496, 0.204102, 0.0539551, 0.0133057, 0.0164795, 0.0244141,
        0.0751953, 0.285156, 0.017334, 0.253906],
       [0.0227051, 0.15332, 0.0181885, 0.0629883, 0.0629883, 0.314453,
        0.176758, 0.160156, 0.0177002, 0.010437],
       [0.0588379, 0.060791, 0.026123, 0.267578, 0.1875, 0.0588379,
        0.012146, 0.195312, 0.101562, 0.0300293],
       [0.141602, 0.0864258, 0.0184326, 0.0476074, 0.00909424, 0.582031,
        0.022583, 0.0534668, 0.0164795, 0.0218506],
       [0.00500488, 0.0629883, 0.112305, 0.0358887, 0.0164795, 0.0224609,
        0.0307617, 0.012207, 0.65625, 0.0444336],
       [0.0395508, 0.335938, 0.111328, 0.171875, 0.0473633, 0.0239258,
        0.0869141, 0.145508, 0.0194092, 0.0194092],
       [0.0145874, 0.193359, 0.0693359, 0.208008, 0.102539, 0.00689697,
        0.135742, 0.196289, 0.00598145, 0.0673828]], dtype=bfloat16)
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;In particular, the initial outputs are not very uniform, which we would like to have.
However, we should still be able to learn – our network is not very deep.&lt;/p&gt;

&lt;h2 id=&quot;a-minimal-training-loop&quot;&gt;A Minimal Training Loop&lt;/h2&gt;

&lt;p&gt;Now we have enough to train a model. A simple training loop follows, using SGD.
First, we set up the training step, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jit&lt;/code&gt;ting it for performance.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;TrainState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NamedTuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Transformer&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;


&lt;span class=&quot;nd&quot;&gt;@jax.jit&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;TrainState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;TrainState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;init_model_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;cache&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;


&lt;span class=&quot;nd&quot;&gt;@partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jit&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;donate_argnums&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;train_step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;TrainState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;loss_fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;targets&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;vmap&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;inputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;compute_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;logprobs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;log_softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;take_along_axis&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;logprobs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;targets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[...,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;value_and_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;loss_fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;new_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;TrainState&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tree&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;metrics&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metrics&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;new_state&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;In the training step, after loading a batch of sequences from our previously-written
batch loader, we vmap &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_transformer&lt;/code&gt; over it, then calculate the next-token prediction
loss with the standard bare-metal approach (with possible up-casting for numerical
stability of the softmax). The &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;donate_argnums&lt;/code&gt; argument to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jit&lt;/code&gt; lets us mark that the
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;train_state&lt;/code&gt; buffer can be discarded after the step finishes (since we create and
return a new &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;TrainState&lt;/code&gt; with the results of the SGD step), letting the XLA compiler
reuse that memory for the output state.&lt;/p&gt;

&lt;p&gt;The loop is then simple to write:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;prev_metrics&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_steps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cur_metrics&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;train_step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;next&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;log_metrics&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;prev_metrics&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;prev_metrics&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cur_metrics&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;log_metrics&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;and&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;50&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;log_metrics&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;|=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
        &lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;metric&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;: &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;val&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metric&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;val&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;log_metrics&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;items&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sep&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\t&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
loss: 1.1560431718826294	step: 50
loss: 0.9151328206062317	step: 100
loss: 1.004440188407898	    step: 150
loss: 0.8609314560890198	step: 200
loss: 0.9039498567581177	step: 250
loss: 0.8058883547782898	step: 300
loss: 0.7688642144203186	step: 350
loss: 0.6995882987976074	step: 400
loss: 0.6380590796470642	step: 450
loss: 0.6304700374603271	step: 500
loss: 0.6276624798774719	step: 550
loss: 0.6255137920379639	step: 600
loss: 0.6255109310150146	step: 650
loss: 0.6271587610244751	step: 700
loss: 0.6273760795593262	step: 750
loss: 0.6262620687484741	step: 800
loss: 0.6236293911933899	step: 850
loss: 0.6241434216499329	step: 900
loss: 0.6261416673660278	step: 950
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;It’s not great, but it’s indeed training! We don’t expect to be able to solve this task
perfectly with our simple model, since a position-wise MLP can only use one token of
context to predict the next token, and our task requires at least two. To see why, just
note that, for example, the character following &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;5&lt;/code&gt; could be either &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;6&lt;/code&gt; or &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;4&lt;/code&gt;,
depending on whether we’re at a rising or falling part of the sequence. Two characters
of context is enough to estimate this, meaning adding attention should help here!&lt;/p&gt;

&lt;p&gt;As an aside, above, we make sure to print metrics for the &lt;em&gt;previous&lt;/em&gt; iteration of the
training loop at each step. This helps with correctly pipelining host to data transfers:
see &lt;a class=&quot;citation&quot; href=&quot;#jax-training-cookbook&quot;&gt;(JAX Team, 2025)&lt;/a&gt; for a detailed description.&lt;/p&gt;

&lt;h1 id=&quot;sampling-from-the-model-first-cut&quot;&gt;Sampling from the Model: First Cut&lt;/h1&gt;

&lt;p&gt;To get a sense of the behavior of our learned model, we can write our sampler for it
now. Normally, efficient sampling requires a cache of past outputs to be maintained (we
already have the API for this), but we’ll discuss this later after we implement
attention.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;nd&quot;&gt;@partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jit&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;donate_argnums&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,))&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;sample_one_token&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cache_in&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;temperature&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_transformer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_in&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;compute_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;next_token&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;categorical&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;temperature&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),))&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;next_token&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;config_sampling_args&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;sharding_data&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()}&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;config_sampling&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config_sampling_args&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;cache&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;tokens_to_generate&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;63&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;temperature&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.1&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tokens_to_generate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;key_sampling&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sk&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key_sampling&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;next_token&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;sample_one_token&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;config_sampling&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sk&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;temperature&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;concatenate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;next_token&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
Array([1, 2, 3, 2, 3, 4, 3, 2, 3, 2, 3, 5, 6, 5, 4, 3, 4, 5, 2, 1, 2, 3,
       2, 1, 0, 1, 2, 3, 4, 3, 4, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 5, 6, 5,
       6, 5, 1, 2, 0, 1, 0, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 4, 3, 2],      dtype=int32)
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;We set a relatively low temperature in order to better highlight what the model has
learned. The generated sequence is in line with our hypothesis – the model has lowered
the loss by learning that it needs to predict the number that’s one above or one below
the current input, but it can’t do much better than guessing randomly between these two
possibilities (hence why our final loss is close to $\log(2)$).&lt;/p&gt;

&lt;p&gt;Let’s improve this with attention!&lt;/p&gt;

&lt;h1 id=&quot;building-the-model-attention-and-positional-encodings&quot;&gt;Building the Model: Attention and Positional Encodings&lt;/h1&gt;

&lt;p&gt;An attention implementation that follows &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jax.nn.dot_product_attention&lt;/code&gt; shape
conventions follows, with both manual implementations and one using this function.
For now, we omit the cache implementation and positional embedding implementation.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x_seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# s: sequence length
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# d: embedding dim (config.d_model)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# n: attention heads (config.num_heads)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# h: head dim (config.d_head)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# x_seq: s x d
&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;qkv&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;sd,d3nh-&amp;gt;3snh&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x_seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w_qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_att_qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;s&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Attention computation
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))[:,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...]&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# broadcast over heads
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;use_fa&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;attn_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;dot_product_attention&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;scale&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;snh,tnh-&amp;gt;nst&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;compute_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Scale and causal mask
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;where&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# type: ignore[reportArgumentType]
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;attn_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;nst,tnh-&amp;gt;snh&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;snh,hnd-&amp;gt;sd&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;attn_out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w_o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_res_stream&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Attention takes the embeddings of the input sequences $\vX \in \bbR^{S \times D}$ as
input. It computes $N$ “attention heads”, then combines them to form the output.
Each head, indexed by $n \in [N]$, is generated from different projections of the input
embedding sequences: we calculate&lt;/p&gt;

\[\begin{equation*}
\vQ_n = \vX \vW_{Q, n}, \quad \vK_n = \vX \vW_{K, n}, \quad \vV_n = \vX \vW_{V, n},
\end{equation*}\]

&lt;p&gt;where the output dimension of the projections is $H$, then generate the head output as&lt;/p&gt;

\[\begin{equation*}
\vA_n = \mathrm{softmax}(\vQ_n \vK_n^\top + \vM) \vV_n,
\end{equation*}\]

&lt;p&gt;with the softmax being taken along rows, and $\vM$ denoting a causal mask:&lt;/p&gt;

\[M_{ij} = \begin{cases}
0 &amp;amp; i \leq j \\
-\infty &amp;amp; \ow.
\end{cases}\]

&lt;p&gt;The heads are then concatenated together along the embedding dimension, producing an $S
\times NH$ dimensional result (usually $NH = D$), and then an output projection $\vW_O$
is applied to produce the attention output, which has the same shape as the input $\vX$.
The code above implements these operations concisely with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jnp.einsum&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;We can replace our previous placeholder attention implementation with a complete one and
then rerun our training loop, since we already added the proper initialization code
above.&lt;/p&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
loss: 0.6685549020767212	step: 50
loss: 0.58647620677948	step: 100
loss: 0.5948932766914368	step: 150
loss: 0.5258455872535706	step: 200
loss: 0.5094199776649475	step: 250
loss: 0.40264642238616943	step: 300
loss: 0.42422279715538025	step: 350
loss: 0.48532211780548096	step: 400
loss: 0.45087528228759766	step: 450
loss: 0.3302219808101654	step: 500
loss: 0.39742058515548706	step: 550
loss: 0.3917175531387329	step: 600
loss: 0.3619387745857239	step: 650
loss: 0.34305423498153687	step: 700
loss: 0.2451237142086029	step: 750
loss: 0.27099400758743286	step: 800
loss: 0.4500095248222351	step: 850
loss: 0.26931577920913696	step: 900
loss: 0.2725009322166443	step: 950
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
loss: 0.15929250419139862	step: 50
loss: 0.14883831143379211	step: 100
loss: 0.165802463889122	step: 150
loss: 0.14946813881397247	step: 200
loss: 0.15420110523700714	step: 250
loss: 0.14391759037971497	step: 300
loss: 0.2680056691169739	step: 350
loss: 0.14069046080112457	step: 400
loss: 0.27090030908584595	step: 450
loss: 0.14207065105438232	step: 500
loss: 0.13676664233207703	step: 550
loss: 0.13658639788627625	step: 600
loss: 0.13079527020454407	step: 650
loss: 0.39731210470199585	step: 700
loss: 0.1367562860250473	step: 750
loss: 0.13199977576732635	step: 800
loss: 0.1323813498020172	step: 850
loss: 0.12948714196681976	step: 900
loss: 0.13398176431655884	step: 950
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Two back-to-back output traces are posted above. We can get a sense that this model
&lt;em&gt;should&lt;/em&gt; be able to learn to solve the task, but the poor optimizer (constant learning
rate SGD) is holding it back. If we rerun the sampling code above with the new model at
low temperature (&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;temperature=0.1&lt;/code&gt;), we do get correct outputs:&lt;/p&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4,
       5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8,
       9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 8],      dtype=int32)
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Yet the final loss is suggestive that the model we’ve learned is not very good.
We can improve the model further by incorporating positional encodings. We’ll use the
modern choice, rotary positional encodings (RoPE). There are already many great blogs
explaining the intuition behind these encodings, so I’ll just focus here on the
algebraic ideas.&lt;/p&gt;

&lt;p&gt;The basic idea is to consider the dot product of a query-key pair forming one head’s
attention matrix $\vA$, as we defined it above (dropping the subscript $n$ for
concision). Ignoring the mask, for the $ij$-th entry of this attention matrix, we
consider the dot product (“attention logits”) $\vq_i \vk_j^\top$.
A notable property of attention is that the only dependence of its output on the
position $i$ is through the causal mask $\vM$: if we drop the mask, then permuting the
input positions produces a corresponding permutation of the output positions,
because every position receives the same projections $\vW_{QKV}$.&lt;sup id=&quot;fnref:2&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:2&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;2&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

&lt;p&gt;This is arguably undesirable: the model benefits in many cases from the ability to learn
a sequence-to-sequence mapping that is more position sensitive. For example, in our toy
data setting, the model would be able to learn a good mapping faster if it were able to
selectively focus on only a small number of past tokens (since this is enough to
determine whether we’re in a “rising” or “falling” part of the sequence).&lt;/p&gt;

&lt;p&gt;A nice way to incorporate this positional information is through a &lt;em&gt;relative&lt;/em&gt; positional
embedding. In such a method, we modify the attention operation so that the attention
logits incorporate the magnitude of the difference in positions $\abs{i - j}$. In RoPE,
this is done by inducing a quadratic form&lt;/p&gt;

\[\begin{equation*}
    \vq_i \vR_{\abs{i-j}} \vk_j^\top,
\end{equation*}\]

&lt;p&gt;where $\vR_{\abs{i-j}}$ is a specific rotation matrix.&lt;sup id=&quot;fnref:3&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:3&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;3&lt;/a&gt;&lt;/sup&gt; RoPE’s choice of this matrix has
the useful property that for fixed-norm keys and queries, the magnitude of the encoded
attention logits &lt;em&gt;decays&lt;/em&gt; as a function of $\abs{i-j}$, giving &lt;em&gt;a priori&lt;/em&gt; less weight to
logits of positions that are further apart.&lt;/p&gt;

&lt;p&gt;It also has the essential property that the encoded logits can be computed efficiently.
It turns out that there are (real-valued) sequences $\vc_i$, $\vs_i$ corresponding to each
position, such that the following holds. Let $\vPi \in \bbR^{H \times H}$ denote the
permutation matrix that reverses the top $H/2$ coordinates of a vector with the bottom
$H/2$ coordinates (assume $H$ is even).&lt;sup id=&quot;fnref:4&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:4&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;4&lt;/a&gt;&lt;/sup&gt; Then one has&lt;/p&gt;

\[\begin{equation*}
    \vq_i \vR_{\abs{i-j}} \vk_j^\top
    =
    (\vq_i \circ \vc_i + \vPi \vq_i \circ \vs_i)
    (\vk_j \circ \vc_j + \vPi \vk_j \circ \vs_j)^\top,
\end{equation*}\]

&lt;p&gt;where matrix multiplication binds before elementwise multiplication $\circ$. In
particular, we can compute the encoded logits via a small number of elementwise
multiplications, adds, and concatenations involving the raw keys and queries, which
leads to a minimal-complexity implementation. The sequences $\vc_i, \vs_i$ depend only
on position, and can be precomputed and reused.&lt;/p&gt;

&lt;p&gt;We give a simple implementation of RoPE below, and its integration into our attention
layer we implemented previously.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_precompute_rope_sincos&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;freqs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;exp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;log&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;rope_theta&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;compute_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;
        &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt;
        &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;//&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;positions&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;max_seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cycles&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;positions&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;freqs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;cos&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cycles&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;sin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cycles&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_apply_rope&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cos&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;positions&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# S x H
&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x_1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;//&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;//&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;c&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cos&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;at&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;positions&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;at&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;positions&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;concatenate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;c&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x_seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# s: sequence length
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# d: embedding dim (config.d_model)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# n: attention heads (config.num_heads)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# h: head dim (config.d_head)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# x_seq: s x d
&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;qkv&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;sd,d3nh-&amp;gt;3snh&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x_seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w_qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_att_qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;s&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Apply RoPE
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;use_rope&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;update_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# ignore passed value
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;ensure_compile_time_eval&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;rope_cos&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;rope_sin&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_precompute_rope_sincos&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;positions&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;_apply_rope_one_head&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;_apply_rope&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;rope_cos&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;rope_sin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;positions&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;_apply_rope_all_heads&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;vmap&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_apply_rope_one_head&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;in_axes&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out_axes&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_apply_rope_all_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_apply_rope_all_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Attention computation
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))[:,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...]&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# broadcast over heads
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;use_fa&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;attn_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;dot_product_attention&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;scale&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;snh,tnh-&amp;gt;nst&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;compute_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Scale and causal mask
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;where&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# type: ignore[reportArgumentType]
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;attn_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;nst,tnh-&amp;gt;snh&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;snh,hnd-&amp;gt;sd&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;attn_out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w_o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_res_stream&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The changes to the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_attn&lt;/code&gt; function are all within the code block marked with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;# Apply
RoPE&lt;/code&gt; (everything else is unchanged); the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_apply_rope&lt;/code&gt; function takes care of the
necessary permutation/concatenation operation. We use JAX’s &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ensure_compile_time_eval&lt;/code&gt;
context manager to ensure the RoPE multiplying sequences are precomputed. The
multiplying sequences themselves are chosen to have geometrically-varying time values,
with a rate parameter &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;rope_theta&lt;/code&gt; set to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;1e4&lt;/code&gt; by default (in general, set based on the
target context length, here 1024).&lt;/p&gt;

&lt;p&gt;With RoPE, we find it much easier to learn a strong model for our toy data.
After rerunning training with these modifications, we observe the following loss
progression:&lt;/p&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
loss: 0.6326758861541748	step: 50
loss: 0.4087381660938263	step: 100
loss: 0.14004486799240112	step: 150
loss: 0.08325601369142532	step: 200
loss: 0.06802305579185486	step: 250
loss: 0.05945558100938797	step: 300
loss: 0.05321500450372696	step: 350
loss: 0.052829645574092865	step: 400
loss: 0.048602353781461716	step: 450
loss: 0.04696694016456604	step: 500
loss: 0.04576549679040909	step: 550
loss: 0.04564153775572777	step: 600
loss: 0.04341501370072365	step: 650
loss: 0.043622132390737534	step: 700
loss: 0.04205687716603279	step: 750
loss: 0.04401983320713043	step: 800
loss: 0.04256521165370941	step: 850
loss: 0.041673485189676285	step: 900
loss: 0.0421551875770092	step: 950
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Much better! For a completely overconfident model, we might expect to have loss around
$\log(2) / 256 \approx 0.002$, corresponding to randomly guessing for the first token in
each sequence of the batch (we train with sequence length &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;256&lt;/code&gt;). This is within about
an order of magnitude.&lt;/p&gt;

&lt;h1 id=&quot;building-the-model-sampling-with-a-cache&quot;&gt;Building the Model: Sampling with a Cache&lt;/h1&gt;

&lt;p&gt;The implementation of sampling we gave above is extremely slow: although we &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jit&lt;/code&gt; the
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;sample_one_token&lt;/code&gt; function, we are re-forwarding the entire updated sequence through
the model after every generation, which is very wasteful. A quick test on our
previously-implemented sampling function (on 1x TPU v4 host) takes about 50 seconds to
sample 63 tokens, including the time to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jit&lt;/code&gt; compile the sampling function.&lt;/p&gt;

&lt;p&gt;This makes sense, given our naive implementation. Indeed, if we recall the attention
definition, at the $i$-th token each head computes&lt;/p&gt;

\[\begin{equation*}
    \mathrm{softmax}( \vq_i \vK^T ) \vV,
\end{equation*}\]

&lt;p&gt;and the output projection just accumulates these independent heads and ‘lifts’ the head
dimension $H$ to the output dimension $D$.
In other words, if we were to save the previous key and value tokens $(\vk_j,
\vv_j)_{j=1}^{i-1}$ in a cache, we could reuse these to compute each head’s attention
output with no wasted operations: we’d just need to compute $\vk_i$ and $\vv_i$ for the
current position, then load the previous positions’ keys and values from the cache to
compute the full attention output.&lt;sup id=&quot;fnref:7&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:7&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;5&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

&lt;p&gt;The data structure that stores these past keys and values is called a &lt;em&gt;KV cache&lt;/em&gt;, and
it’s essential for fast inference, since autoregressive operation implies we can only
compute one token at a time (in contrast to
training, where we forward an entire batch of sequences at once).
A back-of-the-envelope analysis suggests that the naive sampling strategy (re-forward
the entire sequence on every new token generation) takes $O(S^2D^2 + DS^3)$ FLOPs to
generate an $S$-length sequence starting from an empty prompt, whereas the KV caching
strategy takes only $O(SD^2 + DS^2)$ FLOPs. This is a significant savings!&lt;/p&gt;

&lt;p&gt;The catch is that the KV caching strategy requires significantly more memory accesses
and usage compared to the naive approach. In JAX, we also need to be careful to
implement the KV cache ‘correctly’ – since the cache grows after every generation,
naive implementations can entail inputs of different sizes to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_attn&lt;/code&gt; (triggering a
re-compilation every time we call it, when it’s &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jit&lt;/code&gt;ted!), or dynamically-sized arrays
(not allowed within a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jit&lt;/code&gt; context!).&lt;sup id=&quot;fnref:6&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:6&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;6&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

&lt;p&gt;Our simple implementation of the KV cache is below. Since we left the API for the cache
in the code we’ve written already, we just need to implement its functionality in the
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_attn&lt;/code&gt; function, and add a function to initialize an empty cache. To play nice with
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jit&lt;/code&gt;, we will implement a static-sized cache with a fixed size &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;config.max_seq_len&lt;/code&gt;
(the same parameter we used for RoPE), ensuring we can &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jit&lt;/code&gt; the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_attn&lt;/code&gt; function only
once, and maintain an auxilary &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;cache_size&lt;/code&gt; variable to keep track of how much we’ve
written to the cache. To avoid having dynamically-sized arrays in the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_attn&lt;/code&gt; function,
we are a bit wasteful and zero-initialize the cache, then just multiply the current
query $\vq_i$ against the entire cache on every attention operation.&lt;sup id=&quot;fnref:5&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:5&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;7&lt;/a&gt;&lt;/sup&gt; To avoid having
these extra keys influence the attention computation (since softmaxing them produces a
nonzero output!), we augment the attention mask to mask them out.&lt;/p&gt;

&lt;p&gt;The implementation is below.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Attn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x_seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 2 x config.max_seq_len x n x h (see below)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# s: sequence length
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# d: embedding dim (config.d_model)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# n: attention heads (config.num_heads)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# h: head dim (config.d_head)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# x_seq: s x d
&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;qkv&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;sd,d3nh-&amp;gt;3snh&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x_seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w_qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_att_qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;s&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# we save k shape later, after possibly prepending cache
&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Apply RoPE
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;use_rope&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;update_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# ignore passed value
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;ensure_compile_time_eval&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;rope_cos&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;rope_sin&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_precompute_rope_sincos&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;positions&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;_apply_rope_one_head&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;_apply_rope&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;rope_cos&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;rope_sin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;positions&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;_apply_rope_all_heads&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;vmap&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_apply_rope_one_head&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;in_axes&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out_axes&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_apply_rope_all_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_apply_rope_all_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Read/update/concatenate the cache
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;update_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;k_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v_cache&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;dynamic_update_slice&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;dynamic_update_slice&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;v_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;kv_cache_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;concatenate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;kv_cache_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# ignore passed value
&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Attention computation
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;t&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))[:,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# with static kv cache, must mask unused memory!
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;amp;&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...]&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# broadcast over heads
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;use_fa&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;attn_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;dot_product_attention&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;scale&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;snh,tnh-&amp;gt;nst&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;compute_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Scale and causal mask
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;where&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# type: ignore[reportArgumentType]
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;attn_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;nst,tnh-&amp;gt;snh&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;snh,hnd-&amp;gt;sd&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;attn_out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w_o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_res_stream&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_cache_out&lt;/span&gt;


&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_kv_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;global_batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_layers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;max_seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_att_qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Running this block and then rerunning our training script shows that our changes haven’t
broken anything. We can test how much it speeds up sampling, as well. We modify the
sampler above slightly – we enable the cache in the config, and forward one token per
sampling step to the model, rather than the entire output, as the cache will store all
the context.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;config_sampling_args&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;sharding_data&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;update_cache&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;config_sampling&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config_sampling_args&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;next_token&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;next_token&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;cache&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;init_kv_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config_sampling&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;tokens_to_generate&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;63&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;temperature&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.1&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;step&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tokens_to_generate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;key_sampling&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sk&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key_sampling&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;next_token&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;sample_one_token&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;config_sampling&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;sk&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;train_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;next_token&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;cache_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;temperature&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;concatenate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;next_token&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4,
       5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8,
       9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 8],      dtype=int32)
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Correct outputs!
The runtime is about 2 seconds including the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jit&lt;/code&gt; compilation time on the first run,
and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;0.5&lt;/code&gt; seconds on subsequent runs (since the static cache and
one-token-input-per-sampling-step means the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jit&lt;/code&gt;ted sampler sees the same input shapes
on every call). Quite a speedup!&lt;/p&gt;

&lt;h1 id=&quot;conclusions&quot;&gt;Conclusions&lt;/h1&gt;

&lt;p&gt;In this blog post, we’ve built up:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;A simple &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jit&lt;/code&gt;table transformer model with RoPE;&lt;/li&gt;
  &lt;li&gt;A basic training loop with SGD and toy data (digit alphabet);&lt;/li&gt;
  &lt;li&gt;A minimal &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jit&lt;/code&gt;table sampler with a static KV cache.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;All in bare-metal JAX.&lt;/p&gt;

&lt;p&gt;Since the length is already nontrivial, we’ll save further extensions – refactoring
into a python package, improving the config scaffolding for running experiments and
adding distributed training support, improving the optimizer, and training on actual
language data – to the next post.
I hope this exposition is helpful to learners of JAX and transformers – writing it has
been very helpful to me!&lt;/p&gt;

&lt;h1 id=&quot;accumulated-configs&quot;&gt;Accumulated Configs&lt;/h1&gt;

&lt;p&gt;For reference, here’s the implementation of the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Config&lt;/code&gt; class we’ve used throughout the
discussion above.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# Architecture: config
&lt;/span&gt;&lt;span class=&quot;nd&quot;&gt;@jax.tree_util.register_static&lt;/span&gt;
&lt;span class=&quot;nd&quot;&gt;@dataclass&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kw_only&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;frozen&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Experiment orchestration params
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;mesh_axis_names&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;tuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;dp&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mesh_shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;tuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;seed&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1337&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Data and training params
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;256&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;global_batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;num_steps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1e-2&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Model architecture params
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;num_vocab&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;d_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;768&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;12&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;64&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mlp_factor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;num_layers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;param_std&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.02&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;rope_theta&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;10000.0&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;max_seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Model dtypes
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;param_dtype&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bfloat16&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# weights, activations
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;compute_dtype&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;float32&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# layernorm, attn logits, rope
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;optimizer_dtype&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;float32&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# optimizer state
&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Model call-time params
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;eps_ln&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1e-6&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;use_bias_ln&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;bool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;use_fa&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;bool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;use_bias_mlp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;bool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;use_rope&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;bool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;update_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;bool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# default training
&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Model sharding params
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;sharding_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;PartitionSpec&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;dp&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sharding_wqkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;PartitionSpec&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sharding_wo&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;PartitionSpec&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sharding_wup&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;PartitionSpec&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sharding_wdown&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;PartitionSpec&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sharding_mlp_hidden&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;PartitionSpec&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sharding_res_stream&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;PartitionSpec&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sharding_att_qkv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;PartitionSpec&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__post_init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Set up and register mesh
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;make_mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh_shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh_axis_names&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;nf&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh_shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;AxisType&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Explicit&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,),&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;set_mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# Checks
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;d_head&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Head dimension needs to be divisible by 2 for RoPE&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h1 id=&quot;acknowledgments&quot;&gt;Acknowledgments&lt;/h1&gt;

&lt;p&gt;Thanks to the &lt;a href=&quot;https://sites.research.google/trc/about/&quot;&gt;TRC program&lt;/a&gt; for
compute support.&lt;/p&gt;

&lt;div class=&quot;footnotes&quot; role=&quot;doc-endnotes&quot;&gt;
  &lt;ol&gt;
    &lt;li id=&quot;fn:1&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;Namely, because we won’t have a clean way to combine our model
parameter definitions with the actual application of the model. &lt;a href=&quot;#fnref:1&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:2&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;The causal mask actually &lt;em&gt;is&lt;/em&gt; enough to enable sequence-to-sequence mappings that
depend on position to be learned in multi-layer transformers. The recent trend of
“NoPE” (no positional encoding) in open-source large models is evidence of this. &lt;a href=&quot;#fnref:2&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:3&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;It’s actually a block-diagonal rotation matrix, with $2 \times 2$ blocks. This is
what leads to the efficient approach to computing its induced quadratic form below! &lt;a href=&quot;#fnref:3&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:4&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;In code, for a list &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;v&lt;/code&gt;, this is just &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;v[H//2:] + v[:H//2]&lt;/code&gt;. Moreover, the choice
of the permutation $\vPi$ is not essential (because the coordinates of the embedded
keys and queries follow a learnable linear transformation), and can be changed for
either mathematical clarity or code clarity (we choose based on the latter; the
original RoPE paper emphasized the former). &lt;a href=&quot;#fnref:4&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:7&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;Using some sort of a cache turns out to be essential to avoid using as many FLOPs
as the naive strategy (calculated below). If we try to adaptively recompute the keys and
values necessary for generating the current token’s output, it can be seen that we
need the &lt;em&gt;previous queries&lt;/em&gt; in order to do so. Caching instead the past keys and
values avoids any recomputation. &lt;a href=&quot;#fnref:7&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:6&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;Actually, our naive sampling implementation above, which re-forwards the entire
sequence after concatenating, runs so slow because the &lt;em&gt;input&lt;/em&gt; to the transformer
keeps growing in size – which triggers a recompilation after every generation!
Avoiding these pitfalls is a general challenge in JAX. &lt;a href=&quot;#fnref:6&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:5&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;Although this &lt;em&gt;feels&lt;/em&gt; very wasteful, it actually only loses (asymptotically) a
factor of $2$ in FLOPs versus the optimal implementation. &lt;a href=&quot;#fnref:5&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
  &lt;/ol&gt;
&lt;/div&gt;</content>

      
      
      
      
      

      
        <author>
            <name>Sam Buchanan</name>
          
            <email>sdbuchanan@berkeley.edu</email>
          
          
        </author>
      

      
        <category term="jax" />
      
        <category term="transformers" />
      

      

      
        <summary type="html">In this second post in a series on programming with JAX for training models like transformers, we write and train a transformer with a decent set of bells and whistles, then benchmark and scale it as much as we can on our TPU v4 half-cube!</summary>
      

      
      
    </entry>
  
  
  
    <entry>
      
      <title type="html">SPMD in JAX #1: Sharding</title>
      
      
      <link href="http://sdbuchanan.com/blog/jax-1/" rel="alternate" type="text/html" title="SPMD in JAX #1: Sharding" />
      
      <published>2025-08-20T00:00:00+00:00</published>
      <updated>2025-08-20T00:00:00+00:00</updated>
      <id>http://sdbuchanan.com/blog/jax-1</id>
      <content type="html" xml:base="http://sdbuchanan.com/blog/jax-1/">&lt;p&gt;This post will be the first in a series on programming with JAX, for training
models like transformers. I’m experimenting with these as an alternative to my
usual scratch paper or LaTeX notes while learning, in the hope that it will help
me with recall and perhaps be useful to others learning this material.&lt;/p&gt;

&lt;p&gt;The focus in this post is on building a mental model (mathematical) for
sharding and communication based on linear algebra, in particular block
matrices, and on studying some low-level code for different communication
primitives.
The notes are based on two tutorials on sharding: &lt;a class=&quot;citation&quot; href=&quot;#scaling-book&quot;&gt;(Austin et al., 2025)&lt;/a&gt;
and &lt;a class=&quot;citation&quot; href=&quot;#jax_sharded_computation&quot;&gt;(JAX Team, 2025)&lt;/a&gt;, as well as some ‘original research’.&lt;/p&gt;

&lt;ul id=&quot;markdown-toc&quot;&gt;
  &lt;li&gt;&lt;a href=&quot;#setup&quot; id=&quot;markdown-toc-setup&quot;&gt;Setup&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#sharding&quot; id=&quot;markdown-toc-sharding&quot;&gt;Sharding&lt;/a&gt;    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#valid-shardings&quot; id=&quot;markdown-toc-valid-shardings&quot;&gt;Valid Shardings&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#example-matrix-multiplication&quot; id=&quot;markdown-toc-example-matrix-multiplication&quot;&gt;Example: Matrix Multiplication&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#types-of-inter-device-communication-with-sharded-computation&quot; id=&quot;markdown-toc-types-of-inter-device-communication-with-sharded-computation&quot;&gt;Types of Inter-Device Communication with Sharded Computation&lt;/a&gt;    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#communication-primitives-through-matrix-multiplication&quot; id=&quot;markdown-toc-communication-primitives-through-matrix-multiplication&quot;&gt;Communication Primitives through Matrix Multiplication&lt;/a&gt;        &lt;ul&gt;
          &lt;li&gt;&lt;a href=&quot;#case-1-unsharded-contracting-dimension-valid-output&quot; id=&quot;markdown-toc-case-1-unsharded-contracting-dimension-valid-output&quot;&gt;Case 1: Unsharded Contracting Dimension, Valid Output&lt;/a&gt;&lt;/li&gt;
          &lt;li&gt;&lt;a href=&quot;#case-2-one-multiplicand-sharded-along-contracting-dimension-allgather&quot; id=&quot;markdown-toc-case-2-one-multiplicand-sharded-along-contracting-dimension-allgather&quot;&gt;Case 2: One Multiplicand Sharded Along Contracting Dimension (AllGather)&lt;/a&gt;&lt;/li&gt;
          &lt;li&gt;&lt;a href=&quot;#case-3-both-multiplicands-sharded-along-contracting-dimension-allreduce&quot; id=&quot;markdown-toc-case-3-both-multiplicands-sharded-along-contracting-dimension-allreduce&quot;&gt;Case 3: Both Multiplicands Sharded Along Contracting Dimension (AllReduce)&lt;/a&gt;&lt;/li&gt;
          &lt;li&gt;&lt;a href=&quot;#case-4-sharded-non-contracting-dimensions-same-axis&quot; id=&quot;markdown-toc-case-4-sharded-non-contracting-dimensions-same-axis&quot;&gt;Case 4: Sharded Non-Contracting Dimensions, Same Axis&lt;/a&gt;&lt;/li&gt;
        &lt;/ul&gt;
      &lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#alltoall-communication&quot; id=&quot;markdown-toc-alltoall-communication&quot;&gt;AllToAll Communication&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#example-communication-primitives-in-jax&quot; id=&quot;markdown-toc-example-communication-primitives-in-jax&quot;&gt;Example: Communication Primitives in JAX&lt;/a&gt;    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#allgather&quot; id=&quot;markdown-toc-allgather&quot;&gt;AllGather&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#alltoall&quot; id=&quot;markdown-toc-alltoall&quot;&gt;AllToAll&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#allreduce-and-reducescatter&quot; id=&quot;markdown-toc-allreduce-and-reducescatter&quot;&gt;AllReduce and ReduceScatter&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#conclusion&quot; id=&quot;markdown-toc-conclusion&quot;&gt;Conclusion&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#acknowledgments&quot; id=&quot;markdown-toc-acknowledgments&quot;&gt;Acknowledgments&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;h3 id=&quot;setup&quot;&gt;Setup&lt;/h3&gt;

&lt;p&gt;Here’s the environment we will be working in—using Python 3.13 and JAX 0.7.1,
on a single TPU v4 host.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;functools&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;partial&lt;/span&gt;

&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax.numpy&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;devices&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h1 id=&quot;sharding&quot;&gt;Sharding&lt;/h1&gt;

&lt;p&gt;When we want to perform a high-throughput computation involving some data and
some mathematical operations (e.g., passing data through layers of a neural
network) using multiple hardware accelerators (e.g., TPUs), we need to, at
a minimum:&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;Have a scheme for indexing hardware accelerators relative to their physical
(spatial) layout.&lt;/li&gt;
  &lt;li&gt;Specify what data and what parameters (inputs) go on which accelerators.&lt;/li&gt;
  &lt;li&gt;Specify where the outputs of the computation should go.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;In JAX, these three steps are abstracted into creation of a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Mesh&lt;/code&gt; specifying
device layout and a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;PartitionSpec&lt;/code&gt; specifying how data is split across devices
relative to this mesh. Such a splitting is called a sharding.&lt;/p&gt;

&lt;p&gt;Here’s how the first two steps look in code:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;make_mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 1x v4 tpu host
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;NamedSharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;debug&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;visualize_array_sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
┌──────────┬──────────┐
│          │          │
│  TPU 0   │  TPU 1   │
│          │          │
│          │          │
├──────────┼──────────┤
│          │          │
│  TPU 2   │  TPU 3   │
│          │          │
│          │          │
└──────────┴──────────┘
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Above,&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;The &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Mesh&lt;/code&gt; indexes devices based on spatial axes. JAX will optimize this
indexing based on the actual device topology.&lt;/li&gt;
  &lt;li&gt;The &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;NamedSharding&lt;/code&gt; creates a sharding specification relative to the mesh,
which is defined by the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;PartitionSpec&lt;/code&gt; we provide. Here, we specify that a 2D
array has its first dimension sharded along the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt; dimension of the mesh and
its second dimension along &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;y&lt;/code&gt;.&lt;/li&gt;
  &lt;li&gt;For the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;PartitionSpec&lt;/code&gt;, we give &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;None&lt;/code&gt; for dimensions that should not be
sharded. That means this dimension will be present on all devices.&lt;/li&gt;
  &lt;li&gt;If a mesh axis does not appear in the sharding, then the sharded array is
fully replicated over that axis. In the extreme case, a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;NamedSharding&lt;/code&gt; of all
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;None&lt;/code&gt;s puts the full array on every device.&lt;/li&gt;
  &lt;li&gt;The compiler will complain about invalid shardings.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;We can also move data around after creation on, say, one host.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;42&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;normal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;B_sharded&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;device_put&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;debug&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;visualize_array_sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;debug&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;visualize_array_sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B_sharded&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
┌───────────────────────┐
│                       │
│                       │
│                       │
│                       │
│         TPU 0         │
│                       │
│                       │
│                       │
│                       │
└───────────────────────┘
┌──────────┬──────────┐
│          │          │
│  TPU 0   │  TPU 1   │
│          │          │
│          │          │
├──────────┼──────────┤
│          │          │
│  TPU 2   │  TPU 3   │
│          │          │
│          │          │
└──────────┴──────────┘
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;In addition, we can ‘combine’ physical device axes into one or more aggregated
axes. This is a common operation, for example in parallelizing a batch axis
across all devices. Since the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Mesh&lt;/code&gt; is our link between partition indices and
physical devices, we do this through the mesh:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;mesh_flat&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;make_mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,),&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;xy&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;sharding_flat&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;NamedSharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh_flat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;xy&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;C&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding_flat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;debug&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;visualize_array_sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;C&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
┌───────────────────────┐
│         TPU 0         │
├───────────────────────┤
│         TPU 2         │
├───────────────────────┤
│         TPU 1         │
├───────────────────────┤
│         TPU 3         │
└───────────────────────┘
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;When we perform computations on data, we can either let the JAX compiler infer
the output sharding, or we can specify it ourselves.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;nd&quot;&gt;@jax.jit&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;f_contract&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;nd&quot;&gt;@partial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jit&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;out_shardings&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;NamedSharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;f_contract_replicated&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;n&quot;&gt;result&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;f_contract&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B_sharded&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;result_replicated&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;f_contract_replicated&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B_sharded&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Inferred sharding:&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;debug&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;visualize_array_sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;result&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Manual sharding (replicated):&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;debug&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;visualize_array_sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;result_replicated&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
Inferred sharding:
┌───────┬───────┐
│TPU 0,2│TPU 1,3│
└───────┴───────┘
Manual sharding (replicated):
┌───────────┐
│TPU 0,1,2,3│
└───────────┘
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Passing a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;NamedSharding&lt;/code&gt; to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jax.jit&lt;/code&gt;’s &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;out_shardings&lt;/code&gt; keyword argument lets
us specify here that the result of the computation should be propagated to all
devices. This entails some communication!&lt;/p&gt;

&lt;p&gt;Given a valid input-output sharding for a more complex computation, the JAX
compiler can heuristically optimize the inner bits of the computation
(intermediate computations’ shardings, and necessary communication) to try to
make it run as efficiently as possible. In cases where this doesn’t work, JAX
has ways to hint at the compiler, as well as more advanced programming
paradigms, to explicitly specify intermediate sharding and communications.&lt;/p&gt;

&lt;h2 id=&quot;valid-shardings&quot;&gt;Valid Shardings&lt;/h2&gt;

&lt;p&gt;Given a mesh with $M$ axes, and an array with $N$ axes, we can enumerate all
valid shardings as follows. First note:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;A sharding corresponds to, for each of the $N$ array axes, the selection of
one of the $M$ mesh axes, or &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;None&lt;/code&gt;.&lt;/li&gt;
  &lt;li&gt;Each mesh axis can appear at most once in a valid sharding.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Then for each $k = 0, \dots, \min \set{M, N}$, pick $\binom{M}{k}$ of the $M$
mesh axes and $\binom{N}{k}$ of the $N$ array axes, then consider all possible
permutations of those $k$ selected axes. This is a total of&lt;/p&gt;

\[\sum_{k=0}^{\min \set{M, N}} \binom{M}{k} \binom{N}{k} k!\]

&lt;p&gt;possible shardings. This is actually not too large—roughly on the order of
$N^M$ when $N$ is large.&lt;sup id=&quot;fnref:1&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:1&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;1&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

&lt;h1 id=&quot;example-matrix-multiplication&quot;&gt;Example: Matrix Multiplication&lt;/h1&gt;

&lt;p&gt;The most common operations in neural networks are related to matrix
multiplications, and so thinking in terms of sharding data with two axes goes
a long way. Say we have matrices $\vA \in \bbR^{m \times n}$ and $\vB \in
\bbR^{n \times p}$. A sharding of a 2D array (matrix) corresponds to two things:
a partitioning of the matrix into sub-matrices, and a splitting of those
sub-matrices across devices. In JAX, these two operations are coupled through
the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Mesh&lt;/code&gt;/&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;NamedSharding&lt;/code&gt; interface, as we described above:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;If an array dimension is sharded across a mesh axis, that dimension is
partitioned, and the partitioned components will be split across that
mesh axis.&lt;/li&gt;
  &lt;li&gt;If an array dimension is not sharded, that dimension is not partitioned, and
it appears across all devices.&lt;/li&gt;
  &lt;li&gt;If a mesh axis does not appear in a sharding specification, the sharded data
is copied across that mesh axis.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Let’s consider the special case of a 2D mesh, in particular $2 \times 2$. For
simplicity, let’s place the first mesh axis in correspondence with the first
array dimension, and the second mesh axis in correspondence with the second
array dimension.&lt;sup id=&quot;fnref:2&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:2&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;2&lt;/a&gt;&lt;/sup&gt; Then every partitioning of $\vA$ or $\vB$, for example&lt;/p&gt;

\[\begin{equation}\label{eq:partitioning-0}
\begin{bmatrix}
  \vA_{00} &amp;amp; \vA_{01} \\
  \vA_{10} &amp;amp; \vA_{11}
\end{bmatrix},
\end{equation}\]

&lt;p&gt;corresponds to an arrangement of the matrix across devices. In the above example,
we have $\vA_{ij} \in \bbR^{(m/2) \times (n/2)}$, and the coordinates $(i, j)$
of the submatrix correspond to the coordinates of the device it is stored on.
For a different partitioning, for example&lt;/p&gt;

\[\begin{bmatrix}
  \vA_{0} \\
  \vA_{1}
\end{bmatrix},\]

&lt;p&gt;we need to take copying into account: here we have $\vA_{i} \in \bbR^{(m/2)
\times n}$, and the physical device layout is&lt;/p&gt;

\[\begin{equation}\label{eq:partitioning-1-dev}
\begin{bmatrix}
  \vA_{0} &amp;amp; \vA_{0} \\
  \vA_{1} &amp;amp; \vA_{1}
\end{bmatrix}.
\end{equation}\]

&lt;p&gt;Notice that $\eqref{eq:partitioning-0}$ and $\eqref{eq:partitioning-1-dev}$ are
incomparable: the second one uses twice as much memory! But there is a simple
‘algorithm’ to go from a partitioning to a device layout: if an axis is not
partitioned, simply copy the array along that axis to reach the full mesh size
(and remember that this leads to additional memory usage).&lt;/p&gt;

&lt;p&gt;Now, what happens if we want to compute the product $\vA \vB$ when $\vA$ is
sharded in some way, for example as in $\eqref{eq:partitioning-0}$?
First, thinking abstractly: by properties of block matrix multiplication, we
have&lt;/p&gt;

\[\vA \vB
=
\begin{bmatrix}
  \vA_{11} &amp;amp; \vA_{12} \\
  \vA_{21} &amp;amp; \vA_{22}
\end{bmatrix}
\begin{bmatrix}
  \vB_{11} &amp;amp; \vB_{12} \\
  \vB_{21} &amp;amp; \vB_{22}
\end{bmatrix}
=
\begin{bmatrix}
  \vA_{11}\vB_{11} + \vA_{12} \vB_{21} &amp;amp; \vA_{11}\vB_{12} + \vA_{12} \vB_{22} \\
  \vA_{21}\vB_{11} + \vA_{22} \vB_{21} &amp;amp; \vA_{21}\vB_{12} + \vA_{22} \vB_{22} \\
\end{bmatrix}.\]

&lt;p&gt;So no matter how $\vB$ is sharded, we need to perform communication such that we
compute each of the matrix products appearing above. We can think about two
cases:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;&lt;strong&gt;No communication involving $\vA$&lt;/strong&gt;: Here, we can use the RHS of the above
display to tell us what needs to be computed. Processor $(0, 0)$ needs the
first (block) row of $\vB$, processor $(0, 1)$ needs the second (block) row of
$\vB$, etc.—so the sharding of $\vB$ had better support this, or
communication will be necessary. &lt;em&gt;Regardless&lt;/em&gt;, some communication between
devices is necessary to accumulate the block matrix products: e.g., we need to
add results from processor $(0, 0)$ and $(0, 1)$ to get the top-left block of
the product.&lt;/li&gt;
  &lt;li&gt;
    &lt;p&gt;&lt;strong&gt;We can re-structure how $\vA$ is sharded&lt;/strong&gt;: It might be possible, given
a specific sharding for $\vB$, to have a more efficient computation by
sharding $\vA$ differently (and vice versa). Thinking in terms of
block partitions, this is the same as observing that there are many ways of
partitioning a matrix into blocks, each of which gives a different
decomposition for the matrix product. For example, this is also a compatible
partitioning (with suitable definitions of the blocks):&lt;/p&gt;

\[\begin{equation}\label{eq:nice-sharding}
\vA \vB
=
\begin{bmatrix}
  \vA_{1} \\
  \vA_{2}
\end{bmatrix}
\begin{bmatrix}
  \vB_{1} &amp;amp; \vB_{2}
\end{bmatrix}
=
\begin{bmatrix}
  \vA_{1}\vB_{1} &amp;amp; \vA_{1}\vB_{2} \\
  \vA_{2}\vB_{1} &amp;amp; \vA_{2}\vB_{2}
\end{bmatrix},
\end{equation}\]

    &lt;p&gt;and so is&lt;/p&gt;

\[\vA \vB
=
\begin{bmatrix}
  \vA_{1} &amp;amp; \vA_{2}
\end{bmatrix}
\begin{bmatrix}
  \vB_{1} \\
  \vB_{2}
\end{bmatrix}
=
\vA_1 \vB_1 + \vA_2 \vB_2.\]

    &lt;p&gt;Re-sharding requires communication, though, and so this may or may not be more
efficient.&lt;/p&gt;
  &lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;As an exercise, verify that $\eqref{eq:nice-sharding}$ corresponds to
a zero-communication sharding for a $2 \times 2$ output sharding.&lt;sup id=&quot;fnref:4&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:4&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;3&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

&lt;p&gt;Now, if $\vA$ and $\vB$ both have their own sharding, in general these
correspond to (possibly incompatible) block matrix partitions of each matrix.
For each individual partition, we can, as above, write out the blockwise matrix products
and accumulation operations needed to compute the product, and try to intuit
what computation/communication should happen—in particular involving
re-sharding one or both arrays. In JAX, the XLA compiler will be recruited to
perform this optimization; as algorithm designers, we can try to provide
sensible shardings up front in order to make sure we end up with an efficient
system.&lt;/p&gt;

&lt;h1 id=&quot;types-of-inter-device-communication-with-sharded-computation&quot;&gt;Types of Inter-Device Communication with Sharded Computation&lt;/h1&gt;

&lt;p&gt;The matrix multiplication example suggests that different forms of communication
naturally arise in distributed computation. We will describe the ones that are
implemented in XLA here. The JAX scaling book &lt;a class=&quot;citation&quot; href=&quot;#scaling-book&quot;&gt;(Austin et al., 2025)&lt;/a&gt; gives
a nice way to think about matrix multiplication of sharded arrays using
a variant of named-axis notation. Since this is closer to code, we’ll describe
this briefly below.&lt;/p&gt;

&lt;h2 id=&quot;communication-primitives-through-matrix-multiplication&quot;&gt;Communication Primitives through Matrix Multiplication&lt;/h2&gt;

&lt;p&gt;We saw in the previous section how to think about sharding of 2D arrays in terms
of block matrices:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;Sharded array axes are partitioned; non-sharded array axes aren’t.&lt;/li&gt;
  &lt;li&gt;Any mesh axis that isn’t used means the partitioned array is copied across
that axis.
This gives us a scheme to map a block matrix to the corresponding arrangement of
that matrix across physical devices.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;With named index notation, things are more compact. Consider a 2D mesh with axes
$x$ and $y$, and let us write $A[I, J]$ for the 2D array $\vA$ (likewise for
$\vB$) we looked at in the previous section. We write $A[I_x, J]$ (etc.) to
denote the sharding of $A$ along its first dimension with respect to the mesh
axis $x$. The use of $I$, $J$, etc.\ for ‘dummy’ indices is in analogy to
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;einsum&lt;/code&gt;-type notation, which is very useful for tensor operations.&lt;/p&gt;

&lt;p&gt;Now consider a matrix product of sharded matrices $A[I_x, J] \cdot B[J, K_y]$.
Suppose we want the output, say $\vC$, to be sharded as $C[I_x, K_y]$. Then this
matrix product can be computed &lt;em&gt;with no communication&lt;/em&gt;: to see why, think in
terms of the correspondence to block matrix partitions and their associated map
to device arrangements in the previous section. However, for an unsharded output
$C[I, K]$, it would be necessary to perform some communication—either
calculating $C[I_x, K_y]$ and then sharing the local results among all devices,
or something else.&lt;/p&gt;

&lt;p&gt;The following is a convenient way to think about this, following the previous
example:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;Given a sharding of matrix multiplicands $A$, $B$, there is a ‘natural’ output
sharding whenever the contracting dimension is unsharded. (As above, this
comes by ‘removing’ the contracting axis, just like in einsum notation.) If
this sharding is invalid, communication is necessary.&lt;/li&gt;
  &lt;li&gt;If the contracting dimension is sharded, different types of communication are
necessary to turn the multiplication into one with an unsharded contracting
dimension. We can identify typical ‘best practices’ for what communication
type to use based on how the contracting dimension is sharded.&lt;/li&gt;
  &lt;li&gt;If an output sharding is specified and it is not the ‘natural’ one, some
communication is required to achieve it. We can use the techniques from either
of the two previous bullets to do this.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;We’ll describe these different cases below.&lt;/p&gt;

&lt;h3 id=&quot;case-1-unsharded-contracting-dimension-valid-output&quot;&gt;Case 1: Unsharded Contracting Dimension, Valid Output&lt;/h3&gt;

&lt;p&gt;As above, the ‘natural’ output sharding is just an einsum-style replacement:&lt;/p&gt;

\[A[I_x, J] \cdot B[J, K] \to C[I_x, K].\]

&lt;p&gt;If this natural output sharding is invalid, we need to do something else (see
below).&lt;/p&gt;

&lt;h3 id=&quot;case-2-one-multiplicand-sharded-along-contracting-dimension-allgather&quot;&gt;Case 2: One Multiplicand Sharded Along Contracting Dimension (AllGather)&lt;/h3&gt;

&lt;p&gt;For a product like&lt;/p&gt;

\[A[I, J_x] \cdot B[J, K] \to C[I, K],\]

&lt;p&gt;it isn’t possible to directly perform the multiplication, because in order to
contract, we either need to perform a series of block-matrix multiplies, then
accumulate the results (see below), or have the entire contracting dimension of
both arrays available on all devices (see above).&lt;/p&gt;

&lt;p&gt;We use &lt;strong&gt;AllGather&lt;/strong&gt; for this, which removes a specified sharded axis:&lt;/p&gt;

\[\mathrm{AllGather}_{x}(A[I, J_x]) \to A[I, J].\]

&lt;p&gt;For a compound mesh, AllGather can also just remove one sub-axis:&lt;/p&gt;

\[\mathrm{AllGather}_{x}(A[I, J_{xy}]) \to A[I, J_y].\]

&lt;p&gt;AllGather can also be used to satisfy certain output shardings that aren’t
naturally compatible with the input shardings. For example, in the product&lt;/p&gt;

\[A[I_x, J] \cdot B[J, K] \to C[I, K],\]

&lt;p&gt;one removes the sharded axis either before or after the computation in order to
achieve the desired output sharding.&lt;/p&gt;

&lt;h4 id=&quot;allgather-cost&quot;&gt;AllGather Cost&lt;/h4&gt;

&lt;p&gt;&lt;strong&gt;On TPUs&lt;/strong&gt; with enough chips to form a cube, AllGather can be performed using
algorithms that exploit the toroidal interconnectivity of TPU devices. For
example, to AllGather a single axis, the following procedure can be used:&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;Initialize a local buffer on each device with its local shard.&lt;/li&gt;
  &lt;li&gt;Send the buffer contents to the next device; overwrite it with the contents of
  the previous device.&lt;/li&gt;
  &lt;li&gt;Store the received shard in memory.&lt;/li&gt;
  &lt;li&gt;Repeat.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;This takes at most $N/2$ rounds of communication if the mesh axis has size $N$.
For an array of size $V$ bytes, each round sends $V/N$ bytes over each link.
Given an ICI bandwidth of $W$ (unidirectional),&lt;sup id=&quot;fnref:6&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:6&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;4&lt;/a&gt;&lt;/sup&gt; the total time this takes is
no more than&lt;/p&gt;

\[\frac{N}{2} \cdot \frac{V}{NW} = \frac{V}{2W} \text{ s},\]

&lt;p&gt;which is independent of the size of the mesh.
However, one must also take into account the intrinsic overhead of ICI
communication. Each operation takes around $T_{\min} = 10^{-6} \text{ s}$ &lt;a class=&quot;citation&quot; href=&quot;#scaling-book&quot;&gt;(Austin et al., 2025)&lt;/a&gt;, so the total time is actually only bounded by&lt;/p&gt;

\[\max \set*{\frac{V}{2W}, \frac{NT_{\min}}{2}}.\]

&lt;p&gt;In particular, there are &lt;em&gt;latency-bound&lt;/em&gt; (small array size) and
&lt;em&gt;throughput-bound&lt;/em&gt; (large array size) regimes of communication.&lt;/p&gt;

&lt;p&gt;AllGathering a compound axis (e.g., $xy$) is similar: one can utilize some
per-device state to efficiently communicate data across the entire
multidimensional mesh. Since each device can then communicate along each mesh
axis simultaneously, the bandwidth increases by a factor proportional to the
number of mesh axes, but the maximum possible latency also increases, since
there are more devices. The total time is no larger than&lt;/p&gt;

\[\max \set*{\frac{V}{2n_{\mathrm{axes}}W}, \frac{N_{\mathrm{total}}T_{\min}}{2}},\]

&lt;p&gt;where $N_{\mathrm{total}}$ is the product of the axes lengths.&lt;/p&gt;

&lt;h3 id=&quot;case-3-both-multiplicands-sharded-along-contracting-dimension-allreduce&quot;&gt;Case 3: Both Multiplicands Sharded Along Contracting Dimension (AllReduce)&lt;/h3&gt;

&lt;p&gt;In this case, we want to compute a matrix product like&lt;/p&gt;

\[A[I, J_x] \cdot B[J_x, K] \to C[I, K].\]

&lt;p&gt;Thinking in terms of block matrices, we know that computing this matrix product
entails a sum of the block matrix multiplications corresponding to the shards of
both $A$ and $B$, and since the contracting axis is sharded, we can:&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;Compute per-device multiplications of the sharded matrices.&lt;/li&gt;
  &lt;li&gt;Add these up.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The communication primitive that performs this reduction is called
&lt;strong&gt;AllReduce&lt;/strong&gt;. It is similar to AllGather, but it sums up the components being
communicated, and distributes the resulting sum to all devices.&lt;/p&gt;

&lt;p&gt;We will use the notation from &lt;a class=&quot;citation&quot; href=&quot;#scaling-book&quot;&gt;(Austin et al., 2025)&lt;/a&gt;: we write&lt;/p&gt;

\[C[I, K]\{ U_x \}\]

&lt;p&gt;to denote an array $C$ which is “unreduced” over the mesh axis $x$, i.e. the
partial products are sharded over the $x$ mesh axis, which gives&lt;/p&gt;

\[\mathrm{AllReduce}_y\left( A[I_x, J]\{ U_y \} \right) \to A[I_x, J]\]

&lt;p&gt;as the signature for AllReduce.&lt;/p&gt;

&lt;h4 id=&quot;decomposition-in-terms-of-reducescatter&quot;&gt;Decomposition in Terms of ReduceScatter&lt;/h4&gt;

&lt;p&gt;Generally, an AllReduce is about &lt;em&gt;two times as expensive as an AllGather&lt;/em&gt;. One way
to see why is to decompose it into another useful communication primitive: the
&lt;strong&gt;ReduceScatter&lt;/strong&gt; operation.&lt;/p&gt;

&lt;p&gt;ReduceScatter has signature&lt;/p&gt;

\[\mathrm{ReduceScatter}_{y, J}\left( A[I_x, J]\{ U_y \} \right) \to A[I_x, J_y].\]

&lt;p&gt;It takes an array unreduced along a mesh axis, then shards the array along
a specified array axis with respect to that mesh axis. To compute an AllReduce,
we can compose a ReduceScatter with an AllGather:&lt;/p&gt;

\[\mathrm{AllReduce}_y\left( A[I_x, J]\{ U_y \} \right)
=
\mathrm{AllGather}_{y}\left(
\mathrm{ReduceScatter}_{y, J}\left( A[I_x, J]\{ U_y \} \right)
\right).\]

&lt;p&gt;For multidimensional arrays, there is a possibility to choose the array axis
that is ReduceScattered to optimize performance.&lt;/p&gt;

&lt;p&gt;ReduceScatter can be performed by an algorithm similar to the approach we
discussed for AllGather. The process is somewhat difficult to describe
algorithmically, but there is a very good pictorial representation in &lt;a class=&quot;citation&quot; href=&quot;#scaling-book&quot;&gt;(Austin et al., 2025)&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;/assets/blog/reduce-scatter.gif&quot; alt=&quot;Reduce Scatter&quot; /&gt;&lt;/p&gt;

&lt;p&gt;The upshot is that this similarity persists to the latency and
throughput analysis of the algorithm, making it identical to that of AllGather.
This demonstrates the cost of AllReduce as twice that of AllGather.&lt;/p&gt;

&lt;h4 id=&quot;aside-the-backward-pass&quot;&gt;Aside: The Backward Pass&lt;/h4&gt;

&lt;p&gt;There is a more fundamental relationship between the AllGather and ReduceScatter
operations, suggested by the representation of AllReduce above: they can be
represented as adjoints of one another, which means that an AllGather in the
forward pass will trigger a ReduceScatter in the backward pass!&lt;/p&gt;

&lt;p&gt;To see this, let’s go back to our more visual block matrix example. As above,
consider a simplified setting where a matrix is sharded along a 1D mesh
corresponding to the first array dimension. With mesh size $3$ for simplicity,
we represent this as the block matrix&lt;/p&gt;

\[\vA = \begin{bmatrix}
\vA_0 \\
\vA_1 \\
\vA_2
\end{bmatrix}.\]

&lt;p&gt;Remember that the partitioning of the block matrix corresponds to the device
layout in this model.
Now, if we perform an AllGather, we have the ($3\times$ larger) block matrix
output&lt;/p&gt;

\[\mathrm{AllGather}(\vA) = \begin{bmatrix}
\vA \\
\vA \\
\vA
\end{bmatrix}
=
\begin{bmatrix}
\begin{bmatrix}
\vA_0 \\
\vA_1 \\
\vA_2
\end{bmatrix}
\\
\begin{bmatrix}
\vA_0 \\
\vA_1 \\
\vA_2
\end{bmatrix}
\\
\begin{bmatrix}
\vA_0 \\
\vA_1 \\
\vA_2
\end{bmatrix}
\end{bmatrix}.\]

&lt;p&gt;Let’s imagine we perform some subsequent computations on the AllGathered output,
and we want to compute a backward pass. In general, each individual device will
have different gradients, which we can write in partitioned form as&lt;/p&gt;

\[\begin{bmatrix}
\Delta \vA^0 \\
\Delta \vA^1 \\
\Delta \vA^2
\end{bmatrix}
=
\begin{bmatrix}
\begin{bmatrix}
\Delta \vA_0^0 \\
\Delta \vA_1^0 \\
\Delta \vA_2^0
\end{bmatrix}
\\
\begin{bmatrix}
\Delta \vA_0^1 \\
\Delta \vA_1^1 \\
\Delta \vA_2^1
\end{bmatrix}
\\
\begin{bmatrix}
\Delta \vA_0^2 \\
\Delta \vA_1^2 \\
\Delta \vA_2^2
\end{bmatrix}
\end{bmatrix}.\]

&lt;p&gt;Then because AllGather behaves in the sense of linear algebra as a copying
operation, its backward pass involves its adjoint operation, which is a sum.
We end up with&lt;sup id=&quot;fnref:7&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:7&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;5&lt;/a&gt;&lt;/sup&gt;&lt;/p&gt;

\[\mathrm{dAllGather}^* \left(
\begin{bmatrix}
\Delta \vA^0 \\
\Delta \vA^1 \\
\Delta \vA^2
\end{bmatrix}
\right)
=
\begin{bmatrix}
\sum_{i=1}^3\Delta \vA^i_0 \\
\sum_{i=1}^3 \Delta \vA^i_1 \\
\sum_{i=1}^3 \Delta \vA^i_2
\end{bmatrix}.\]

&lt;p&gt;This is exactly what we’d get if we:&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;Interpret the per-device gradients as unreduced partial sums with respect to
the mesh axis of the AllGather;&lt;/li&gt;
  &lt;li&gt;Perform a ReduceScatter with respect to this mesh axis and the array axis
that was sharded before the AllGather.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;So, an AllGather in the forward pass produces a corresponding ReduceScatter in
the backward pass!&lt;/p&gt;

&lt;p&gt;Similarly, if there is a ReduceScatter in the forward pass that produces an
array sharded along one of its axes, then in the backward pass, there will be an
AllGather operation along that axis in the backward pass.
And finally, because an AllReduce can be expressed as the composition of
an AllGather with a ReduceScatter, &lt;em&gt;the backward pass of an AllReduce is an
AllReduce&lt;/em&gt;!&lt;/p&gt;

&lt;h3 id=&quot;case-4-sharded-non-contracting-dimensions-same-axis&quot;&gt;Case 4: Sharded Non-Contracting Dimensions, Same Axis&lt;/h3&gt;

&lt;p&gt;A direct reduction to the natural sharding for this case would lead to an
invalid sharding:&lt;/p&gt;

\[A[I_x, J] \cdot B[J, K_x] \to C[I_x, K_x]\]

&lt;p&gt;A mesh axis cannot appear multiple times in a sharding specification.&lt;/p&gt;

&lt;p&gt;To resolve this, we just AllGather one of the axes in the multiplicands:&lt;/p&gt;

\[\begin{split}
\mathrm{AllGather}_x(A[I_x, J]) \to A[I, J] \\
A[I, J] \cdot B[J, K_x] \to C[I, K_x].
\end{split}\]

&lt;p&gt;We can select which depending on the downstream shardings needed.&lt;/p&gt;

&lt;h2 id=&quot;alltoall-communication&quot;&gt;AllToAll Communication&lt;/h2&gt;

&lt;p&gt;This primitive can be thought of as a resharding operation. It has signature&lt;/p&gt;

\[\mathrm{AllToAll}_{x, J}( A[I_x, J] ) \to A[I, J_x].\]

&lt;p&gt;AllToAll is actually significantly cheaper than an AllGather, by a factor of 4.
This is another one where the excellent picture from &lt;a class=&quot;citation&quot; href=&quot;#scaling-book&quot;&gt;(Austin et al., 2025)&lt;/a&gt;
describes where this efficiency is coming from well.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;/assets/blog/all-to-all.gif&quot; alt=&quot;All To All&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Notice that the message
that is being sent over each link is composed of pieces of different lengths
from each shard. The communication is balanced, but later messages are shorter,
which may cause issues with latency-boundedness of the communication.&lt;/p&gt;

&lt;h1 id=&quot;example-communication-primitives-in-jax&quot;&gt;Example: Communication Primitives in JAX&lt;/h1&gt;

&lt;p&gt;Let’s test out what we learned above in the JAX sandbox we started this post
with. For slightly more granular control of sharding, we’ll switch to Explicit
mode sharding (see e.g. &lt;a class=&quot;citation&quot; href=&quot;#jax-explicit-sharding&quot;&gt;(JAX Team, 2025)&lt;/a&gt;).&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;make_mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis_types&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;AxisType&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Explicit&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 1x v4 tpu host
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;set_mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2048&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;8192&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bfloat16&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;NamedSharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)))&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;typeof&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
ShapedArray(bfloat16[2048@x,8192@y])
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h2 id=&quot;allgather&quot;&gt;AllGather&lt;/h2&gt;

&lt;p&gt;First, a simple AllGather implementation along the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;y&lt;/code&gt; mesh axis—we compile
the resharding operation and inspect the “high-level operations” (HLO) to see
what actually gets implemented on the device.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;nd&quot;&gt;@jax.jit&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;all_gather&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;reshard&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;NamedSharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)))&lt;/span&gt;


&lt;span class=&quot;n&quot;&gt;lowered&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;all_gather&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;lower&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;compiled&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lowered&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;compile&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;compiled&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Output type: &lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;typeof&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;StableHLO:&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lowered&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;as_text&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;XLA HLO:&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;compiled&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;as_text&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
Output type:  bfloat16[2048@x,8192]

StableHLO:
 module @jit_all_gather attributes {mhlo.num_partitions = 4 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = &amp;lt;[&quot;x&quot;=2, &quot;y&quot;=2]&amp;gt;
  func.func public @main(%arg0: tensor&amp;lt;2048x8192xbf16&amp;gt; {sdy.sharding = #sdy.sharding&amp;lt;@mesh, [{&quot;x&quot;}, {&quot;y&quot;}]&amp;gt;}) -&amp;gt; (tensor&amp;lt;2048x8192xbf16&amp;gt; {jax.result_info = &quot;result&quot;, sdy.sharding = #sdy.sharding&amp;lt;@mesh, [{&quot;x&quot;}, {}]&amp;gt;}) {
    %0 = sdy.sharding_constraint %arg0 &amp;lt;@mesh, [{&quot;x&quot;}, {}]&amp;gt; : tensor&amp;lt;2048x8192xbf16&amp;gt;
    return %0 : tensor&amp;lt;2048x8192xbf16&amp;gt;
  }
}

XLA HLO:
 HloModule jit_all_gather, is_scheduled=true, entry_computation_layout={(bf16[1024,4096]{1,0:T(8,128)(2,1)})-&amp;gt;bf16[1024,8192]{1,0:T(8,128)(2,1)}}, num_partitions=4
ENTRY %main.5_spmd (param: bf16[1024,4096]) -&amp;gt; bf16[1024,8192] {
  %param = bf16[1024,4096]{1,0:T(8,128)(2,1)} parameter(0), sharding={devices=[2,2]&amp;lt;=[4]}, metadata={op_name=&quot;x&quot;}
  %all-gather = bf16[1024,8192]{1,0:T(8,128)(2,1)S(3)} all-gather(%param), channel_id=1, replica_groups=[2,2]&amp;lt;=[4], dimensions={1}, use_global_device_ids=true, metadata={op_name=&quot;jit(all_gather)/reshard&quot; source_file=&quot;/tmp/ipykernel_794398/1925679859.py&quot; source_line=3}, backend_config={&quot;flag_configs&quot;:[],&quot;barrier_config&quot;:{&quot;barrier_type&quot;:&quot;CUSTOM&quot;,&quot;id&quot;:&quot;0&quot;},&quot;scoped_memory_configs&quot;:[],&quot;collective_algorithm_config&quot;:{&quot;emitter&quot;:&quot;1DAllGatherNonMajorDim&quot;,&quot;debug&quot;:&quot;\ngroup_size = 2 \nhas_reordering_map: false \nper_stride_size = 65536 bytes \nshard_size = 8388608 bytes &quot;},&quot;used_scoped_memory_configs&quot;:[{&quot;memory_space&quot;:&quot;1&quot;,&quot;offset&quot;:&quot;0&quot;,&quot;size&quot;:&quot;0&quot;}],&quot;retry_config&quot;:{&quot;retry_count&quot;:&quot;0&quot;}}
  ROOT %copy.3 = bf16[1024,8192]{1,0:T(8,128)(2,1)} copy(%all-gather), backend_config={&quot;flag_configs&quot;:[],&quot;window_config&quot;:{&quot;kernel_window_bounds&quot;:[],&quot;output_window_bounds&quot;:[&quot;16&quot;,&quot;64&quot;],&quot;input_window_bounds&quot;:[],&quot;estimated_cycles&quot;:&quot;58232&quot;,&quot;iteration_bounds&quot;:[&quot;8&quot;,&quot;1&quot;]},&quot;megacore_config&quot;:{&quot;megacore_split_dim&quot;:&quot;0&quot;},&quot;scoped_memory_configs&quot;:[],&quot;used_scoped_memory_configs&quot;:[{&quot;memory_space&quot;:&quot;1&quot;,&quot;offset&quot;:&quot;0&quot;,&quot;size&quot;:&quot;8388608&quot;}],&quot;retry_config&quot;:{&quot;retry_count&quot;:&quot;0&quot;}}
}
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The first representation above (output of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;lower&lt;/code&gt;) is in XLA’s StableHLO
language. It is slightly imposing, but there is a clearly written
&lt;a href=&quot;https://github.com/openxla/stablehlo/blob/main/docs/spec.md&quot;&gt;spec&lt;/a&gt; that helps
with parsing it after a bit of studying. It’s still too high level to clearly
show what communication primitives are being generated, though.&lt;/p&gt;

&lt;p&gt;The second representation (output of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;compile&lt;/code&gt;) is the result of the XLA
compiler’s first device-independent compilation stage; it’s in its own dialect
of HLO, which is different from StableHLO and doesn’t have as easily-accessible
of a public spec.&lt;sup id=&quot;fnref:8&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:8&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;6&lt;/a&gt;&lt;/sup&gt;
Here’s how to parse it above:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;Types are relatively self-explanatory, aside from the text in braces—the
first sequence of integers here is most important, and it specifies the memory
layout of the array (see &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;layout.h&lt;/code&gt; in the XLA source). Axis indices are
listed in minor-to-major order (i.e., &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;1, 0&lt;/code&gt; denotes row-major layout).&lt;/li&gt;
  &lt;li&gt;The &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;all-gather&lt;/code&gt; semantics are clear at a high level: we gather along
dimension $1$ of our array. The &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;replica_groups&lt;/code&gt; argument specifies which
devices communicate in the gather (see &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tile_assignment.h&lt;/code&gt; in the XLA
source)—compare to our input &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;%param&lt;/code&gt;’s &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;sharding&lt;/code&gt; specified above to see
that these agree.&lt;/li&gt;
  &lt;li&gt;Independent of these lower-level details, we can see that the input and output
shapes are as we expect.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;alltoall&quot;&gt;AllToAll&lt;/h2&gt;

&lt;p&gt;We can test out AllToAll with a similar high-level API. Let’s start by
transposing the sharding axis of the output of the previous AllGather.&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;nd&quot;&gt;@jax.jit&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;all_to_all&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;reshard&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;NamedSharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)))&lt;/span&gt;


&lt;span class=&quot;n&quot;&gt;lowered&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;all_to_all&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;lower&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;compiled&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lowered&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;compile&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;C&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;compiled&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;B&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Output type: &lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;typeof&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;C&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;StableHLO:&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lowered&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;as_text&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;XLA HLO:&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;compiled&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;as_text&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;p&gt;The HLO that we get is more complex:&lt;/p&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
Output type:  bfloat16[2048,8192@x]
StableHLO:
 module @jit_all_to_all attributes {mhlo.num_partitions = 4 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = &amp;lt;[&quot;x&quot;=2, &quot;y&quot;=2]&amp;gt;
  func.func public @main(%arg0: tensor&amp;lt;2048x8192xbf16&amp;gt; {sdy.sharding = #sdy.sharding&amp;lt;@mesh, [{&quot;x&quot;}, {}]&amp;gt;}) -&amp;gt; (tensor&amp;lt;2048x8192xbf16&amp;gt; {jax.result_info = &quot;result&quot;, sdy.sharding = #sdy.sharding&amp;lt;@mesh, [{}, {&quot;x&quot;}]&amp;gt;}) {
    %0 = sdy.sharding_constraint %arg0 &amp;lt;@mesh, [{}, {&quot;x&quot;}]&amp;gt; : tensor&amp;lt;2048x8192xbf16&amp;gt;
    return %0 : tensor&amp;lt;2048x8192xbf16&amp;gt;
  }
}
XLA HLO:
 HloModule jit_all_to_all, is_scheduled=true, entry_computation_layout={(bf16[1024,8192]{1,0:T(8,128)(2,1)})-&amp;gt;bf16[2048,4096]{1,0:T(8,128)(2,1)}}, num_partitions=4
%fused_computation (param_0.1: bf16[128,2,8,4096]) -&amp;gt; bf16[1024,2,4096] {
  %param_0.1 = bf16[128,2,8,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0)
  %copy.4 = bf16[128,2,8,4096]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.1), metadata={op_name=&quot;jit(all_to_all)/reshard&quot; source_file=&quot;/tmp/ipykernel_874047/2971478236.py&quot; source_line=3}
  ROOT %bitcast.5 = bf16[1024,2,4096]{2,0,1:T(8,128)(2,1)S(3)} bitcast(%copy.4), metadata={op_name=&quot;jit(all_to_all)/reshard&quot; source_file=&quot;/tmp/ipykernel_874047/2971478236.py&quot; source_line=3}
}
ENTRY %main.5_spmd (param: bf16[1024,8192]) -&amp;gt; bf16[2048,4096] {
  %param = bf16[1024,8192]{1,0:T(8,128)(2,1)} parameter(0), sharding={devices=[2,1,2]&amp;lt;=[4] last_tile_dim_replicate}, metadata={op_name=&quot;x&quot;}
  %bitcast.7 = bf16[128,2,8,4096]{3,2,1,0:T(8,128)(2,1)} bitcast(%param)
  %copy_bitcast_fusion = bf16[1024,2,4096]{2,0,1:T(8,128)(2,1)S(3)} fusion(%bitcast.7), kind=kLoop, calls=%fused_computation, metadata={op_name=&quot;jit(all_to_all)/reshard&quot; source_file=&quot;/tmp/ipykernel_874047/2971478236.py&quot; source_line=3}, backend_config={&quot;flag_configs&quot;:[],&quot;window_config&quot;:{&quot;kernel_window_bounds&quot;:[],&quot;output_window_bounds&quot;:[&quot;1&quot;,&quot;32&quot;,&quot;32&quot;],&quot;input_window_bounds&quot;:[],&quot;estimated_cycles&quot;:&quot;78212&quot;,&quot;iteration_bounds&quot;:[&quot;2&quot;,&quot;4&quot;,&quot;1&quot;]},&quot;megacore_config&quot;:{&quot;megacore_split_dim&quot;:&quot;1&quot;},&quot;scoped_memory_configs&quot;:[],&quot;used_scoped_memory_configs&quot;:[{&quot;memory_space&quot;:&quot;1&quot;,&quot;offset&quot;:&quot;0&quot;,&quot;size&quot;:&quot;8388608&quot;}],&quot;retry_config&quot;:{&quot;retry_count&quot;:&quot;0&quot;}}
  %all-to-all = bf16[1024,2,4096]{2,0,1:T(8,128)(2,1)S(3)} all-to-all(%copy_bitcast_fusion), channel_id=1, replica_groups=[2,2]&amp;lt;=[2,2]T(1,0), dimensions={1}, metadata={op_name=&quot;jit(all_to_all)/reshard&quot; source_file=&quot;/tmp/ipykernel_874047/2971478236.py&quot; source_line=3}, backend_config={&quot;flag_configs&quot;:[],&quot;barrier_config&quot;:{&quot;barrier_type&quot;:&quot;GLOBAL&quot;,&quot;id&quot;:&quot;-1&quot;},&quot;scoped_memory_configs&quot;:[],&quot;used_scoped_memory_configs&quot;:[{&quot;memory_space&quot;:&quot;1&quot;,&quot;offset&quot;:&quot;0&quot;,&quot;size&quot;:&quot;262144&quot;}],&quot;retry_config&quot;:{&quot;retry_count&quot;:&quot;0&quot;}}
  %bitcast.6 = bf16[2048,4096]{1,0:T(8,128)(2,1)S(3)} bitcast(%all-to-all)
  ROOT %copy.6 = bf16[2048,4096]{1,0:T(8,128)(2,1)} copy(%bitcast.6), backend_config={&quot;flag_configs&quot;:[],&quot;window_config&quot;:{&quot;kernel_window_bounds&quot;:[],&quot;output_window_bounds&quot;:[&quot;32&quot;,&quot;32&quot;],&quot;input_window_bounds&quot;:[],&quot;estimated_cycles&quot;:&quot;58232&quot;,&quot;iteration_bounds&quot;:[&quot;8&quot;,&quot;1&quot;]},&quot;megacore_config&quot;:{&quot;megacore_split_dim&quot;:&quot;0&quot;},&quot;scoped_memory_configs&quot;:[],&quot;used_scoped_memory_configs&quot;:[{&quot;memory_space&quot;:&quot;1&quot;,&quot;offset&quot;:&quot;0&quot;,&quot;size&quot;:&quot;8388608&quot;}],&quot;retry_config&quot;:{&quot;retry_count&quot;:&quot;0&quot;}}
}
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Some amount of the above amounts to memory optimizations that are necessary to
actually implement the communication in physical memory.&lt;sup id=&quot;fnref:9&quot; role=&quot;doc-noteref&quot;&gt;&lt;a href=&quot;#fn:9&quot; class=&quot;footnote&quot; rel=&quot;footnote&quot;&gt;7&lt;/a&gt;&lt;/sup&gt; We can more easily
interpret what’s going on by thinking in terms of block matrices again:
recall that we started with a $2048 \times 8192$ array $\vA$, sharded two ways:&lt;/p&gt;

\[\vA = \begin{bmatrix}
\vA_{00} &amp;amp; \vA_{01} \\
\vA_{10} &amp;amp; \vA_{11}
\end{bmatrix},\]

&lt;p&gt;and we performed an AllGather along the second axis, leading to (we’ll use
$\oplus_i$ for concatenation along dimension $i$)&lt;/p&gt;

\[\vB = \begin{bmatrix}
\vA_{00} \oplus_1 \vA_{01} &amp;amp; \vA_{00} \oplus_1 \vA_{01} \\
\vA_{10} \oplus_1 \vA_{11} &amp;amp; \vA_{10} \oplus_1 \vA_{11}
\end{bmatrix}.\]

&lt;p&gt;The output of the AllToAll reshards along the trailing axis, &lt;em&gt;but with respect
to the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;'x'&lt;/code&gt; mesh axis&lt;/em&gt;, which gives&lt;/p&gt;

\[\vC = \begin{bmatrix}
\vA_{00} \oplus_0 \vA_{10} &amp;amp; \vA_{00} \oplus_0 \vA_{10}  \\
\vA_{01} \oplus_0 \vA_{11} &amp;amp; \vA_{01} \oplus_0 \vA_{11}
\end{bmatrix}.\]

&lt;p&gt;To go from $\vB$ to $\vC$, we can do the following, which is essentially what
the HLO above is doing (modulo memory operations): unbind the concatenated
arrays, to give (in a rough approximation of tensor notation)&lt;/p&gt;

\[\begin{bmatrix}
(\vA_{00}, \vA_{01}) &amp;amp; (\vA_{00}, \vA_{01}) \\
(\vA_{10}, \vA_{11}) &amp;amp; (\vA_{10}, \vA_{11})
\end{bmatrix},\]

&lt;p&gt;then communicate ‘vertically’, swapping $01$ and $10$ on the left and right, to
get&lt;/p&gt;

\[\begin{bmatrix}
(\vA_{00}, \vA_{10}) &amp;amp; (\vA_{00}, \vA_{10}) \\
(\vA_{01}, \vA_{11}) &amp;amp; (\vA_{01}, \vA_{11})
\end{bmatrix}.\]

&lt;p&gt;If we concatenate these in the right order, we end up with $\vC$!
We can see that this is what the HLO is doing:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;After some memory manipulations, an extra axis of size $2$ is split off from
the second dimension, making the local array shaped like $1024 \times 2 \times
4096$.&lt;/li&gt;
  &lt;li&gt;The &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;all-to-all&lt;/code&gt; function has to be interpreted relative to the input
sharding. Here &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;{devices=[2,1,2]&amp;lt;=[4] last_tile_dim_replicate}&lt;/code&gt; denotes a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;[2,
1]&lt;/code&gt; tile size, replicated twice; &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;&amp;lt;=[4]&lt;/code&gt; is equivalent to a device list &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;{0,
1, 2, 3}&lt;/code&gt;. This is as we expect—the input is sharded over &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;'x'&lt;/code&gt;.&lt;/li&gt;
  &lt;li&gt;Now, in &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;all-to-all&lt;/code&gt;, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;replica_groups=[2,2]&amp;lt;=[2,2]T(1,0)&lt;/code&gt; specifies how the
communication is done across devices (relative to the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;sharding&lt;/code&gt; previously
specified). This corresponds to a specification of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;{{0, 2}, {1,
3}}&lt;/code&gt; (see &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;hlo_sharding.h&lt;/code&gt; in the XLA source), which means that
each replica’s array gets split into two parts along dimension $1$ (the axis
of length $2$), then communication occurs across the specified groups. This
gets us what we wanted above.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;It’s slightly interesting to note that different input sharding and an AllToAll
leads to different memory operations in the HLO:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;nd&quot;&gt;@jax.jit&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;all_to_all&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;reshard&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;NamedSharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)))&lt;/span&gt;


&lt;span class=&quot;n&quot;&gt;lowered&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;all_to_all&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;lower&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;C&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;compiled&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lowered&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;compile&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;D&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;compiled&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;C&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Output type: &lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;typeof&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;D&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;StableHLO:&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lowered&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;as_text&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;XLA HLO:&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;compiled&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;as_text&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
Output type:  bfloat16[2048@x,8192]
StableHLO:
 module @jit_all_to_all attributes {mhlo.num_partitions = 4 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = &amp;lt;[&quot;x&quot;=2, &quot;y&quot;=2]&amp;gt;
  func.func public @main(%arg0: tensor&amp;lt;2048x8192xbf16&amp;gt; {sdy.sharding = #sdy.sharding&amp;lt;@mesh, [{}, {&quot;x&quot;}]&amp;gt;}) -&amp;gt; (tensor&amp;lt;2048x8192xbf16&amp;gt; {jax.result_info = &quot;result&quot;, sdy.sharding = #sdy.sharding&amp;lt;@mesh, [{&quot;x&quot;}, {}]&amp;gt;}) {
    %0 = sdy.sharding_constraint %arg0 &amp;lt;@mesh, [{&quot;x&quot;}, {}]&amp;gt; : tensor&amp;lt;2048x8192xbf16&amp;gt;
    return %0 : tensor&amp;lt;2048x8192xbf16&amp;gt;
  }
}
XLA HLO:
 HloModule jit_all_to_all, is_scheduled=true, entry_computation_layout={(bf16[2048,4096]{1,0:T(8,128)(2,1)})-&amp;gt;bf16[1024,8192]{1,0:T(8,128)(2,1)}}, num_partitions=4
ENTRY %main.5_spmd (param: bf16[2048,4096]) -&amp;gt; bf16[1024,8192] {
  %param = bf16[2048,4096]{1,0:T(8,128)(2,1)} parameter(0), sharding={devices=[1,2,2]&amp;lt;=[4] last_tile_dim_replicate}, metadata={op_name=&quot;x&quot;}
  %bitcast.5 = bf16[2,1024,4096]{2,1,0:T(8,128)(2,1)} bitcast(%param)
  %all-to-all = bf16[2,1024,4096]{2,1,0:T(8,128)(2,1)S(3)} all-to-all(%bitcast.5), channel_id=1, replica_groups=[2,2]&amp;lt;=[2,2]T(1,0), dimensions={0}, metadata={op_name=&quot;jit(all_to_all)/reshard&quot; source_file=&quot;/tmp/ipykernel_874047/814736672.py&quot; source_line=3}, backend_config={&quot;flag_configs&quot;:[],&quot;barrier_config&quot;:{&quot;barrier_type&quot;:&quot;GLOBAL&quot;,&quot;id&quot;:&quot;-1&quot;},&quot;scoped_memory_configs&quot;:[],&quot;used_scoped_memory_configs&quot;:[{&quot;memory_space&quot;:&quot;1&quot;,&quot;offset&quot;:&quot;0&quot;,&quot;size&quot;:&quot;0&quot;}],&quot;retry_config&quot;:{&quot;retry_count&quot;:&quot;0&quot;}}
  %copy.3 = bf16[2,1024,4096]{1,2,0:T(8,128)(2,1)S(3)} copy(%all-to-all), metadata={op_name=&quot;jit(all_to_all)/reshard&quot; source_file=&quot;/tmp/ipykernel_874047/814736672.py&quot; source_line=3}, backend_config={&quot;flag_configs&quot;:[],&quot;window_config&quot;:{&quot;kernel_window_bounds&quot;:[],&quot;output_window_bounds&quot;:[&quot;1&quot;,&quot;48&quot;,&quot;8&quot;],&quot;input_window_bounds&quot;:[&quot;1&quot;,&quot;128&quot;,&quot;3&quot;],&quot;estimated_cycles&quot;:&quot;47696&quot;,&quot;iteration_bounds&quot;:[&quot;1&quot;,&quot;11&quot;,&quot;1&quot;]},&quot;scoped_memory_configs&quot;:[],&quot;used_scoped_memory_configs&quot;:[{&quot;memory_space&quot;:&quot;1&quot;,&quot;offset&quot;:&quot;0&quot;,&quot;size&quot;:&quot;3145728&quot;}],&quot;retry_config&quot;:{&quot;retry_count&quot;:&quot;0&quot;}}
  %bitcast.4 = bf16[1024,8192]{0,1:T(8,128)(2,1)S(3)} bitcast(%copy.3)
  ROOT %copy.4 = bf16[1024,8192]{1,0:T(8,128)(2,1)} copy(%bitcast.4), metadata={op_name=&quot;jit(all_to_all)/reshard&quot; source_file=&quot;/tmp/ipykernel_874047/814736672.py&quot; source_line=3}, backend_config={&quot;flag_configs&quot;:[],&quot;window_config&quot;:{&quot;kernel_window_bounds&quot;:[],&quot;output_window_bounds&quot;:[&quot;64&quot;,&quot;6&quot;],&quot;input_window_bounds&quot;:[&quot;96&quot;,&quot;4&quot;],&quot;estimated_cycles&quot;:&quot;67782&quot;,&quot;iteration_bounds&quot;:[&quot;1&quot;,&quot;11&quot;]},&quot;scoped_memory_configs&quot;:[],&quot;used_scoped_memory_configs&quot;:[{&quot;memory_space&quot;:&quot;1&quot;,&quot;offset&quot;:&quot;0&quot;,&quot;size&quot;:&quot;3182592&quot;}],&quot;retry_config&quot;:{&quot;retry_count&quot;:&quot;0&quot;}}
}
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The HLO for this operation is more concise, but neither is able to avoid
a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;copy&lt;/code&gt; operation.&lt;/p&gt;

&lt;h2 id=&quot;allreduce-and-reducescatter&quot;&gt;AllReduce and ReduceScatter&lt;/h2&gt;

&lt;p&gt;To trigger an AllReduce/ReduceScatter, let’s follow our matrix multiplication example
above—we’ll multiply two arrays that are both sharded along the contracting
dimension.
In the AllToAll section, we’ve generated &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;C&lt;/code&gt;, which is sharded along &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;'x'&lt;/code&gt; in
its 1st dimension. We can try computing a Gram matrix associated to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;C&lt;/code&gt;, and
specifying the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;out_sharding&lt;/code&gt; (singular!) in the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;dot&lt;/code&gt; operation to avoid the
compiler complaining:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;nd&quot;&gt;@jax.jit&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;gram&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;dot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mT&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;NamedSharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()))&lt;/span&gt;


&lt;span class=&quot;n&quot;&gt;lowered&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;gram&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;lower&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;C&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;compiled&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lowered&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;compile&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;E&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;compiled&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;C&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Output type: &lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;typeof&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;E&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;StableHLO:&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lowered&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;as_text&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;XLA HLO:&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;compiled&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;as_text&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
Output type:  bfloat16[2048,2048]
StableHLO:
 module @jit_gram attributes {mhlo.num_partitions = 4 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = &amp;lt;[&quot;x&quot;=2, &quot;y&quot;=2]&amp;gt;
  func.func public @main(%arg0: tensor&amp;lt;2048x8192xbf16&amp;gt; {sdy.sharding = #sdy.sharding&amp;lt;@mesh, [{}, {&quot;x&quot;}]&amp;gt;}) -&amp;gt; (tensor&amp;lt;2048x2048xbf16&amp;gt; {jax.result_info = &quot;result&quot;, sdy.sharding = #sdy.sharding&amp;lt;@mesh, [{}, {}]&amp;gt;}) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor&amp;lt;2048x8192xbf16&amp;gt;) -&amp;gt; tensor&amp;lt;8192x2048xbf16&amp;gt;
    %1 = sdy.sharding_constraint %0 &amp;lt;@mesh, [{&quot;x&quot;}, {}]&amp;gt; : tensor&amp;lt;8192x2048xbf16&amp;gt;
    %2 = stablehlo.dot_general %arg0, %1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor&amp;lt;2048x8192xbf16&amp;gt;, tensor&amp;lt;8192x2048xbf16&amp;gt;) -&amp;gt; tensor&amp;lt;2048x2048xbf16&amp;gt;
    %3 = sdy.sharding_constraint %2 &amp;lt;@mesh, [{}, {}]&amp;gt; : tensor&amp;lt;2048x2048xbf16&amp;gt;
    return %3 : tensor&amp;lt;2048x2048xbf16&amp;gt;
  }
}
XLA HLO:
 HloModule jit_gram, is_scheduled=true, entry_computation_layout={(bf16[2048,4096]{1,0:T(8,128)(2,1)})-&amp;gt;bf16[2048,2048]{1,0:T(8,128)(2,1)}}, num_partitions=4
%add.clone (x.1: bf16[], y.1: bf16[]) -&amp;gt; bf16[] {
  %y.1 = bf16[]{:T(256)} parameter(1)
  %x.1 = bf16[]{:T(256)} parameter(0)
  ROOT %add.1 = bf16[]{:T(256)} add(%x.1, %y.1)
}
%bitcast_fusion (bitcast_input: bf16[2048,4096]) -&amp;gt; bf16[2048,4096] {
  %bitcast_input = bf16[2048,4096]{1,0:T(8,128)(2,1)S(3)} parameter(0)
  ROOT %bitcast = bf16[2048,4096]{1,0:T(8,128)(2,1)} bitcast(%bitcast_input)
}
%bitcast_fusion.1 (bitcast_input.1: bf16[2048,4096]) -&amp;gt; bf16[2048,4096] {
  %bitcast_input.1 = bf16[2048,4096]{1,0:T(8,128)(2,1)S(3)} parameter(0)
  ROOT %bitcast.1 = bf16[2048,4096]{1,0:T(8,128)(2,1)} bitcast(%bitcast_input.1)
}
%fused_computation (param_0: bf16[2048,4096]) -&amp;gt; bf16[2048,2048] {
  %param_0 = bf16[2048,4096]{1,0:T(8,128)(2,1)S(3)} parameter(0)
  %fusion.1 = bf16[2048,4096]{1,0:T(8,128)(2,1)} fusion(%param_0), kind=kLoop, calls=%bitcast_fusion
  %fusion.2 = bf16[2048,4096]{1,0:T(8,128)(2,1)} fusion(%param_0), kind=kLoop, calls=%bitcast_fusion.1
  ROOT %convolution.1 = bf16[2048,2048]{1,0:T(8,128)(2,1)S(3)} convolution(%fusion.1, %fusion.2), dim_labels=bf_oi-&amp;gt;bf, metadata={op_name=&quot;jit(gram)/dot_general&quot; source_file=&quot;/tmp/ipykernel_874047/670048905.py&quot; source_line=3}
}
ENTRY %main.8_spmd (param: bf16[2048,4096]) -&amp;gt; bf16[2048,2048] {
  %param = bf16[2048,4096]{1,0:T(8,128)(2,1)} parameter(0), sharding={devices=[1,2,2]&amp;lt;=[4] last_tile_dim_replicate}, metadata={op_name=&quot;x&quot;}
  %copy-start = (bf16[2048,4096]{1,0:T(8,128)(2,1)S(3)}, bf16[2048,4096]{1,0:T(8,128)(2,1)}, u32[]{:S(2)}) copy-start(%param), cross_program_prefetch_index=0
  %copy-done = bf16[2048,4096]{1,0:T(8,128)(2,1)S(3)} copy-done(%copy-start)
  %fusion = bf16[2048,2048]{1,0:T(8,128)(2,1)S(3)} fusion(%copy-done), kind=kOutput, calls=%fused_computation, metadata={op_name=&quot;jit(gram)/dot_general&quot; source_file=&quot;/tmp/ipykernel_874047/670048905.py&quot; source_line=3}, backend_config={&quot;flag_configs&quot;:[],&quot;window_config&quot;:{&quot;kernel_window_bounds&quot;:[&quot;64&quot;,&quot;32&quot;],&quot;output_window_bounds&quot;:[&quot;64&quot;,&quot;4&quot;],&quot;input_window_bounds&quot;:[&quot;64&quot;,&quot;32&quot;],&quot;estimated_cycles&quot;:&quot;149352&quot;,&quot;iteration_bounds&quot;:[&quot;4&quot;,&quot;4&quot;,&quot;1&quot;]},&quot;megacore_config&quot;:{&quot;megacore_split_dim&quot;:&quot;0&quot;},&quot;scoped_memory_configs&quot;:[],&quot;used_scoped_memory_configs&quot;:[{&quot;memory_space&quot;:&quot;1&quot;,&quot;offset&quot;:&quot;0&quot;,&quot;size&quot;:&quot;5734400&quot;}],&quot;retry_config&quot;:{&quot;retry_count&quot;:&quot;0&quot;},&quot;convolution_algorithm_config&quot;:{&quot;emitter&quot;:&quot;EmitAllBatchInSublanes&quot;}}
  ROOT %all-reduce = bf16[2048,2048]{1,0:T(8,128)(2,1)} all-reduce(%fusion), channel_id=1, replica_groups=[2,2]&amp;lt;=[2,2]T(1,0), use_global_device_ids=true, to_apply=%add.clone, metadata={op_name=&quot;jit(gram)/dot_general&quot; source_file=&quot;/tmp/ipykernel_874047/670048905.py&quot; source_line=3}, backend_config={&quot;flag_configs&quot;:[],&quot;barrier_config&quot;:{&quot;barrier_type&quot;:&quot;CUSTOM&quot;,&quot;id&quot;:&quot;0&quot;},&quot;scoped_memory_configs&quot;:[{&quot;memory_space&quot;:&quot;0&quot;,&quot;offset&quot;:&quot;0&quot;,&quot;size&quot;:&quot;67108864&quot;}],&quot;collective_algorithm_config&quot;:{&quot;emitter&quot;:&quot;RotatedPincerEmitter&quot;,&quot;strategy&quot;:&quot;UniDirection1DRingStrategy&quot;,&quot;debug&quot;:&quot;\nUniDirection1DRingStrategy{colors:2 phases:1 cores:{2},{2} nophase0:0 reserved_sflags:0 cross_module_on_2d_plane:0 has_reordering_map:0 use_routing_table_indices:0}&quot;},&quot;used_scoped_memory_configs&quot;:[{&quot;memory_space&quot;:&quot;1&quot;,&quot;offset&quot;:&quot;0&quot;,&quot;size&quot;:&quot;12582912&quot;}],&quot;retry_config&quot;:{&quot;retry_count&quot;:&quot;0&quot;}}
}
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The above is relatively clear! A &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;copy&lt;/code&gt; operation accounts for the fact that
we’re multiplying the matrix with itself to calculate the Gram matrix, then the
actual multiplication is performed with the XLA &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;convolution&lt;/code&gt; operation. It’s
performed on each local shard, and then the local shards are accumulated and
shared with an AllReduce.&lt;/p&gt;

&lt;p&gt;We expect to be able to trigger a ReduceScatter instead of an AllReduce by
asking the computation output to be sharded in a particular way (rather than
fully replicated, as above). Let’s try to verify this:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;nd&quot;&gt;@jax.jit&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;dot_custom&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;dot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;NamedSharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;'&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;sh&quot;&gt;'&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)))&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;F&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;8192&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4096&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jnp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bfloat16&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out_sharding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;NamedSharding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mesh&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nc&quot;&gt;P&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;'&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)))&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;lowered&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dot_custom&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;lower&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;compiled&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lowered&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;compile&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;G&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;compiled&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;


&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Output type: &lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;typeof&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;G&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;StableHLO:&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lowered&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;as_text&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;nf&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;XLA HLO:&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;sh&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;compiled&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nf&quot;&gt;as_text&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;div class=&quot;language-plaintext highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;Output:
Output type:  bfloat16[2048@x,4096@y]
StableHLO:
 module @jit_dot_custom attributes {mhlo.num_partitions = 4 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = &amp;lt;[&quot;x&quot;=2, &quot;y&quot;=2]&amp;gt;
  func.func public @main(%arg0: tensor&amp;lt;2048x8192xbf16&amp;gt; {sdy.sharding = #sdy.sharding&amp;lt;@mesh, [{&quot;x&quot;}, {&quot;y&quot;}]&amp;gt;}, %arg1: tensor&amp;lt;8192x4096xbf16&amp;gt; {sdy.sharding = #sdy.sharding&amp;lt;@mesh, [{&quot;y&quot;}, {}]&amp;gt;}) -&amp;gt; (tensor&amp;lt;2048x4096xbf16&amp;gt; {jax.result_info = &quot;result&quot;, sdy.sharding = #sdy.sharding&amp;lt;@mesh, [{&quot;x&quot;}, {&quot;y&quot;}]&amp;gt;}) {
    %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor&amp;lt;2048x8192xbf16&amp;gt;, tensor&amp;lt;8192x4096xbf16&amp;gt;) -&amp;gt; tensor&amp;lt;2048x4096xbf16&amp;gt;
    %1 = sdy.sharding_constraint %0 &amp;lt;@mesh, [{&quot;x&quot;}, {&quot;y&quot;}]&amp;gt; : tensor&amp;lt;2048x4096xbf16&amp;gt;
    return %1 : tensor&amp;lt;2048x4096xbf16&amp;gt;
  }
}
XLA HLO:
 HloModule jit_dot_custom, is_scheduled=true, entry_computation_layout={(bf16[1024,4096]{1,0:T(8,128)(2,1)}, bf16[4096,4096]{1,0:T(8,128)(2,1)})-&amp;gt;bf16[1024,2048]{1,0:T(8,128)(2,1)}}, allow_spmd_sharding_propagation_to_parameters={false,false}, num_partitions=4
%add.1.clone (x.3: bf16[], y.3: bf16[]) -&amp;gt; bf16[] {
  %y.3 = bf16[]{:T(256)} parameter(1)
  %x.3 = bf16[]{:T(256)} parameter(0)
  ROOT %add.3 = bf16[]{:T(256)} add(%x.3, %y.3)
}
%all-reduce-scatter (input: bf16[1024,4096]) -&amp;gt; bf16[1024,2048] {
  %input = bf16[1024,4096]{1,0:T(8,128)(2,1)S(3)} parameter(0)
  %all-reduce.2 = bf16[1024,4096]{1,0:T(8,128)(2,1)} all-reduce(%input), channel_id=3, replica_groups={{0,1},{2,3}}, use_global_device_ids=true, to_apply=%add.1.clone, frontend_attributes={from-cross-replica-sharding=&quot;true&quot;}, backend_config={&quot;flag_configs&quot;:[],&quot;barrier_config&quot;:{&quot;barrier_type&quot;:&quot;CUSTOM&quot;,&quot;id&quot;:&quot;0&quot;},&quot;scoped_memory_configs&quot;:[],&quot;used_scoped_memory_configs&quot;:[]}
  %constant.13 = u32[] constant(0)
  %constant.shard_id_table = u32[4]{0:T(128)} constant({0, 1, 0, 1})
  %partition-id.1 = u32[] partition-id()
  %dynamic-slice.3 = u32[1]{0:T(128)} dynamic-slice(%constant.shard_id_table, %partition-id.1), dynamic_slice_sizes={1}, backend_config={&quot;flag_configs&quot;:[],&quot;scoped_memory_configs&quot;:[],&quot;indices_config&quot;:{&quot;index_known_bits&quot;:[{&quot;zeroes&quot;:&quot;0&quot;,&quot;ones&quot;:&quot;0&quot;,&quot;bitwidth&quot;:&quot;32&quot;}],&quot;is_index_aligned&quot;:[false]},&quot;used_scoped_memory_configs&quot;:[]}
  %bitcast.1 = u32[]{:T(128)} bitcast(%dynamic-slice.3)
  %constant.14 = u32[] constant(2048)
  %multiply.3 = u32[]{:T(128)} multiply(%bitcast.1, %constant.14)
  ROOT %dynamic-slice.4 = bf16[1024,2048]{1,0:T(8,128)(2,1)} dynamic-slice(%all-reduce.2, %constant.13, %multiply.3), dynamic_slice_sizes={1024,2048}, backend_config={&quot;flag_configs&quot;:[],&quot;scoped_memory_configs&quot;:[],&quot;indices_config&quot;:{&quot;index_known_bits&quot;:[{&quot;zeroes&quot;:&quot;4294967295&quot;,&quot;ones&quot;:&quot;0&quot;,&quot;bitwidth&quot;:&quot;32&quot;},{&quot;zeroes&quot;:&quot;4294961151&quot;,&quot;ones&quot;:&quot;0&quot;,&quot;bitwidth&quot;:&quot;32&quot;}],&quot;is_index_aligned&quot;:[true,false]},&quot;used_scoped_memory_configs&quot;:[]}
}
%bitcast_fusion (bitcast_input: bf16[1024,4096]) -&amp;gt; bf16[1024,4096] {
  %bitcast_input = bf16[1024,4096]{1,0:T(8,128)(2,1)} parameter(0)
  ROOT %bitcast.2 = bf16[1024,4096]{1,0:T(8,128)(2,1)} bitcast(%bitcast_input)
}
%bitcast_fusion.1 (bitcast_input.1: bf16[4096,4096]) -&amp;gt; bf16[4096,4096] {
  %bitcast_input.1 = bf16[4096,4096]{1,0:T(8,128)(2,1)S(3)} parameter(0)
  ROOT %bitcast.3 = bf16[4096,4096]{1,0:T(8,128)(2,1)} bitcast(%bitcast_input.1)
}
%fused_computation (param_0: bf16[1024,4096], param_1: bf16[4096,4096]) -&amp;gt; bf16[1024,4096] {
  %param_0 = bf16[1024,4096]{1,0:T(8,128)(2,1)} parameter(0)
  %fusion.2 = bf16[1024,4096]{1,0:T(8,128)(2,1)} fusion(%param_0), kind=kLoop, calls=%bitcast_fusion
  %param_1 = bf16[4096,4096]{1,0:T(8,128)(2,1)S(3)} parameter(1)
  %fusion.3 = bf16[4096,4096]{1,0:T(8,128)(2,1)} fusion(%param_1), kind=kLoop, calls=%bitcast_fusion.1
  ROOT %convolution.1 = bf16[1024,4096]{1,0:T(8,128)(2,1)S(3)} convolution(%fusion.2, %fusion.3), dim_labels=bf_io-&amp;gt;bf, metadata={op_name=&quot;jit(dot_custom)/dot_general&quot; source_file=&quot;/tmp/ipykernel_874047/871715024.py&quot; source_line=3}
}
ENTRY %main.7_spmd (param: bf16[1024,4096], param.1: bf16[4096,4096]) -&amp;gt; bf16[1024,2048] {
  %param.1 = bf16[4096,4096]{1,0:T(8,128)(2,1)} parameter(1), sharding={devices=[2,1,2]&amp;lt;=[2,2]T(1,0) last_tile_dim_replicate}, metadata={op_name=&quot;y&quot;}
  %copy-start = (bf16[4096,4096]{1,0:T(8,128)(2,1)S(3)}, bf16[4096,4096]{1,0:T(8,128)(2,1)}, u32[]{:S(2)}) copy-start(%param.1), cross_program_prefetch_index=0
  %param = bf16[1024,4096]{1,0:T(8,128)(2,1)} parameter(0), sharding={devices=[2,2]&amp;lt;=[4]}, metadata={op_name=&quot;x&quot;}
  %copy-done = bf16[4096,4096]{1,0:T(8,128)(2,1)S(3)} copy-done(%copy-start)
  %fusion.1 = bf16[1024,4096]{1,0:T(8,128)(2,1)S(3)} fusion(%param, %copy-done), kind=kOutput, calls=%fused_computation, metadata={op_name=&quot;jit(dot_custom)/dot_general&quot; source_file=&quot;/tmp/ipykernel_874047/871715024.py&quot; source_line=3}, backend_config={&quot;flag_configs&quot;:[],&quot;window_config&quot;:{&quot;kernel_window_bounds&quot;:[&quot;512&quot;,&quot;4&quot;],&quot;output_window_bounds&quot;:[&quot;32&quot;,&quot;4&quot;],&quot;input_window_bounds&quot;:[&quot;32&quot;,&quot;32&quot;],&quot;estimated_cycles&quot;:&quot;144544&quot;,&quot;iteration_bounds&quot;:[&quot;8&quot;,&quot;4&quot;,&quot;1&quot;]},&quot;megacore_config&quot;:{&quot;megacore_split_dim&quot;:&quot;1&quot;},&quot;scoped_memory_configs&quot;:[],&quot;used_scoped_memory_configs&quot;:[{&quot;memory_space&quot;:&quot;1&quot;,&quot;offset&quot;:&quot;0&quot;,&quot;size&quot;:&quot;5591040&quot;}],&quot;retry_config&quot;:{&quot;retry_count&quot;:&quot;0&quot;},&quot;convolution_algorithm_config&quot;:{&quot;emitter&quot;:&quot;EmitAllBatchInSublanes&quot;}}
  ROOT %fusion = bf16[1024,2048]{1,0:T(8,128)(2,1)} fusion(%fusion.1), kind=kCustom, calls=%all-reduce-scatter, metadata={op_name=&quot;jit(dot_custom)/dot_general&quot; source_file=&quot;/tmp/ipykernel_874047/871715024.py&quot; source_line=3}, backend_config={&quot;flag_configs&quot;:[],&quot;scoped_memory_configs&quot;:[{&quot;memory_space&quot;:&quot;0&quot;,&quot;offset&quot;:&quot;0&quot;,&quot;size&quot;:&quot;67108864&quot;}],&quot;collective_algorithm_config&quot;:{&quot;emitter&quot;:&quot;SingleInputAllReduceScatterFusion&quot;,&quot;strategy&quot;:&quot;StrategyRing&quot;,&quot;debug&quot;:&quot;\nStrategyRing{colors:1 phases:1 cores:{2} nophase0:0 reserved_sflags:0 cross_module_on_2d_plane:0 has_reordering_map:0 use_routing_table_indices:0}\nStrategyRing{colors:1 phases:1 cores:{2} nophase0:0 reserved_sflags:0 cross_module_on_2d_plane:0 has_reordering_map:0 use_routing_table_indices:0}\nType: 1D phase_count: 1; color_count: 1; sharded_partitions: 4; original_shape: bf16[1024,4096]{1,0:T(8,128)(2,1)S(3)}; per_color_shard_counts: 2; color_dim: -1; sharding_dim: 1; sharding_type: minor; convert_all_gather_output_to_bf16: 0; formatting steps: ()\nall-reduce-scatter fusion ND: span_size:8192*k512Byte, shard count:2, span_count:2, total_size:16384*k512Byte, valid_granules:16384*k512Byte&quot;},&quot;used_scoped_memory_configs&quot;:[{&quot;memory_space&quot;:&quot;1&quot;,&quot;offset&quot;:&quot;0&quot;,&quot;size&quot;:&quot;12582912&quot;}],&quot;retry_config&quot;:{&quot;retry_count&quot;:&quot;0&quot;}}
}
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;I actually haven’t found a way to generate a ReduceScatter operation in the
HLO—if you know how to do this, please let me know! Above, I’ve picked
different shardings and arrays, and we end up triggering an operation similar in
effect to a ReduceScatter; but the HLO implementation instead uses &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;all-reduce&lt;/code&gt;
and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;dynamic-slice&lt;/code&gt; to split up the reduced array.&lt;/p&gt;

&lt;h1 id=&quot;conclusion&quot;&gt;Conclusion&lt;/h1&gt;

&lt;p&gt;This turned out to be on the longer side! But we have two valuable takeaways:&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;A linear-algebraic mental model for sharding in JAX—sharded dimensions
partition data, and unsharded dimensions lead to replication in the
underlying mesh! All of this can be understood in terms of block matrices and
their algebra.&lt;/li&gt;
  &lt;li&gt;The ability to read HLO and parse the communications primitives that arise
therein when we manipulate sharded arrays.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;We also recapped information from the TPU Scaling book about how to
interpret performance tradeoffs between these different forms of communication.&lt;/p&gt;

&lt;h1 id=&quot;acknowledgments&quot;&gt;Acknowledgments&lt;/h1&gt;

&lt;p&gt;Thanks to the &lt;a href=&quot;https://sites.research.google/trc/about/&quot;&gt;TRC program&lt;/a&gt; for
compute support.&lt;/p&gt;

&lt;div class=&quot;footnotes&quot; role=&quot;doc-endnotes&quot;&gt;
  &lt;ol&gt;
    &lt;li id=&quot;fn:1&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;One way to generalize this to simultaneously considering different meshes
is to consider that the mesh simply indexes different physical devices axes,
or ‘merges’ of them. Hence if there are $M$ physical device axes (e.g.,
devices with 3D interconnects), one can count shardings across meshes by
allowing the $k$ selected axes to be merged together. This is
a noncommutative merging, as the traversal order for the devices will be
different based on the ordering of the merged axes. &lt;a href=&quot;#fnref:1&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:2&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;To go beyond this restriction, it suffices to think of the mapping between
block indices and devices as being performed by a pre-configured mesh.
Algebraically, the mesh axis corresponds to a sort of ‘block permutation
matrix’ (i.e., a tensor product of a permutation matrix with the identity),
which left-multiplies for sharding the first dimension of the matrix, and
right-multiplies for the second. &lt;a href=&quot;#fnref:2&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:4&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;Note that, algebraically, if we perform copying to represent a partitioned
matrix in terms of devices, then &lt;em&gt;elementwise&lt;/em&gt; matrix multiplication
corresponds to the within-device matrix products!) &lt;a href=&quot;#fnref:4&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:6&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;E.g., about 42 GB/s for &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;v4&lt;/code&gt; TPUs. &lt;a href=&quot;#fnref:6&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:7&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;Since AllGather is linear, its derivative doesn’t involve any cached
activations from the forward pass, so we omit these from the notation. &lt;a href=&quot;#fnref:7&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:8&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;But the translator is open source, and the admissible operations can be
gleaned from the source code (use Google’s LLM for help with this!).
A high-level specification of operations can be found &lt;a href=&quot;https://openxla.org/xla/operation_semantics&quot;&gt;here&lt;/a&gt;. &lt;a href=&quot;#fnref:8&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
    &lt;li id=&quot;fn:9&quot; role=&quot;doc-endnote&quot;&gt;
      &lt;p&gt;Some high-level takeaways about these operations: think of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;bitcast&lt;/code&gt; like
a view of the array, at least in this context, whereas &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;copy&lt;/code&gt; doesn’t change
the underlying tensor, but does change the memory layout. (Notice that the
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;copy&lt;/code&gt; operation involves a transposition of the layout axes.) &lt;a href=&quot;#fnref:9&quot; class=&quot;reversefootnote&quot; role=&quot;doc-backlink&quot;&gt;&amp;#8617;&lt;/a&gt;&lt;/p&gt;
    &lt;/li&gt;
  &lt;/ol&gt;
&lt;/div&gt;</content>

      
      
      
      
      

      
        <author>
            <name>Sam Buchanan</name>
          
            <email>sdbuchanan@berkeley.edu</email>
          
          
        </author>
      

      
        <category term="jax" />
      

      

      
        <summary type="html">This post will be the first in a series on programming with JAX, for training models like transformers. I’m experimenting with these as an alternative to my usual scratch paper or LaTeX notes while learning, in the hope that it will help me with recall and perhaps be useful to others learning this material.</summary>
      

      
      
    </entry>
  
  
</feed>
