<?xml version="1.0" encoding="UTF-8"?><rss xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:content="http://purl.org/rss/1.0/modules/content/" xmlns:atom="http://www.w3.org/2005/Atom" version="2.0" xmlns:cc="http://cyber.law.harvard.edu/rss/creativeCommonsRssModule.html">
    <channel>
        <title><![CDATA[Stories by Ali Nawaz on Medium]]></title>
        <description><![CDATA[Stories by Ali Nawaz on Medium]]></description>
        <link>https://medium.com/@AliPythonDev?source=rss-72f6eb2992e8------2</link>
        <image>
            <url>https://cdn-images-1.medium.com/fit/c/150/150/1*Gae34nBd4Q739F-6v1H7_g.jpeg</url>
            <title>Stories by Ali Nawaz on Medium</title>
            <link>https://medium.com/@AliPythonDev?source=rss-72f6eb2992e8------2</link>
        </image>
        <generator>Medium</generator>
        <lastBuildDate>Mon, 01 Jun 2026 05:17:04 GMT</lastBuildDate>
        <atom:link href="https://medium.com/@AliPythonDev/feed" rel="self" type="application/rss+xml"/>
        <webMaster><![CDATA[yourfriends@medium.com]]></webMaster>
        <atom:link href="http://medium.superfeedr.com" rel="hub"/>
        <item>
            <title><![CDATA[Making Financial Rules Portable (Without Losing Your Mind)]]></title>
            <link>https://medium.com/@AliPythonDev/making-financial-rules-portable-without-losing-your-mind-f69de90d39ef?source=rss-72f6eb2992e8------2</link>
            <guid isPermaLink="false">https://medium.com/p/f69de90d39ef</guid>
            <category><![CDATA[morphir]]></category>
            <category><![CDATA[open-source]]></category>
            <category><![CDATA[fino]]></category>
            <dc:creator><![CDATA[Ali Nawaz]]></dc:creator>
            <pubDate>Wed, 06 May 2026 18:30:34 GMT</pubDate>
            <atom:updated>2026-05-06T18:36:52.441Z</atom:updated>
            <content:encoded><![CDATA[<figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*Eewtt25uj9G_cN6E6oMitg.png" /></figure><p><strong>Author:</strong> <a href="https://www.linkedin.com/in/AlleyNawaz">Ali Nawaz</a>, <a href="https://www.finos.org/ambassador-program">FINOS Ambassador</a></p><blockquote><strong>Abstract:</strong><em> </em>Banks waste millions rewriting the same business rules over and over. A trading rule starts in Excel, gets rebuilt in Java, then Python, then SQL. When something changes, teams update each version separately and often miss one. The result? Systems that don’t match, audits that fail, and nobody knowing which version is right.</blockquote><blockquote>Morphir fixes this. Write your financial rules once in a simple language called Elm. Morphir then creates versions for every system you need automatically. Same rule, different formats, zero manual copying. Morgan Stanley built it to solve their own mess and shared it with everyone through FINOS. This post explains what it does, why it matters, and whether you should care.</blockquote><p>If you work in finance, you’ve probably seen this movie before: a business analyst writes the logic for a new trading rule in Excel. A developer rewrites it in Java. Another team needs the same rule in Python for their risk system. Six months later, nobody’s sure which version is “correct” anymore.</p><p>This isn’t just annoying. It’s expensive, error-prone, and sometimes dangerous when millions of dollars ride on calculations that should match but don’t.</p><p>Enter Morphir, a FINOS project that’s trying to solve a problem most people don’t even realize exists: making business logic portable across systems, languages, and teams without breaking everything in the process.</p><h3>The Problem: Lost in Translation</h3><p>Here’s a scenario that plays out daily in banks and financial institutions:</p><p>Your compliance team defines a new regulation rule. They understand the business side perfectly but can’t code. So they write it in plain English, maybe throw it in a Word doc or Excel spreadsheet.</p><p>The development team picks it up, interprets it, and codes it in Scala for the trading system. Then another team needs the same rule for reporting, so they rewrite it in SQL. The data science team wants it for modeling, so they build it again in Python.</p><p>Now you have four versions of the “same” rule. When the regulation changes, you update all four. Except someone misses the SQL version, and now your reports don’t match your trades. Auditors show up. Everyone panics.</p><p>This isn’t a hypothetical. This is Tuesday.</p><h3>What Morphir Actually Does</h3><p>Morphir takes a different approach. Instead of writing your business logic directly in Java, Python, or whatever language your system uses, you write it once in a language called Elm. Then Morphir translates it into whatever you need.</p><p>Think of it like writing a recipe once, then having it automatically translated into French, Spanish, and Japanese. Same recipe, different languages, no manual rewriting.</p><p>But here’s the clever bit: Elm isn’t just any language. It’s designed to be really hard to mess up. No null pointer exceptions. No runtime errors if you do it right. This matters when you’re dealing with financial calculations where a misplaced decimal can cost millions.</p><p>So the workflow looks like this:</p><ol><li>Business analysts and developers collaborate to write the logic in Elm</li><li>Morphir takes that Elm code and generates an intermediate representation (IR)</li><li>From that IR, Morphir can generate code in Scala, TypeScript, JSON schemas, or even documentation</li><li>Everyone uses the generated code, and there’s only one source of truth</li></ol><h3>Why This Matters More Than You Think</h3><p>Let’s talk about what this unlocks.</p><p><strong>Same logic, different systems.</strong> You define a pricing model once. Morphir generates the Scala version for your backend, the TypeScript version for your web app, and the documentation for your compliance team. They all match because they came from the same source.</p><p><strong>Version control for business rules.</strong> Since your logic lives in code (Elm), you can use Git. You can see who changed what, when, and why. You can roll back if something breaks. Try doing that with an Excel spreadsheet passed around in email.</p><p><strong>Testing becomes possible.</strong> When your business logic is in Elm, you can write tests. Real tests. Not “let me manually check this in production” tests. You can verify that your interest calculation works correctly before deploying it to a system managing billions of dollars.</p><p><strong>Talking the same language.</strong> Ever tried explaining a technical concept to a business analyst, or vice versa? Morphir creates a middle ground. The Elm code is readable enough that business folks can follow along, but rigorous enough that it compiles into production code.</p><h3>The Real-World Use Case: Morgan Stanley</h3><p>Morgan Stanley didn’t build Morphir because they thought it would be cool. They built it because they had a genuine problem.</p><p>They had business logic scattered across multiple systems and languages. Every time a rule changed, they had to hunt down every implementation and update them all. They missed things. Systems fell out of sync. It was a mess.</p><p>Morphir came out of that pain. They open-sourced it through FINOS because they realized other banks had the same problem. And they were right.</p><h3>How It Actually Works (The Less Boring Technical Part)</h3><p>Morphir uses something called an Intermediate Representation. You don’t need to understand the deep computer science here, but the idea is simple:</p><p>Instead of translating directly from Elm to Java, then Elm to Python, then Elm to whatever else you need (which means building a translator for every combination), Morphir translates Elm into a neutral format first. Then it translates that neutral format into your target language.</p><p>It’s like how Google Translate doesn’t have a direct translator for every language pair. It uses an intermediate step to make the whole thing manageable.</p><p>This IR is also queryable. You can ask it questions like “which functions use this data type?” or “where is this business rule used?” Try doing that with regular code.</p><h3>The Catch (Because There’s Always a Catch)</h3><p>Morphir isn’t a magic wand. You can’t just point it at your existing codebase and expect everything to work.</p><p>First, you need to rewrite your logic in Elm. That takes time. It takes learning a new language (though Elm is pretty friendly as languages go). For small rules, this might feel like overkill.</p><p>Second, Morphir works best for pure business logic — calculations, transformations, rules. If your code is tangled up with database calls, API requests, and UI rendering, you’ll need to separate the business logic first. That’s good practice anyway, but it’s still work.</p><p>Third, you’re betting on Elm. It’s not a mainstream language. If your team leaves and you need to hire replacements, finding Elm developers isn’t as easy as finding Java developers. Though the FINOS community helps here.</p><h3>Where This Is Heading</h3><p>FINOS is pushing Morphir toward some interesting places:</p><p><strong>Regulation as Code.</strong> Imagine regulators publishing rules in Morphir format. Banks could directly import those rules instead of interpreting 200-page PDFs and hoping they got it right. The UK’s Financial Conduct Authority is already experimenting with this concept.</p><p><strong>Cross-institution collaboration.</strong> If multiple banks agree on a standard way to calculate something (say, portfolio risk), they could share the Morphir definition. Everyone implements it the same way, reducing errors and making collaboration easier.</p><p><strong>Better tooling.</strong> The Morphir team is building visual tools so business analysts can actually work with the models without writing code. Imagine dragging and dropping to build a financial rule, then having production-ready code generated automatically.</p><h3>Should You Care?</h3><p>If you’re building financial systems and you’ve ever had to maintain the same business logic in multiple places, yes.</p><p>If you’re tired of “wait, which version of this calculation is the right one?” conversations, yes.</p><p>If you think business analysts and developers should speak a common language instead of playing telephone, yes.</p><p>Morphir won’t solve every problem in financial software. But for the specific problem of keeping business logic consistent across systems, it’s one of the smarter approaches out there.</p><p>And unlike a lot of open source projects that feel like academic exercises, Morphir came from real pain at a real bank. That matters.</p><h3>Resources to Dig Deeper</h3><p><strong>Official FINOS Morphir Resources:</strong></p><ul><li><a href="https://github.com/finos/morphir">Morphir GitHub Repository</a>: Main codebase and documentation</li><li><a href="https://github.com/finos/morphir-examples">Morphir Examples</a>: Sample projects to understand how it works</li><li><a href="https://www.finos.org/morphir">FINOS Morphir Project Page</a>: Overview and community info</li></ul><p><strong>Getting Started:</strong></p><ul><li><a href="https://github.com/finos/morphir-elm">Morphir Elm SDK</a>: The Elm tooling for building Morphir models</li><li><a href="https://morphir.finos.org/">Morphir Documentation</a>: Full docs and guides</li><li><a href="https://guide.elm-lang.org/">Intro to Elm</a>: Since you’ll need Elm basics to use Morphir</li></ul><p><strong>Related Reading:</strong></p><ul><li><a href="https://www.finos.org/common-domain-model">Common Domain Model (CDM)</a>: Another FINOS project for standardizing financial data</li></ul><p><em>This is part of a series on FINOS projects. <br>Previously: </em><a href="https://medium.com/@AliPythonDev/building-the-future-of-open-source-in-finance-with-finos-af2118e14718"><em>Building the Future of Open Source in Finance with FINOS</em></a><em> and </em><a href="https://medium.com/@AliPythonDev/connecting-the-dots-how-fdc3-is-making-the-financial-desktop-smarter-ade5e6d919c6"><em>Connecting the Dots: How FDC3 is Making the Financial Desktop Smarter</em></a></p><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=f69de90d39ef" width="1" height="1" alt="">]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[jit, grad, and vmap in JAX — One simple idea behind all three: transforming functions]]></title>
            <link>https://medium.com/@AliPythonDev/jit-grad-and-vmap-in-jax-one-simple-idea-behind-all-three-transforming-functions-ed4734f0eb16?source=rss-72f6eb2992e8------2</link>
            <guid isPermaLink="false">https://medium.com/p/ed4734f0eb16</guid>
            <category><![CDATA[artificial-intelligence]]></category>
            <category><![CDATA[python]]></category>
            <category><![CDATA[jax]]></category>
            <category><![CDATA[machine-learning]]></category>
            <category><![CDATA[google]]></category>
            <dc:creator><![CDATA[Ali Nawaz]]></dc:creator>
            <pubDate>Mon, 26 Jan 2026 12:54:43 GMT</pubDate>
            <atom:updated>2026-01-26T13:01:46.846Z</atom:updated>
            <content:encoded><![CDATA[<figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*K9heBCZ2zIDvEttyx51mCA.png" /></figure><p>If you’re new to JAX, you’ll keep seeing three words everywhere:</p><blockquote><em>jit, grad, and vmap</em></blockquote><p>People use them casually, like:</p><blockquote><em>“Just jit it”</em></blockquote><blockquote><em>“Wrap it with vmap”</em></blockquote><blockquote><em>“Grad takes care of that”</em></blockquote><p>And you’re left thinking:</p><blockquote>“I know Python. I know NumPy. Why does JAX feel… different?”</blockquote><p>This blog is here to clear that confusion <strong>without heavy math</strong> and <strong>without magic</strong>.</p><h3>The problem JAX is designed to solve</h3><p>Let’s start simple.</p><p>Python is <strong>easy to write</strong>, but it’s <strong>slow</strong> at runtime.</p><p>Accelerators like <strong>GPU and TPU</strong> are <strong>extremely fast</strong>, but they hate Python-style loops and dynamic behavior.</p><p>JAX sits in the middle and asks:</p><blockquote><em>“What if we write normal Python functions…</em></blockquote><blockquote><em>but transform them into something fast, differentiable, and parallel?”</em></blockquote><p>That’s where jit, grad, and vmap come in.</p><p>They don’t add features.</p><p>They <strong>transform your function</strong>.</p><h3>Functions are the main abstraction in JAX</h3><p>In JAX, you don’t write classes or graphs.</p><p>You write <strong>plain Python functions</strong>:</p><pre>def loss(w, x):<br>    return jnp.sum(w * x)</pre><p>Then JAX says:</p><ul><li>Want gradients? Use grad</li><li>Want speed? Use jit</li><li>Want batching? Use vmap</li></ul><p>Each of these <strong>takes a function and returns a new function</strong>.</p><p>That’s it.</p><p>No hidden state. No magic objects.</p><h3>Understanding grad through simple examples</h3><p>Think of grad as:</p><blockquote><em>“Given a function, create another function that computes its derivative.”</em></blockquote><p>Example:</p><pre>from jax import grad<br><br>def f(x):<br>    return x ** 2<br>df_dx = grad(f)<br>df_dx(3.0)   # → 6.0</pre><p>What happened?</p><ul><li>The original function stays the same</li><li>grad(f) returns a new function</li><li>That new function computes how f changes with respect to x</li></ul><p>Important thing to understand:</p><p>JAX traces your function, builds a computation graph internally, and applies automatic differentiation.</p><p>You <strong>never</strong> write derivative code yourself.</p><h3>What actually happens when you use jit</h3><p>jit stands for <strong>Just-In-Time compilation</strong>.</p><p>But practically, it means:</p><blockquote><em>“Run this function once, understand it fully, then compile it to fast machine code.”</em></blockquote><p>Example:</p><pre>from jax import jit<br><br>@jit<br>def compute(x):<br>    return x * x + 2</pre><p>The <strong>first run</strong> is slower:</p><ul><li>JAX traces the function</li><li>Sends it to XLA</li><li>Compiles it</li></ul><p>After that?</p><ul><li>Runs at <strong>C / CUDA / TPU speed</strong></li></ul><p>Key idea:</p><ul><li>jit doesn’t change <em>what</em> your function does</li><li>It changes <em>how</em> it runs</li></ul><h3>How vmap removes the need for manual batching</h3><p>Normally, you write loops like this:</p><pre>results = []<br>for x in batch:<br>    results.append(f(x))</pre><p>Loops are slow and messy.</p><p>vmap says:</p><blockquote><em>“This function works for one input.</em></blockquote><blockquote><em>I’ll automatically make it work for a batch.”</em></blockquote><p>Example:</p><pre>from jax import vmap<br><br>def f(x):<br>    return x ** 2<br>batched_f = vmap(f)<br>batched_f(jnp.array([1, 2, 3]))  # → [1, 4, 9]</pre><ul><li>No loops.</li><li>No manual batching.</li></ul><p>Internally, JAX turns this into a <strong>single vectorized computation</strong>.</p><h3>Using multiple transformations together</h3><p>Here’s where JAX becomes special.</p><p>You can <strong>stack transformations</strong>:</p><pre>fast_grad = jit(grad(f))</pre><p>This means:<br>Take f, create its gradient, then compile the gradient itself.</p><p>Or:</p><pre>batched_grad = vmap(grad(f))</pre><p>Now you have gradients for a whole batch without writing a loop.</p><p>This composability is the <strong>core design win of JAX</strong>.</p><h3>Common mistakes when using jit , grad, and vmap</h3><p>Many issues people face with JAX come from incorrect expectations rather than bugs in code.</p><p>Assuming jit will speed up Python logic<br>jit does not optimize Python if statements or loops. That code still executes in Python, outside the compiled part.</p><p>Frequently changing input shapes<br>When shapes change, JAX has to recompile the function. This removes most of the performance benefit of using jit.</p><p>Using randomness without managing PRNG keys<br>JAX does not hide randomness. You must pass and update PRNG keys explicitly, otherwise results may be incorrect or confusing.</p><p>Treating jit, grad, and vmap as simple utilities<br>These are not helper functions. Each one rewrites how your function behaves internally.</p><h3>Key ideas to remember</h3><ul><li><strong>grad</strong> gives you automatic differentiation</li><li><strong>jit </strong>gives you speed through compilation</li><li><strong>vmap</strong> gives you batching without loops</li></ul><p>But the most important idea is this:</p><blockquote>JAX is not a traditional framework.</blockquote><blockquote>It’s a <strong>function transformation system</strong>.</blockquote><p>Once this clicks, everything else Flax, Optax, and even TPUs starts making sense.</p><h3>What’s next</h3><p>In the next post, we’ll focus on how randomness works in JAX and why it looks so different from NumPy or Python.</p><p>We’ll break down what PRNG keys are, why you have to pass them around explicitly, and how this design helps JAX stay reproducible, parallel-friendly, and pure.</p><p>By the end, you’ll understand how JAX generates random numbers without hidden state, and why this approach matters once you start using jit, vmap, and accelerators.</p><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=ed4734f0eb16" width="1" height="1" alt="">]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[What the Compact API in Flax Really Does: The Design Choice That Keeps Parameter Creation…]]></title>
            <link>https://medium.com/@AliPythonDev/what-the-compact-api-in-flax-really-does-the-design-choice-that-keeps-parameter-creation-c9747388324e?source=rss-72f6eb2992e8------2</link>
            <guid isPermaLink="false">https://medium.com/p/c9747388324e</guid>
            <category><![CDATA[flax]]></category>
            <category><![CDATA[deep-learning]]></category>
            <category><![CDATA[machine-learning]]></category>
            <category><![CDATA[jax]]></category>
            <dc:creator><![CDATA[Ali Nawaz]]></dc:creator>
            <pubDate>Sat, 27 Dec 2025 09:38:32 GMT</pubDate>
            <atom:updated>2025-12-27T09:38:32.238Z</atom:updated>
            <content:encoded><![CDATA[<figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*VoBYIshnnVWqTRGC_SetIQ.png" /></figure><h3><strong>What the Compact API in Flax Really Does: </strong>The Design Choice That Keeps Parameter Creation Functional</h3><h3>Why this topic matters</h3><p>When you first see the compact API in Flax, it feels confusing.</p><p>You see layers being created inside the forward function.<br>You see parameters appearing where computation should happen.</p><p>If you care about functional programming, this feels wrong.</p><p>JAX teaches us that functions should be pure.<br>So how can Flax allow parameter creation inside the forward pass and still stay functional?</p><p>This blog exists to answer that question slowly and clearly.</p><h3>The core confusion beginners face</h3><p>In most deep learning libraries, parameter creation and computation are separated.</p><p>You define layers first.<br>Then you run data through them.</p><p>But with the compact API in <strong>Flax</strong>, you often see code like this.</p><p>Inside the forward function, layers appear directly.<br>Dense layers.<br>Convolutions.<br>Normalization.</p><p>It looks like parameters are being created every time the function runs.</p><p>But they are not.</p><p>Understanding why is the key to understanding the compact API.</p><h3>Step 1: Compact is about permission, not behavior</h3><p>The most important idea is this.</p><p>The compact API does not change what your function does.<br>It changes <strong>when</strong> Flax is allowed to create parameters.</p><p>By marking a function as compact, you are telling Flax:</p><p>You may create parameters here, but only during initialization.</p><p>That permission is tightly controlled.</p><p>Outside initialization, parameter creation is completely disabled.</p><h3>Step 2: Two runs of the same function</h3><p>Every Flax model runs in two very different modes.</p><p>Initialization mode<br>Application mode</p><p>During initialization:</p><p>The forward function runs once.<br>Flax watches it carefully.<br>Whenever it sees a layer that needs parameters, it creates them.<br>All parameters are stored outside the function in a parameter tree.</p><p>During application:</p><p>The same function runs again.<br>But now parameters already exist.<br>Nothing new is created.<br>Everything is reused.</p><p>The code looks the same.<br>The behavior is not.</p><h3>Step 3: Why this does not break functional programming</h3><p>This is the subtle but powerful design choice.</p><p>The function itself never stores parameters.<br>It never mutates state.<br>It never remembers anything.</p><p>All parameters live outside the function and are passed in explicitly.</p><p>So from JAX’s point of view, the model is still just a function.</p><p>Give it parameters and inputs.<br>Get outputs back.</p><p>That is why Flax works perfectly with <strong>JAX</strong> transformations like jit and grad.</p><h3>Step 4: A simple mental model</h3><p>Think of a factory inspection.</p><p>The first time, inspectors walk through the building.<br>They write down where machines should exist.<br>They note their shapes and requirements.</p><p>After inspection, the factory is fixed.</p><p>Every production run after that uses the same machines.<br>No new machines are added.</p><p>The compact API is that inspection step.</p><p>Once it is done, everything becomes stable and predictable.</p><h3>Step 5: Why Flax introduced the compact API</h3><p>The original setup style in Flax is explicit and clear.</p><p>But it can feel split.</p><p>Structure lives in one place.<br>Computation lives in another.</p><p>The compact API solves this by allowing you to describe structure exactly where computation happens.</p><p>You read the model top to bottom.<br>You see data flow clearly.<br>You understand the model without jumping between methods.</p><p>It improves readability without sacrificing correctness.</p><h3>Step 6: What compact is not doing</h3><p>It is important to say what the compact API does not do.</p><p>It does not recreate parameters every forward pass.<br>It does not store state inside the module.<br>It does not make the model object oriented in a traditional sense.</p><p>Everything still follows the functional rules strictly.</p><h3>Common beginner mistakes</h3><p><strong>Mistake one:</strong> Thinking compact layers run differently<br>They do not.</p><p><strong>Mistake two:</strong> Mixing setup and compact styles in the same module<br>This leads to confusion.</p><p><strong>Mistake three:</strong> Forgetting that init and apply behave differently<br>Most bugs come from this misunderstanding.</p><p><strong>Mistake four:</strong> Assuming compact is a shortcut or hack<br>It is a deliberate design choice.</p><h3>Final takeaway</h3><p>The compact API is not magic.</p><p>It is a carefully designed permission system.</p><p>It allows parameters to be created during initialization.<br>It forbids parameter creation during execution.<br>It keeps the model functional at all times.</p><p>Once you understand that separation, the compact API stops feeling strange and starts feeling natural.</p><h3>What’s Next</h3><p>In the next post, we will look at how Flax tracks parameters across nested modules.</p><p>We will explore how names are assigned, how parameter trees are built, and how Flax ensures the right parameters are reused at the right time without relying on hidden state.</p><p>This will complete the picture of how Flax manages complexity while staying fully functional.</p><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=c9747388324e" width="1" height="1" alt="">]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[How Flax Modules Work Internally: The Design Choice That Keeps JAX Functional]]></title>
            <link>https://medium.com/@AliPythonDev/how-flax-modules-work-internally-9e499890ed24?source=rss-72f6eb2992e8------2</link>
            <guid isPermaLink="false">https://medium.com/p/9e499890ed24</guid>
            <category><![CDATA[jax]]></category>
            <category><![CDATA[deep-learning]]></category>
            <category><![CDATA[machine-learning]]></category>
            <category><![CDATA[tensorflow]]></category>
            <dc:creator><![CDATA[Ali Nawaz]]></dc:creator>
            <pubDate>Thu, 25 Dec 2025 06:24:08 GMT</pubDate>
            <atom:updated>2025-12-27T09:40:25.440Z</atom:updated>
            <content:encoded><![CDATA[<figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*M__SGXl8tx3aIBhvAm57LQ.png" /></figure><h3>Why this topic matters</h3><p>When people start using Flax with JAX, one thing feels strange very quickly.</p><p>You write a neural network class, but there is no state inside it.<br>There is no self.weights being updated.<br>There is no hidden magic happening.</p><p>Yet everything works.</p><p>This can feel confusing if you come from PyTorch or TensorFlow.</p><p>Understanding how Flax modules work internally removes this confusion.<br>It helps you debug better, write cleaner models, and stop treating Flax like black magic.</p><p>Let us break it down slowly.</p><h4>The big idea in one sentence</h4><p>A Flax module is just a <strong>pure function description</strong> that knows how to create parameters and how to use them, but it never owns them.</p><p>That sentence will make sense by the end.</p><h3>Step 1: Forget what a class usually means</h3><p>In most deep learning frameworks, a model class behaves like this:</p><p>The class stores weights<br>The class updates weights<br>The class remembers state</p><p>Flax does not work like that.</p><p><strong>In Flax:</strong></p><p>The module describes structure<br>The parameters live outside the module<br>The data flows in and out</p><p>Think of a Flax module as a <strong>recipe</strong>, not a container.</p><p>A recipe explains how to cook.<br>It does not store the ingredients.</p><h3>Step 2: What a Flax Module really is</h3><p>A Flax module answers two questions.</p><ol><li>What parameters do I need</li><li>How do I compute outputs using those parameters</li></ol><p><strong>That is all.</strong></p><p>Internally, a module has:<br>A setup phase where submodules are defined<br>A call function that describes computation</p><p>But no actual parameter values live inside it.</p><p>This is the key mental shift.</p><h3>Step 3: Where do parameters come from then?</h3><p>Parameters are created during <strong>initialization</strong>.</p><p>When you call something like:</p><p>model.init(random_key, input)</p><p>Flax does this internally:</p><ol><li>Walk through the module tree</li><li>Look at each layer</li><li>Ask each layer what parameters it needs</li><li>Create arrays using the random key</li><li>Store everything in a separate dictionary</li></ol><p>The result is a <strong>parameter tree</strong>, not a model object with weights.</p><p>You now have two things:</p><p>The module which is just structure<br>The parameters which are just data</p><p>They are separate on purpose.</p><h3>Step 4: What happens during a forward pass</h3><p>When you call:</p><p>model.apply(params, input)</p><p>Here is what really happens:</p><p>The module code runs like a normal function<br>Whenever it needs a weight, it looks it up in params<br>It performs computation<br>It returns output</p><p>Nothing is stored.<br>Nothing is mutated.</p><p>This is why Flax works so well with JAX transformations like jit and grad.</p><p>Pure functions are easy to optimize.</p><h3>Step 5: A simple mental model</h3><p>Imagine a calculator.</p><p>The calculator body has buttons and logic.<br>The numbers you type are inputs.</p><p>Now imagine the calculator does not remember any numbers after the operation.</p><p>Each time:</p><p>You give it numbers<br>It computes<br>It gives output<br>Then it forgets everything</p><p>That is how Flax modules behave.</p><p>The calculator is the module.<br>The numbers are parameters and inputs.</p><h3>Step 6: Why setup exists</h3><p>You may wonder why Flax has a setup function.</p><p>Setup is just a place to define structure.</p><p>For example:</p><p>This module contains two dense layers<br>This module contains a convolution and normalization</p><p>No weights are created here.<br>Only relationships are defined.</p><p>During init, Flax reads this structure and creates parameters accordingly.</p><h3>Step 7: Why this design is powerful</h3><p>This design gives you three big advantages.</p><p>First, clarity.<br>You always know where state lives.</p><p>Second, safety.<br>No accidental hidden mutations.</p><p>Third, compatibility.<br>JAX transformations work naturally.</p><p>Once you accept that modules do not own parameters, everything becomes simpler.</p><h3>Common beginner mistakes</h3><p><strong>Mistake 1:</strong> Expecting parameters inside self<br>They are never there.</p><p><strong>Mistake 2:</strong> Thinking init runs the model normally<br>Init is about discovering parameters, not training.</p><p><strong>Mistake 3:</strong> Treating apply as stateful<br>Apply is just a function call.</p><p><strong>Mistake 4:</strong> Trying to modify parameters inside the module<br>Always return new parameters instead.</p><h3>Final takeaway</h3><p>Flax modules are not containers.<br>They are blueprints.</p><p>They describe how to create parameters.<br>They describe how to use parameters.<br>But they never store them.</p><p>Once you see a Flax module as a pure function plus structure, the confusion disappears.</p><p>That mental model will save you hours of frustration as you go deeper into JAX and Flax.</p><h3>What’s Next</h3><p>In the next post, we will focus on the <a href="https://medium.com/@AliPythonDev/what-the-compact-api-in-flax-really-does-the-design-choice-that-keeps-parameter-creation-c9747388324e"><strong>Compact API in Flax</strong></a><strong>.</strong></p><p>We will explain what @compact means, why it exists, and how Flax allows parameter creation inside the forward pass without turning your model into a stateful object.</p><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=9e499890ed24" width="1" height="1" alt="">]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[How JAX Compiles Your Code: The Secret Relationship Between JAX and XLA]]></title>
            <link>https://medium.com/@AliPythonDev/how-jax-compiles-your-code-the-secret-relationship-between-jax-and-xla-77df4e50e444?source=rss-72f6eb2992e8------2</link>
            <guid isPermaLink="false">https://medium.com/p/77df4e50e444</guid>
            <category><![CDATA[deep-learning]]></category>
            <category><![CDATA[xla]]></category>
            <category><![CDATA[jax]]></category>
            <category><![CDATA[machine-learning]]></category>
            <category><![CDATA[tpu]]></category>
            <dc:creator><![CDATA[Ali Nawaz]]></dc:creator>
            <pubDate>Sun, 07 Dec 2025 08:02:15 GMT</pubDate>
            <atom:updated>2025-12-25T06:29:30.904Z</atom:updated>
            <content:encoded><![CDATA[<figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*irdmhjBf3KeWHTyjaa4v7A.png" /></figure><p>In the previous post, we explored why <a href="https://medium.com/@AliPythonDev/why-jax-thinks-differently-the-functional-mindset-behind-its-power-08bf947a9de8">JAX thinks differently</a>, why it separates parameters from computation, and how its functional mindset gives you more control than object-oriented frameworks. But there is still one major question left unanswered.</p><p>If JAX is just NumPy with magical transformations, how does it suddenly become so fast on GPUs and TPUs?</p><p>To understand that, we need to uncover how JAX compiles your Python code. And the moment you see this clearly, the entire design of JAX makes perfect sense.</p><p>Let’s break this down with simple analogies, clear intuition, and real code.</p><h3>The Big Analogy: JAX as an Architect, XLA as the Construction Team</h3><p>Imagine you design a building on paper. You draw the blueprint. That blueprint is your Python function.</p><p>But the blueprint alone cannot build a skyscraper.<br> You need a construction team that knows how to pour concrete, lift steel, and assemble everything using heavy machinery.</p><p>In JAX, the architect is JAX itself.<br> The construction team is XLA.</p><p>You write the instructions at a high level.<br> JAX analyzes your blueprint.<br> XLA takes that blueprint and builds the fastest possible version of it for your hardware.</p><p>This partnership is the reason JAX feels different from every other framework.</p><h3>What Actually Happens When You Use jit</h3><p>Let’s take a simple function.</p><pre>import jax<br>import jax.numpy as jnp<br><br>def compute(x):<br>    return jnp.sin(x) + jnp.cos(x)</pre><p>If you call the function normally, JAX executes it operation by operation.</p><p>But when you wrap it with jit:</p><pre>fast_compute = jax.jit(compute)</pre><p>JAX does something entirely different.<br> It stops executing your code.<br> Instead, it begins <strong>tracing</strong> your function.</p><p>Tracing is like JAX walking through your function slowly, collecting the mathematical steps you wrote.<br> It does not run them.<br> It records them.</p><p>Once JAX collects those steps, it hands them over to XLA, and XLA starts building.</p><p>It fuses operations together.<br> It rearranges them for maximum parallelism.<br> It removes unnecessary steps.<br> It compiles everything into one optimized accelerator program.</p><p>Then, when you finally call:</p><pre>y = fast_compute(jnp.ones(1000000))</pre><p>You are not running Python anymore.<br> You are running pure, optimized machine code on the GPU or TPU.</p><h3>Why GPUs Love XLA’s Style of Execution</h3><p>GPUs do not like receiving small, separate tasks.<br> They want one big package of work to run in parallel.</p><p>In many frameworks, every operation becomes a separate GPU call. This causes overhead. It is like asking a construction team to build your house brick by brick instead of giving them full walls.</p><p>XLA avoids this problem entirely.<br> It merges your operations into a single fused kernel.</p><p>So instead of sending:<br>Compute sin<br>Compute cos<br>Add sin and cos</p><p>XLA sends:<br>Compute sin(x) + cos(x) in one fused operation<br>This fusion is the secret behind JAX’s speed.</p><h3>TPU Compilation: Why JAX Fits TPUs Naturally</h3><p>TPUs cannot interpret Python at all.<br> Everything must be compiled into a TPU-compatible program before execution.</p><p>This is where JAX and XLA shine.<br> Since JAX functions are pure and stateless, they are perfectly suited for compilation.</p><p>A JAX function that works on CPU will work on GPU.<br> The same function will work on TPU.<br> Nothing needs to be rewritten.</p><p>Let’s test this idea in code.</p><pre>def forward(x):<br>    return jnp.tanh(x * 3.0)<br>    <br>compiled = jax.jit(forward)</pre><p>Whether forward runs on CPU, GPU, or TPU depends only on the device. The function itself never changes.</p><p>This is why Google researchers often use JAX on TPUs. The compilation pipeline is stable, predictable, and incredibly efficient.</p><h3>Device Placement: Moving Work to a GPU or TPU</h3><p>To explicitly send work to a GPU, JAX makes it easy.</p><pre>gpu = jax.devices(&quot;gpu&quot;)[0]<br><br>def func(x):<br>    return jnp.sqrt(x)<br>compiled = jax.jit(func)<br>x = jnp.ones((1000, 1000))<br>y = compiled(x).block_until_ready()</pre><p>If your device is GPU-enabled, the compiled version automatically runs there.<br> JAX selects the most powerful device unless told otherwise.</p><p>You can confirm the device by printing:</p><pre>print(y.device())</pre><p>It will show something like:</p><p>GpuDevice(id=0)</p><p>This confirms that the function was compiled for and executed on the GPU.</p><h3>Why This Compilation Approach Makes JAX Unique</h3><p>Most frameworks interpret operations eagerly.<br> JAX compiles entire functions.<br> Most frameworks keep internal state hidden inside objects.<br> JAX keeps everything explicit and pure.<br> Most frameworks send many small ops to accelerators.<br> XLA fuses them into large kernels.</p><p>This is why JAX feels different.<br> This is why JAX feels fast.<br> This is why JAX scales to extremely large models with fewer surprises.</p><p>The functional mindset is not just a programming style.<br> It is the key that unlocks compilation.<br> And compilation is the key that unlocks speed.</p><h3>What’s Next</h3><p>In the next post, we will explore <a href="https://medium.com/@AliPythonDev/how-flax-modules-work-internally-9e499890ed24"><strong>how Flax modules work internally</strong></a>, why they look object-oriented even though JAX is functional, how the compact API actually works, and how Flax manages parameters under the hood without breaking the functional model.</p><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=77df4e50e444" width="1" height="1" alt="">]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[Why JAX Thinks Differently: The Functional Mindset Behind its Power]]></title>
            <link>https://medium.com/@AliPythonDev/why-jax-thinks-differently-the-functional-mindset-behind-its-power-08bf947a9de8?source=rss-72f6eb2992e8------2</link>
            <guid isPermaLink="false">https://medium.com/p/08bf947a9de8</guid>
            <category><![CDATA[ml-engineering]]></category>
            <category><![CDATA[ai-research]]></category>
            <category><![CDATA[jax]]></category>
            <category><![CDATA[deep-learning]]></category>
            <category><![CDATA[machine-learning]]></category>
            <dc:creator><![CDATA[Ali Nawaz]]></dc:creator>
            <pubDate>Fri, 05 Dec 2025 17:42:41 GMT</pubDate>
            <atom:updated>2025-12-05T17:52:49.026Z</atom:updated>
            <content:encoded><![CDATA[<figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*5on1S_G2EAobxqJVvlJQoQ.png" /></figure><p>In the previous post, we explored how <a href="https://medium.com/@AliPythonDev/jax-vs-tensorflow-vs-pytorch-a-deeper-look-for-beginners-0b26c948f284">JAX compares to TensorFlow and PyTorch</a>, and why its speed-focused design feels so different. But one question still remains unanswered:</p><blockquote>Why does JAX feel so unusual when you first write code?</blockquote><p>The reason is simple. JAX follows a <strong>functional programming mindset</strong>, not an object-oriented one. Understanding this mindset changes everything. It explains why model parameters must be passed separately, why randomness requires explicit keys, and why JAX feels so clean once it “clicks.”</p><p>Let’s break down this mindset with simple analogies, clear intuition, and code.</p><h3>The Big Analogy: Blueprint vs Machine</h3><p>Think of building a machine in two different ways.</p><p>In the first style, you create one big machine object. It holds its internal parts, remembers its state, and changes itself as it works. This is how PyTorch models behave. The object contains its own weights and updates them.</p><p>In the second style, you keep the <strong>blueprint</strong> separate from the <strong>building materials</strong>. The blueprint never changes. Instead, you pass different sets of materials whenever you want to build something.</p><p>This is how JAX thinks.<br> The model is the blueprint.<br> The parameters are separate.<br> The function is pure.</p><p>When these two are separate, something powerful happens. JAX can analyze the function mathematically. It can optimize it. It can differentiate it. It can compile it. And it can parallelize it without worrying about hidden state.</p><p>This separation is the heart of the functional mindset.</p><h3>How This Looks in Real Code</h3><p>Let’s recreate a tiny model in PyTorch style and JAX style to see the difference.</p><p>PyTorch keeps everything inside one object.</p><pre>import torch<br>import torch.nn as nn<br><br>class Model(nn.Module):<br>    def __init__(self):<br>        super().__init__()<br>        self.linear = nn.Linear(10, 1)<br><br>    def forward(self, x):<br>        return self.linear(x)<br><br>model = Model()<br>x = torch.randn(5, 10)<br>output = model(x)</pre><p>The model holds its own weights and uses them internally.</p><p>Now look at JAX with Flax.</p><pre>import jax<br>import jax.numpy as jnp<br>from flax import linen as nn<br><br>class Model(nn.Module):<br>    @nn.compact<br>    def __call__(self, x):<br>        return nn.Dense(1)(x)<br><br>model = Model()<br><br>key = jax.random.PRNGKey(0)<br>dummy = jnp.ones((1, 10))<br>params = model.init(key, dummy)[&#39;params&#39;]<br><br>x = jnp.ones((5, 10))<br>output = model.apply({&#39;params&#39;: params}, x)</pre><p>JAX forces you to handle parameters explicitly. It may feel unfamiliar at first, but this gives you more control than any other framework.</p><h3>Why JAX Uses Explicit Random Keys</h3><p>Another thing beginners find strange is how JAX handles randomness. Instead of simply calling a random function, JAX requires you to pass an explicit PRNGKey every time.</p><p>It looks like this:</p><pre>key = jax.random.PRNGKey(42)<br>x = jax.random.normal(key, (3,))</pre><p>Why does JAX do this?</p><p>Let’s understand with an analogy.</p><p>Imagine you have a robot that generates random numbers. In most frameworks, you ask the robot for a random number and it uses some hidden state internally to give you one.</p><p>But in JAX, the robot does not keep any hidden state. You must hand it a “token” (the key) to use for randomness. The robot then returns two things:</p><p>A random number<br> A new key to use next time</p><p>No hidden state. No surprises.</p><p>This is what makes randomness reproducible and traceable, even inside jit, vmap, grad, and across multiple devices.</p><h3>Splitting Keys: The Part People Forget</h3><p>In JAX, using the same key twice is not allowed.<br> You must always split your key before using it again.</p><p>Example:</p><pre>key = jax.random.PRNGKey(0)<br><br>key, subkey = jax.random.split(key)<br>a = jax.random.normal(subkey)<br><br>key, subkey = jax.random.split(key)<br>b = jax.random.uniform(subkey)</pre><p>Each split creates a fresh, independent source of randomness.</p><p>This is why JAX code looks so clean inside large distributed systems: all randomness is explicit, controlled, and reproducible.</p><h3>Why Functional Style + PRNG Makes JAX So Powerful</h3><p>Let’s create a simple example that uses both ideas together.</p><p>We will write a function that initializes random weights and processes data.</p><pre>def forward(params, x):<br>    w, b = params<br>    return w * x + b<br><br>def init_params(key):<br>    key_w, key_b = jax.random.split(key)<br>    w = jax.random.normal(key_w)<br>    b = jax.random.normal(key_b)<br>    return w, b<br><br>key = jax.random.PRNGKey(0)<br>params = init_params(key)<br><br>x = jnp.array(2.0)<br>output = forward(params, x)</pre><p>Here is what is happening.</p><p>The function forward is pure. It has no hidden state.<br> All randomness is explicit in init_params.<br> The parameters are separate and visible.</p><p>Now JAX can do amazing things with this code.</p><p>It can jit-compile it.<br> It can differentiate it.<br> It can vectorize it.<br> It can run it across multiple GPUs.</p><p>All because the function is pure and stateless.</p><p>No hidden state means complete freedom to optimize.</p><h3>How This Helps You in Real Projects</h3><p>This mindset makes JAX especially good for:</p><p>Research, where you need complete control<br> Physics simulations, where equations must be exact<br> Distributed training, where hidden state becomes messy<br> Writing your own custom layers or optimizers<br> Building very large models where structure matters</p><p>JAX may feel strict at first, but once you adapt, the clarity becomes addictive.<br> You always know where your parameters are.<br> You always know where your randomness comes from.<br> You always know how your function behaves.</p><p>There is no hidden magic. Everything is explicit.</p><p>And this explicitness is what allows JAX to be so powerful.</p><h3>What’s Next</h3><p>In the next post, I will explain <strong>How JAX Compiles Your Code</strong>, explore the secret relationship between JAX and XLA, and show how your Python functions are transformed into highly optimized accelerator programs.</p><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=08bf947a9de8" width="1" height="1" alt="">]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[JAX vs. TensorFlow vs. PyTorch: A Deeper Look for Beginners]]></title>
            <link>https://medium.com/@AliPythonDev/jax-vs-tensorflow-vs-pytorch-a-deeper-look-for-beginners-0b26c948f284?source=rss-72f6eb2992e8------2</link>
            <guid isPermaLink="false">https://medium.com/p/0b26c948f284</guid>
            <category><![CDATA[deep-learning]]></category>
            <category><![CDATA[python]]></category>
            <category><![CDATA[google-colab]]></category>
            <category><![CDATA[jax]]></category>
            <category><![CDATA[machine-learning]]></category>
            <dc:creator><![CDATA[Ali Nawaz]]></dc:creator>
            <pubDate>Sun, 02 Nov 2025 06:08:50 GMT</pubDate>
            <atom:updated>2025-11-02T06:08:50.439Z</atom:updated>
            <content:encoded><![CDATA[<figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*aR2uACptJrfAXEgRcSPVLA.png" /></figure><p>In the previous post, we walked through “<a href="https://medium.com/@AliPythonDev/why-jax-exists-the-need-for-speed-and-simplicity-in-ml-8e3137928caa">Why JAX Exists? The Need for Speed and Simplicity in ML</a>”</p><p>Now that we know <em>why</em> JAX was created, it leads to the most important question for any beginner: “This is all great, but… <strong>which one should I actually learn?</strong>”</p><p>It’s a confusing choice. You see JAX, TensorFlow, and PyTorch everywhere. But here’s the simple truth: there is no single “best” framework. There is only the <strong>best framework for your specific goal.</strong></p><p>Let’s ditch the high-level talk and go in-depth to compare them with simple analogies and code.</p><h3>The Big Analogy: Choosing Your Vehicle</h3><p>Think of these three frameworks as different types of vehicles, each built for a different job.</p><p><strong>TensorFlow is the Industrial Truck:</strong></p><ul><li>It’s a heavy-duty, industrial powerhouse. It’s designed to handle massive, end-to-end projects.</li><li>It has a giant ecosystem to help you not just build a model, but also check your data, “serve” your model in the cloud, and even run it on tiny mobile phones (using TFLite) or in a web browser (using TF.js). It’s built for production and reliability.</li></ul><p><strong>PyTorch is the All-Terrain SUV:</strong></p><ul><li>This is the flexible, popular, all-around choice. It’s famous for being “Pythonic,” which means it feels natural and easy to use for Python programmers.</li><li>It’s fantastic for “off-roading” in research because it’s so easy to build and change new, creative ideas. It has a huge community and is now also very powerful in production.</li></ul><p><strong>JAX is the Formula 1 Engine:</strong></p><ul><li>JAX is a specialized machine built for one thing above all else: pure, mind-blowing speed.</li><li>It’s not a complete “vehicle” like the others. It’s more like a revolutionary engine. It doesn’t come with all the “comforts” like built-in data loaders or model-serving tools. It gives you the parts: a speed booster (jit), a slope-finder (grad), and a batch-processor (vmap).</li></ul><p>It’s up to you, the skilled driver, to build the car around it.</p><h3>How They “Feel”: A Simple Code Comparison</h3><p>The best way to understand the difference is to see how they “feel” when you write code. Let’s build the <em>same</em> simple model in all three.</p><h4>PyTorch: The “Pythonic” Object</h4><p>PyTorch is famous for being “object-oriented.” This means you build a model by creating a Python class. It feels very intuitive. You create a &quot;blueprint&quot; (the class), define the &quot;parts&quot; (the layers) in the __init__ setup, and then define how data flows through them in the forward function.</p><pre>import torch<br>import torch.nn as nn<br><br># 1. Define the model&#39;s &quot;blueprint&quot; as a class<br>class SimpleModel(nn.Module):<br>    def __init__(self):<br>        super().__init__()<br>        # 2. Define the parts (layers)<br>        self.layer1 = nn.Linear(in_features=10, out_features=50)<br>        self.activation = nn.ReLU()<br>        self.layer2 = nn.Linear(in_features=50, out_features=1)<br><br>    # 3. Define how data flows through the parts<br>    def forward(self, x):<br>        x = self.layer1(x)<br>        x = self.activation(x)<br>        x = self.layer2(x)<br>        return x<br><br># 4. Create the model and use it!<br>model = SimpleModel()<br>data = torch.randn(64, 10) # A batch of 64 samples<br>output = model(data)</pre><p><strong>The Feel:</strong> This is easy to read. The model is a &quot;thing&quot; that holds its own parts (layers) and its own internal state (the weights).</p><h4>TensorFlow: The “Lego Stack” (with Keras)</h4><p>TensorFlow’s high-level API, <strong>Keras</strong>, is famous for its simplicity.</p><p>For most models, you don’t even need to write a class. You can just stack layers together in a list, like building with Legos.</p><pre>import tensorflow as tf<br>from tensorflow import keras<br>from tensorflow.keras import layers<br><br># 1. Define the model by stacking layers in a list<br>model = keras.Sequential([<br>    layers.Input(shape=(10,)),<br>    layers.Dense(units=50),<br>    layers.Activation(&#39;relu&#39;),<br>    layers.Dense(units=1)<br>])<br><br># 2. Use the model!<br>data = tf.random.normal(shape=(64, 10))<br>output = model(data)</pre><p><strong>The Feel:</strong> This is incredibly fast and simple for building standard models. (For more complex models, you can also write it as a class, just like in PyTorch).</p><h4>JAX: The “Pure Math Function”</h4><p>JAX is different. It is <strong>“functional.”</strong> This is the most important concept to understand.</p><p>A JAX model (often built with a library like <strong>Flax</strong>) is not an “object” that holds its own weights. Instead, the model is just a “pure function,” like a math recipe. You must keep the <strong>model’s weights (parameters)</strong> separate from the <strong>model’s logic (the function)</strong>.</p><pre>import jax<br>import jax.numpy as jnp<br>from flax import linen as nn  # Flax is the popular neural net library for JAX<br><br># 1. Define the model&#39;s &quot;blueprint&quot; (looks similar to PyTorch)<br>class SimpleModel(nn.Module):<br>    @nn.compact<br>    def __call__(self, x):<br>        x = nn.Dense(features=50)(x)<br>        x = nn.relu(x)<br>        x = nn.Dense(features=1)(x)<br>        return x<br><br># THIS is where it gets different!<br>model = SimpleModel()<br><br># 2. You MUST initialize the model to get the parameters (weights)<br>key = jax.random.PRNGKey(0)  # JAX needs an explicit &quot;key&quot; for randomness<br>dummy_data = jnp.ones(shape=(10,))<br>params = model.init(key, dummy_data)[&#39;params&#39;]<br><br># 3. To run the model, you pass the `params` AND the `data`<br>data = jax.random.normal(key, shape=(64, 10))<br>output = model.apply({&#39;params&#39;: params}, data)</pre><p><strong>The Feel:</strong> This seems like extra work, right? We had to call model.init to get the params, then pass the params back in with model.apply. But this &quot;functional&quot; design is the secret to JAX&#39;s power. Because the function and its weights are separate, JAX can easily jit (speed up), grad (get the slope of), and vmap (batch) it.</p><h3>In-Depth Breakdown: What Matters for a Beginner</h3><p>Let’s go deeper than just code. What do these differences mean for you?</p><h4>Community &amp; Ecosystem</h4><p>This is maybe the most important factor for a beginner.</p><p><strong>PyTorch:</strong> Has the <strong>strongest research community</strong>. Almost every new AI paper from top labs is released in PyTorch. The community is massive, the tutorials are excellent, and libraries like <strong>Hugging Face Transformers</strong> (the king of NLP) are PyTorch-first.</p><p><strong>TensorFlow:</strong> Has the <strong>strongest production ecosystem</strong>. Its community is huge and backed by Google. You will find more tools for the <em>full</em> lifecycle: deploying to a server (TensorFlow Serving), running on a phone (TFLite), and monitoring in production (TFX).</p><p><strong>JAX:</strong> Has a <strong>smaller, more advanced community.</strong> You’ll find brilliant people and amazing tools (like Flax, Optax, and Haiku), but you are expected to be more independent. You are the “mechanic” who has to find and assemble the parts (a library for data, a library for optimization, etc.).</p><h4>Debugging (Fixing Your Errors)</h4><p>This is where beginners get the most frustrated.</p><p><strong>PyTorch:</strong> <strong>Easiest to debug.</strong> It runs in “eager mode,” which means it runs just like your normal Python script, one line at a time. If there is an error on line 10, it stops on line 10 and tells you. You can print values anywhere to see what’s happening.</p><p><strong>TensorFlow:</strong> <strong>Also very good.</strong> With Keras and Eager Execution, it’s just as easy to debug as PyTorch for most tasks.</p><p><strong>JAX:</strong> <strong>The hardest to debug.</strong> This is the trade-off for its speed. When you use jax.jit, JAX &quot;compiles&quot; your function into a super-fast version. If there&#39;s an error, it might happen <em>inside</em> that compiled code. The error messages can be very long and confusing until you get used to them.</p><h3>My Advice: Which One Should You Learn?</h3><p>Here is my simple, direct advice based on your goals.</p><h4>Learn PyTorch if…</h4><ul><li>You are an <strong>absolute beginner</strong> and this is your first ML framework.</li><li>You want the <strong>easiest learning curve</strong> and the most tutorials.</li><li>You are interested in <strong>Natural Language Processing (NLP)</strong> or using the latest models from Hugging Face.</li><li>You want to read and understand <strong>new research papers</strong>.</li></ul><blockquote><strong>Bottom Line: </strong>For most beginners today, PyTorch is the best and safest place to start.</blockquote><h4>Learn TensorFlow if…</h4><ul><li>Your main goal is <strong>getting a model into a real-world product.</strong></li><li>You are specifically interested in deploying models on <strong>mobile phones (iOS/Android)</strong> or <strong>in a web browser.</strong></li><li>You work at a company that already has a large, established TensorFlow codebase.</li><li>You want an “all-in-one” toolkit that handles everything from data validation to deployment.</li></ul><blockquote><strong>Bottom Line:</strong> Learn TensorFlow if your job or project is focused on production and deployment.</blockquote><h4>Learn JAX if…</h4><ul><li>You are <strong>NOT a beginner.</strong> You should already be comfortable with ML concepts and <em>either</em> PyTorch or TensorFlow.</li><li>You have a <strong>strong love for math</strong> and want to build things from scratch.</li><li>Your number one need is <strong>raw performance</strong> (e.g., training huge models on TPUs).</li><li>You are a researcher in physics, biology, or advanced AI and need to write custom, high-speed math that isn’t just a standard model.</li></ul><blockquote><strong>Bottom Line:</strong> Learn JAX <em>after</em> you are comfortable with another framework and you find yourself needing more speed and flexibility than it can offer.</blockquote><p><strong>What’s Next:</strong></p><p>Up next, I will guide you step by step on how to install and configure JAX on Google Colab and on your own computer. I will share the exact commands, explain CPU and GPU builds, and help you avoid the common setup issues people run into.</p><p>By the end, you will be ready to run JAX code smoothly and start experimenting.</p><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=0b26c948f284" width="1" height="1" alt="">]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[Why JAX Exists? The Need for Speed and Simplicity in ML]]></title>
            <link>https://medium.com/@AliPythonDev/why-jax-exists-the-need-for-speed-and-simplicity-in-ml-8e3137928caa?source=rss-72f6eb2992e8------2</link>
            <guid isPermaLink="false">https://medium.com/p/8e3137928caa</guid>
            <category><![CDATA[deep-learning]]></category>
            <category><![CDATA[machine-learning]]></category>
            <category><![CDATA[python]]></category>
            <category><![CDATA[jax]]></category>
            <dc:creator><![CDATA[Ali Nawaz]]></dc:creator>
            <pubDate>Sat, 01 Nov 2025 06:14:34 GMT</pubDate>
            <atom:updated>2025-11-01T06:15:33.761Z</atom:updated>
            <content:encoded><![CDATA[<figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*1vssWGMMGHizNnFW7Ho8rQ.png" /></figure><p>In the previous post, we walked through “<a href="https://medium.com/@AliPythonDev/your-first-guide-to-jax-numpy-with-superpowers-0878757d0ab0">Your First Guide to JAX: NumPy with Superpowers</a>” and got comfortable with JAX basics.</p><p>This usually leads to a great question: <strong>“But… why? Don’t we already have TensorFlow and PyTorch?”</strong></p><p>It’s a smart question. <strong>TensorFlow</strong> and <strong>PyTorch</strong> are fantastic, powerful, and used by millions. So why did <strong>Google</strong>, the creator of <strong>TensorFlow</strong>, build <em>another</em> library?</p><p>The answer is that JAX was built to solve a very specific problem that researchers were facing. It fills a gap that existed between <strong>easy, familiar code</strong> (like NumPy) and <strong>high-performance ML</strong> (like TensorFlow).</p><p>To understand this, let’s imagine the world of data science is split into two “kingdoms.”</p><h3>Kingdom 1: The “NumPy” Kingdom (Easy &amp; Familiar)</h3><p>This is the world of traditional science and data analysis.</p><p><strong>Who lives here?</strong> Data scientists, financial analysts, physicists, biologists, and <em>most</em> Python programmers.</p><p><strong>What’s it like?</strong> It’s simple and beautiful! The main tool is <strong>NumPy</strong>. You write simple, clean Python code. You have an idea, you write a function, and it just works.</p><p><strong>Why people love it:</strong></p><ul><li>clean and readable code</li><li>fast to prototype ideas</li><li>easy to understand and teach</li><li>feels like regular Python</li></ul><pre>import numpy as np<br><br>def square(x):<br>    return x * x<br><br>print(square(5))</pre><p><strong>The Problem:</strong> This kingdom is not built for modern, large-scale AI.</p><ol><li><strong>It’s “Slow”:</strong> NumPy code runs on your CPU. It doesn’t know how to use those powerful, expensive <strong>GPUs</strong> or <strong>TPUs</strong> that make machine learning possible.</li><li><strong>No “Slope Finder”:</strong> It’s missing the most important tool for machine learning: <strong>automatic differentiation</strong> (grad). You can&#39;t &quot;train&quot; a NumPy function.</li></ol><p>This makes it hard to take a mathematical idea and turn it into a scalable ML experiment.</p><h3>Kingdom 2: The “Framework” Kingdom (Fast &amp; Powerful)</h3><p>This is the world of deep learning, ruled by <strong>TensorFlow</strong> and <strong>PyTorch</strong>.</p><p><strong>Who lives here?</strong> Machine Learning Engineers and AI Researchers.</p><p><strong>What’s it like?</strong> It’s incredibly powerful. These frameworks are built from the ground up to be fast. They can run on massive clusters of GPUs and TPUs, and they have excellent automatic differentiation.</p><p><strong>PyTorch Example:</strong></p><pre>import torch<br><br>x = torch.tensor(5.0, requires_grad=True)<br>y = x * x  # y = x^2<br>y.backward()  # compute gradient<br>print(x.grad)  # prints 10</pre><p><strong>TensorFlow Example:</strong></p><pre>import tensorflow as tf<br><br>x = tf.Variable(5.0)<br><br>with tf.GradientTape() as tape:<br>    y = x * x  # y = x^2<br><br>dy_dx = tape.gradient(y, x)<br>print(dy_dx.numpy())  # prints 10</pre><p><strong>The Problem:</strong> To live here, you have to learn a whole new set of rules.</p><ol><li><strong>It’s “Heavy”:</strong> You have to learn about Tensors, Sessions (in older TF), Model.compile(), layers, and model.fit().</li><li><strong>It’s “Rigid”:</strong> These frameworks are built to do one job very well: train deep learning models. What if you want to do something <em>weird</em> or new? What if you want to find the gradient of a physics simulation that has complex loops and logic? It can be difficult.</li></ol><h3>The Big Problem: Researchers Were Stuck</h3><p>For years, AI researchers felt stuck between these two kingdoms.</p><ul><li>They loved the <strong>simplicity and flexibility of NumPy</strong> (Kingdom 1).</li><li>But they needed the <strong>speed and </strong><strong>grad function of a framework</strong> (Kingdom 2).</li></ul><p>They found themselves spending more time fighting with a heavy framework than testing their new, creative AI ideas. They had a simple wish:</p><blockquote><em>“I just want to write my simple Python/NumPy code, and I want it to run fast on a GPU and give me gradients.”</em></blockquote><h4>JAX is the Bridge Between the Kingdoms</h4><p>This is exactly why JAX was created.</p><p>JAX is not a new, heavy framework. <strong>JAX is a set of “superpowers” for the NumPy kingdom.</strong></p><p>It was designed to give researchers the best of both worlds. Here’s how it solves the two biggest needs: speed and simplicity.</p><h4>1. The Need for SPEED (Solved by jit)</h4><p>Python is slow because it reads your code one line at a time. A for loop that runs 1,000,000 times means Python makes 1,000,000 decisions.</p><p>A GPU is fast because it’s a “parallel” processor. It’s like having 10,000 workers ready to go. You can’t give them one instruction at a time; you have to give them the <em>whole plan</em> at once.</p><p><strong>How JAX solves this:</strong> jax.jit (Just-in-Time compilation).</p><p><strong>Analogy:</strong></p><ul><li><strong>Normal Python:</strong> A nervous home cook reading a recipe one line at a time. “1. Get a bowl.” (walks to cupboard). <br>2. Get a spoon.” (walks to drawer). Very slow.</li><li><strong>@jax.jit:</strong> A master chef in a giant restaurant kitchen. They read the <em>entire</em> recipe <em>once</em>, optimize it for their 10,000-worker (GPU) kitchen, and shout &quot;Go!&quot; The entire meal is prepared in seconds.</li></ul><p>jax.jit uses a powerful compiler called <strong>XLA</strong> (Accelerated Linear Algebra) to translate your simple Python function into super-fast machine code that runs perfectly on GPUs or TPUs.</p><h4>2. The Need for SIMPLICITY (Solved by grad and vmap)</h4><p>Researchers don’t just want to build standard models. They want to get creative. They want to “compose” functions.</p><p><strong>How JAX solves this:</strong> JAX superpowers are <em>composable</em>. They are just Python functions that transform other Python functions.</p><p>This is the most beautiful part of JAX. Want to get the “slope” of a function? Just wrap it in grad.</p><pre>def my_function(x):<br>    return x**2<br><br>find_slope = jax.grad(my_function)</pre><p>Want to make that “slope” function run really, really fast on a GPU? <br>Just wrap it in jit!</p><pre># Let&#39;s stack our superpowers!<br>fast_slope_finder = jax.jit(jax.grad(my_function))</pre><p>This is <em>revolutionary</em>. You are not building a “model.” You are not fighting a framework. You are just stacking Lego bricks. You can jit a grad, vmap a jit, grad a vmap... whatever you want.</p><p>This “composability” gives researchers the ultimate simplicity and flexibility to build anything they can imagine.</p><h4>Key JAX Features</h4><ol><li><strong>jit for Speed:</strong></li></ol><pre>import jax<br>import jax.numpy as jnp<br><br>def square(x):<br>    return x * x<br><br>fast_square = jax.jit(square)<br>print(fast_square(5))</pre><p><strong>2. grad for Differentiation:</strong></p><pre>def f(x):<br>    return x**2<br><br>df = jax.grad(f)<br>print(df(3.0))  # 6.0</pre><p><strong>3. vmap for Batch Operations:</strong></p><pre>values = jnp.array([1, 2, 3, 4])<br><br>vmap_f = jax.vmap(f)<br>print(vmap_f(values))</pre><p><strong>Composing Transformations:</strong></p><pre>fast_grad = jax.jit(jax.grad(f))<br>print(fast_grad(5.0))</pre><p>Simple functions → compiled + differentiable + vectorized</p><p>No new syntax. No framework overhead.</p><h3>So, Why Does JAX Exist?</h3><p>JAX exists to make AI research <strong>fast, simple, and flexible.</strong></p><p>It was built for researchers who wanted to feel like they were writing clean, simple <strong>NumPy</strong>, but have it execute with the blinding <strong>speed of TensorFlow</strong> on a GPU.</p><p>It gives you the power to write <em>any</em> math function you can think of and instantly be able to:</p><ul><li><strong>jit it:</strong> Make it run crazy fast.</li><li><strong>grad it:</strong> Find its slope (train it).</li><li><strong>vmap it:</strong> Run it on a huge batch of data.</li></ul><p>For beginners, this makes JAX a fantastic tool. It’s not a “black box.” It lets you build from the ground up and <em>really</em> understand the math, all while being as simple and familiar as the NumPy you already know.</p><blockquote>In short: <strong>Write math naturally, run it fast, scale it easily.</strong></blockquote><h3>What’s Next</h3><p>In my next post, I’ll put these tools side by side: <strong>“JAX vs. TensorFlow vs. PyTorch: When to Use Each One.”</strong></p><p>I’ll break down where each framework shines, from fast research prototyping to production-level deployment, and explain when JAX’s simplicity and performance make it the better choice over traditional deep learning libraries.</p><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=8e3137928caa" width="1" height="1" alt="">]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[Your First Guide to JAX: NumPy with Superpowers]]></title>
            <link>https://medium.com/@AliPythonDev/your-first-guide-to-jax-numpy-with-superpowers-0878757d0ab0?source=rss-72f6eb2992e8------2</link>
            <guid isPermaLink="false">https://medium.com/p/0878757d0ab0</guid>
            <category><![CDATA[jax]]></category>
            <category><![CDATA[python]]></category>
            <category><![CDATA[machine-learning]]></category>
            <category><![CDATA[numpy]]></category>
            <category><![CDATA[deep-learning]]></category>
            <dc:creator><![CDATA[Ali Nawaz]]></dc:creator>
            <pubDate>Sun, 26 Oct 2025 16:57:32 GMT</pubDate>
            <atom:updated>2025-10-26T16:57:32.428Z</atom:updated>
            <content:encoded><![CDATA[<figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*udpPITKtaLF9DSnoom0MYQ.png" /></figure><p>Have you ever used NumPy? It’s that amazing Python library everyone uses for working with lists or “arrays” of numbers. It’s fast, powerful, and the foundation of data science.</p><p>Now, what if you could take your NumPy code, make it run <strong>hundreds of times faster</strong> on a GPU or TPU, and give it new “superpowers” for machine learning, all without having to learn a giant new framework?</p><p><strong>That is exactly what JAX is.</strong></p><p>JAX is a high-performance library from Google for math and machine learning. But don’t let “high-performance” scare you. At its heart, JAX is so simple you already know how to use it.</p><p>In short: <strong>JAX is NumPy on steroids.</strong></p><h3>The Secret: JAX is Just NumPy</h3><p>To get started with JAX, you don’t need to learn a new API. You just change your import statement.</p><p>Instead of this:</p><pre>import numpy as np</pre><p>You write this:</p><pre>import jax.numpy as jnp</pre><p>That’s it. You can now use jnp for all your familiar NumPy functions (jnp.array(), jnp.dot(), jnp.sum(), etc.), but JAX will be running things behind the scenes, making it possible to...</p><ol><li>Run your code on powerful <strong>GPUs and TPUs</strong>.</li><li>Use three incredible “superpowers” that are perfect for modern AI.</li></ol><p>Let’s meet those superpowers.</p><h3>Superpower 1: jax.jit (The Speed Booster)</h3><p>jit stands for <strong>J</strong>ust-<strong>I</strong>n-<strong>T</strong>ime compilation.</p><p><strong>What it is:</strong> A “decorator” that you put on top of your Python function to make it incredibly fast.</p><p><strong>The Analogy:</strong></p><ul><li>Imagine a chef following a recipe (your function) one line at a time. <br>“1. Get a bowl.” (walks to cupboard). “2. Get a spoon.” (walks to drawer). “3. Get flour.” (walks to pantry). This is slow.</li><li>@jit is like a <em>smart</em> chef. They read the <em>entire</em> recipe first, optimize it (&quot;I&#39;ll get the bowl, spoon, and flour all in one trip&quot;), and then execute the whole plan as fast as possible.</li></ul><h4>How to Use jit:</h4><p>Let’s say you have a simple function that does some math on a big array.</p><pre>import jax<br>import jax.numpy as jnp<br><br># This is our normal Python function<br>def slow_calculation(x):<br>    # A bunch of math steps<br>    y = jnp.sin(x)<br>    z = jnp.cos(y)<br>    return jnp.sum(z)<br><br># Now, let&#39;s create the &quot;fast&quot; version using jit<br>fast_calculation = jax.jit(slow_calculation)<br><br># You can also just add @jax.jit above the function definition:<br># @jax.jit<br># def fast_calculation(x):<br>#     ...<br><br># Let&#39;s create some data<br>big_array = jnp.arange(1_000_000)<br><br># Run it once to &quot;compile&quot; the function<br>fast_calculation(big_array) <br><br># Now, let&#39;s time it!<br># On a GPU, the JIT-compiled version can be 100x+ faster.</pre><p>By adding jax.jit, you&#39;ve told JAX to pre-compile this function into highly efficient machine code that can run directly on your GPU.</p><h3>Superpower 2: jax.grad (The AI &quot;Guide&quot;)</h3><p>grad stands for <strong>gradient</strong>, which is a core concept in AI.</p><p><strong>What it is:</strong> A function that automatically calculates the “slope” or “derivative” of any other function. This is the magic behind how all AI models learn.</p><p><strong>The Analogy:</strong></p><ul><li>Imagine you are on a foggy mountain and you want to get to the bottom of the valley (the “lowest error”). You can’t see, but you can feel the slope of the ground under your feet.</li><li>To get down, you just feel which way is “downhill” and take a step. You repeat this until you reach the bottom.</li><li>jax.grad is a magic compass that instantly tells you <strong>exactly which direction is the steepest &quot;downhill&quot;</strong> from any point in your function.</li></ul><h4>How to Use grad:</h4><p>In AI, we call the “downhill” direction the <strong>gradient</strong>. Let’s see how grad finds it.</p><p>Let’s use a simple “valley” function: y = x². We all know the bottom of this valley is at x = 0.</p><p>Python</p><pre>import jax<br>import jax.numpy as jnp<br><br># 1. Define our &quot;valley&quot; function<br>def my_function(x):<br>    return x**2<br><br># 2. Create the &quot;guide&quot; function using jax.grad<br>find_slope = jax.grad(my_function)<br><br># 3. Ask the guide for directions at different points<br>slope_at_x_10 = find_slope(10.0)<br>print(f&quot;The slope at x=10 is: {slope_at_x_10}&quot;)<br># Output: 20.0 (points uphill)<br><br>slope_at_x_neg_5 = find_slope(-5.0)<br>print(f&quot;The slope at x=-5 is: {slope_at_x_neg_5}&quot;)<br># Output: -10.0 (points uphill)<br><br>slope_at_x_0 = find_slope(0.0)<br>print(f&quot;The slope at x=0 is: {slope_at_x_0}&quot;)<br># Output: 0.0 (we are at the bottom!)</pre><p>jax.grad automatically figured out the math to find the slope. This is the <strong>entire basis of machine learning</strong>, and JAX does it for you, no matter how complex your function is.</p><h3>Superpower 3: jax.vmap (The Batch Processor)</h3><p>vmap stands for <strong>v</strong>ectorizing <strong>map</strong>.</p><p><strong>What it is:</strong> A tool that automatically “batches” your code. It lets you run a function designed for <em>one</em> item on <em>thousands</em> of items at the same time, in parallel.</p><p><strong>The Analogy:</strong></p><ul><li>Imagine you have a function that “squares one number.” Now you have a list of 1,000 numbers. The slow way is a for loop: &quot;take number, square it, take next number, square it...&quot; 1,000 times.</li><li>jax.vmap is like building a <strong>giant stamp</strong> that squares all 1,000 numbers in a single &quot;press.&quot; It automatically rewrites your function to handle the whole batch at once, efficiently.</li></ul><h4>How to Use vmap</h4><pre>import jax<br>import jax.numpy as jnp<br><br># 1. A simple function that works on ONE number<br>def square(x):<br>    return x * x<br><br># 2. A batch of numbers<br>numbers = jnp.array([1, 2, 3, 4, 5])<br><br># The SLOW way (a Python loop)<br># results = []<br># for x in numbers:<br>#     results.append(square(x))<br><br># 3. The FAST way (using vmap)<br># Create a new function that knows how to handle batches<br>batched_square = jax.vmap(square)<br><br># Run it on all numbers at once<br>results = batched_square(numbers)<br><br>print(results)<br># Output: [ 1  4  9 16 25]</pre><p>This is essential for AI, where you are always processing batches of data (e.g., 64 images at a time, not just one).</p><h3>Putting It All Together</h3><p>Let’s use our new superpowers to solve a real, simple ML problem: <strong>finding the best-fit line</strong> for some data.</p><p>We have some x and y data, and we want to find the best slope (w) and intercept (b) for the line y = w*x + b.</p><pre>import jax<br>import jax.numpy as jnp<br><br># 1. Our Data<br># Let&#39;s try to find the line for y = 3*x + 2<br>true_w = 3.0<br>true_b = 2.0<br>x_data = jnp.array([1.0, 2.0, 3.0, 4.0])<br>y_data = x_data * true_w + true_b<br>print(f&quot;Our data: x={x_data}, y={y_data}&quot;)<br><br><br># 2. Our Model and &quot;Loss&quot;<br># Our model&#39;s prediction<br>def predict(params, x):<br>    return params[&#39;w&#39;] * x + params[&#39;b&#39;]<br><br># A &quot;loss&quot; function tells us how &quot;wrong&quot; our model is.<br># We want to make this number as small as possible.<br>def loss_function(params, x, y):<br>    prediction = predict(params, x)<br>    error = prediction - y<br>    return jnp.mean(error**2) # This is our &quot;foggy valley&quot;<br><br># 3. Use Our Superpowers!<br><br># Use jax.grad to create the &quot;guide&quot; that finds the downhill slope<br># (Notice it&#39;s a &quot;grad&quot; of the &quot;loss_function&quot;)<br>find_loss_gradient = jax.grad(loss_function)<br><br># Use jax.jit to make our guide super fast!<br>@jax.jit<br>def update_step(params, x, y, learning_rate):<br>    # Get the &quot;downhill&quot; direction from our guide<br>    gradients = find_loss_gradient(params, x, y)<br>    <br>    # Take a small step downhill<br>    new_params = {<br>        &#39;w&#39;: params[&#39;w&#39;] - learning_rate * gradients[&#39;w&#39;],<br>        &#39;b&#39;: params[&#39;b&#39;] - learning_rate * gradients[&#39;b&#39;]<br>    }<br>    return new_params<br><br># 4. &quot;Train&quot; the Model<br># Let&#39;s start with a bad guess<br>params = {&#39;w&#39;: 0.0, &#39;b&#39;: 0.0} <br>learning_rate = 0.01<br><br>print(f&quot;Starting guess: w={params[&#39;w&#39;]}, b={params[&#39;b&#39;]}&quot;)<br><br># Let&#39;s take a few steps &quot;downhill&quot;<br>for _ in range(1000):<br>    params = update_step(params, x_data, y_data, learning_rate)<br><br>print(f&quot;Final trained guess: w={params[&#39;w&#39;]:.2f}, b={params[&#39;b&#39;]:.2f}&quot;)</pre><p><strong>Output:</strong></p><pre>Our data: x=[1. 2. 3. 4.], y=[ 5.  8. 11. 14.]<br>Starting guess: w=0.0, b=0.0<br>Final trained guess: w=3.00, b=2.00</pre><p>Look at that! By using jax.grad to find the &quot;downhill&quot; direction and jax.jit to make it fast, we &quot;trained&quot; a model to find the exact answer.</p><h3>Why Should a Beginner Care About JAX?</h3><ul><li><strong>It’s Easy to Start:</strong> If you know NumPy, you already know JAX. There’s no big new API to learn.</li><li><strong>It’s Powerful:</strong> It gives you the three most important tools for modern AI (jit, grad, vmap) for free.</li><li><strong>It’s Fast:</strong> It lets your code use the full power of your GPU/TPU, which is essential for any serious machine learning.</li><li><strong>It’s the Future:</strong> JAX is used by top research labs like Google DeepMind to build the world’s most advanced AI models.</li></ul><p>Learning JAX is a fantastic way to understand how AI <em>really</em> works under the hood, and it gives you a skill that many employers are looking for.</p><h3>How to Install</h3><p>You can try it right now in a Google Colab notebook.</p><pre># Install JAX<br>pip install jax jaxlib</pre><p>Start by importing jax.numpy as jnp and try running your old NumPy code. Then, try adding @jax.jit to a function and see the magic for yourself!</p><h3>What’s Next</h3><p>In the next blog, <strong>“Your First Guide to JAX: NumPy with Superpowers,”</strong> we’ll move from understanding <em>why</em> JAX exists to actually using it.</p><p>You’ll learn how to write your first lines of JAX code, explore how it mirrors NumPy, and discover how JAX’s automatic differentiation and just-in-time compilation make machine learning workflows faster and more efficient.</p><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=0878757d0ab0" width="1" height="1" alt="">]]></content:encoded>
        </item>
        <item>
            <title><![CDATA[What Is JAX? A Friendly Introduction for ML Beginners]]></title>
            <link>https://medium.com/@AliPythonDev/what-is-jax-a-friendly-introduction-for-ml-beginners-fb2c6d0bcbb8?source=rss-72f6eb2992e8------2</link>
            <guid isPermaLink="false">https://medium.com/p/fb2c6d0bcbb8</guid>
            <category><![CDATA[jax]]></category>
            <category><![CDATA[machine-learning]]></category>
            <category><![CDATA[google-ai]]></category>
            <category><![CDATA[deep-learning]]></category>
            <category><![CDATA[python]]></category>
            <dc:creator><![CDATA[Ali Nawaz]]></dc:creator>
            <pubDate>Fri, 10 Oct 2025 17:49:21 GMT</pubDate>
            <atom:updated>2025-10-26T16:53:52.885Z</atom:updated>
            <content:encoded><![CDATA[<figure><img alt="" src="https://cdn-images-1.medium.com/max/1024/1*KvhxcOskAEsqTAKuddX4iA.png" /></figure><p>Have you ever used <strong>TensorFlow</strong> or <strong>PyTorch</strong> and wished your code ran much faster, especially when dealing with massive datasets or complex new AI models?</p><p>The world of Artificial Intelligence is always moving forward, and right now, many top researchers and companies are shifting towards a powerful, yet surprisingly simple tool called <strong>JAX</strong>.</p><p><strong>JAX</strong> is a newer, open-source library developed by <strong>Google</strong> and is quickly becoming the favorite for advanced AI development. Think of it as <strong>Python with Superpowers</strong> for deep learning.</p><p>Just as <strong>TensorFlow</strong> helped make AI development easier for everyone, JAX takes it a step further by bringing unmatched speed, flexibility, and power, allowing your code to run seamlessly on CPUs, GPUs, and even TPUs without changing a single line.</p><p>In <strong>Pakistan</strong>, where innovation in fintech, research, and remote work is booming, mastering a cutting-edge tool like <strong>JAX</strong> can instantly make you stand out.</p><h4>What Exactly is JAX?</h4><p>JAX is a <strong>numerical computing library</strong> that dramatically accelerates your math code, primarily by offering two key capabilities:</p><ol><li><strong>Automatic Differentiation (AutoDiff):</strong> It can instantly and accurately calculate the slope (gradient) of any function you write. This is the secret sauce behind all AI learning.</li><li><strong>XLA (Accelerated Linear Algebra):</strong> It automatically compiles your Python code to run efficiently on <strong>GPUs</strong> (like those used for gaming) and <strong>TPUs</strong> (Google’s custom AI chips).</li></ol><p>Imagine you have a very skilled tailor in Anarkali Bazaar, Lahore. He makes perfect dresses but works slowly, one stitch at a time.</p><p>Now imagine that tailor is given a smart, automated laser machine that follows the same cutting plan but finishes the work in seconds.</p><p>That is exactly what JAX does for your Python code.</p><p>It takes your existing NumPy code and runs it much faster, using advanced technology under the hood.</p><p>In short:</p><ul><li>You still write the same familiar code.</li><li>It runs faster on your machine.</li><li>It can even calculate complex gradients automatically, which is a must-have for machine learning.</li></ul><h3>The Three Core Features of JAX</h3><p>Once you understand these three main tools, you understand JAX.</p><h4>1. jax.grad — The Automatic Differentiator</h4><p>In machine learning, models improve by learning from their mistakes. They do this by finding how much change is needed in each step — something called the <strong>gradient</strong>.</p><p>jax.grad helps you calculate this automatically without writing the math yourself.</p><pre>import jax<br>import jax.numpy as jnp<br><br>def simple_loss(x):<br>  return x**2 + 5 * x + 3<br><br>slope_fn = jax.grad(simple_loss)<br><br>x_value = 2.0<br>slope_at_2 = slope_fn(x_value)<br><br>print(f&quot;Value at x={x_value}: {simple_loss(x_value):.1f}&quot;)<br>print(f&quot;Slope (gradient) at x={x_value}: {slope_at_2:.1f}&quot;)</pre><p>Now the model instantly knows which direction to move to reduce the error.</p><h4>2. jax.jit — The Speed Booster</h4><p>Python is simple but sometimes slow, especially for big models.</p><p>That’s where jax.jit (Just-In-Time compilation) helps. It takes your normal Python function, optimizes it, and runs it as fast machine code, often on a GPU or TPU.</p><p>You write your code once, add one small line, and everything runs much faster.</p><pre>from jax import jit<br>import jax.numpy as jnp<br><br>@jit<br>def multiply_and_sum(x):<br>    return jnp.sum(x * 2)<br><br>data = jnp.arange(1_000_000)<br>print(multiply_and_sum(data))</pre><p>This code runs significantly faster than normal NumPy because JAX compiles it efficiently before execution.</p><h4>3. jax.vmap — The Batch Processor</h4><p>In many AI tasks, you need to repeat the same operation on many samples, for example analyzing a set of images or sensor readings.</p><p>Normally, this requires writing loops, but with jax.vmap, you can apply a function to an entire batch automatically.</p><pre>import jax<br>import jax.numpy as jnp<br><br>def square(x):<br>    return x ** 2<br><br># Apply the function to all elements at once<br>batch_square = jax.vmap(square)<br>numbers = jnp.array([1, 2, 3, 4, 5])<br>print(batch_square(numbers))</pre><p>JAX handles the batch processing efficiently and runs all the operations in parallel.</p><h4>A Simple Real-World Example: Karachi Climate Analysis</h4><p>Let’s imagine you’re studying temperature changes in Karachi. You have daily readings and want to see which days caused the biggest overall change.</p><p>With JAX, you can do this analysis quickly and efficiently.</p><pre>import jax<br>import jax.numpy as jnp<br><br>historic_temps = jnp.array([28.1, 28.5, 29.2, 30.1, 30.0, 29.8, 30.5])<br><br>def calculate_temp_change_metric(temp_array, target_temp=28.0):<br>    return jnp.sum((temp_array - target_temp)**2)<br><br>metric_gradient_fn = jax.grad(calculate_temp_change_metric)<br>compiled_gradient_fn = jax.jit(metric_gradient_fn)<br><br>temp_gradients = compiled_gradient_fn(historic_temps)<br><br>print(&quot;Karachi Climate Analysis&quot;)<br>for day, gradient in enumerate(temp_gradients):<br>    print(f&quot;Day {day+1} Contribution: {gradient:.2f}&quot;)</pre><p>The gradient shows which day’s temperature affected the overall pattern the most, and JAX calculates it in a fraction of a second.</p><h3>Getting Started with JAX</h3><p>If you already know NumPy, you’ll find JAX very easy to use.</p><p><strong>Installation (in Google Colab or locally):</strong></p><pre>pip install jax jaxlib<br># For GPU users<br>pip install jax[cuda12] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html</pre><p>Once installed, just replace numpy with jax.numpy (imported as jnp) in your code.</p><p><strong>Example:</strong></p><pre>import jax.numpy as jnp<br><br>salaries = jnp.array([120000, 85000, 150000, 90000, 210000])<br>total = jnp.sum(salaries)<br>print(f&quot;Total Monthly Payroll: {total:,.0f} PKR&quot;)</pre><p><strong>And if you want it to run faster:</strong></p><pre>from jax import jit<br><br>@jit<br>def calculate_payroll(salaries):<br>    tax_rate = 0.15<br>    net_salary = salaries * (1 - tax_rate)<br>    return jnp.sum(net_salary)<br><br>total_net = calculate_payroll(salaries)<br>print(f&quot;Total Net Payroll after tax: {total_net:,.0f} PKR&quot;)</pre><h4>Why JAX Matters for You</h4><ul><li><strong>It’s fast:</strong> Perfect for large projects, research, and AI models.</li><li><strong>It’s simple:</strong> Works just like NumPy.</li><li><strong>It’s modern:</strong> Used by Google Research, DeepMind, and many AI labs.</li><li><strong>It’s valuable:</strong> Few people know it well, which makes it a great skill for your career.</li></ul><p>If you’re someone curious about how modern AI tools are built and want to explore what’s powering the next wave of machine learning innovation, JAX is a great place to begin your journey.</p><h4>What’s Next</h4><p>In the next blog, <strong>“Your First Guide to JAX: NumPy with Superpowers,”</strong> we’ll take our first practical step into JAX.</p><p>You’ll learn how JAX combines the simplicity of NumPy with the power of automatic differentiation and GPU acceleration, making it the perfect starting point for anyone stepping into the world of modern machine learning.</p><img src="https://medium.com/_/stat?event=post.clientViewed&referrerSource=full_rss&postId=fb2c6d0bcbb8" width="1" height="1" alt="">]]></content:encoded>
        </item>
    </channel>
</rss>