
Introduction
Multi-Modal Language Models (MLLMs) are increasingly forming the core of the brain for general-purpose embodied agents — AI that can navigate and act in the physical world as robots. While MLLMs are making rapid progress, they often stumble on a critical hurdle: precise visual perception. They struggle to reliably capture the fine-grained links between low-level visual features and high-level textual semantics.
Today, we are highlighting the work of Prof. Mayur Naik's research team at the University of Pennsylvania. To bridge the gap between high-level language and low-level visual features, they developed ESCA (Embodied and Scene-Graph Contextualized Agent). By porting their neurosymbolic pipeline to JAX, they achieved the real-time performance necessary for high-throughput decision-making. This work also demonstrates that JAX drives performance gains across a wide range of hardware, including standard CPUs and NVIDIA GPUs, and not just on Google TPUs.
In this blog, the UPenn team explains how they combined structured scene graphs with JAX's functional design to reduce perception errors by over 50% and achieve a 25% speedup in inference.
The "Grounding" Problem in Embodied AI
Existing MLLMs are powerful, but they can be surprisingly "blind" when tasked with interacting with the physical world. In our empirical analysis of 60 navigation tasks from EmbodiedBench, we found that 69% of agent failures stemmed from perception errors. See the figure below.
The three top-level error types are Perception, Reasoning, and Planning. The second-level errors are Hallucination, Wrong Recognition, Spatial Understanding, Spatial Reasoning, Reflection Error, Inaccurate Action, and Collision. For clarity, the figure uses these acronyms to label the different error types.
The models struggle to capture fine-grained links between visual features and textual semantics. They might recognize a "kitchen," but fail to identify the specific spatial relationship between a knife and a cutting board required to complete a task.
Enter ESCA: The Anglerfish of AI
To solve this, we introduced ESCA, a framework designed to contextualize MLLMs through open-domain scene graph generation.
Think of ESCA like the bioluminescent lure of a deep-sea anglerfish. Just as the fish illuminates its dark surroundings to reveal prey, ESCA "illuminates" the agent's environment by generating a structured Scene Graph—a map of objects, attributes, and relationships (e.g., Cup [Red] ON Table).
A key innovation here is Selective Grounding. Injecting a massive scene graph of everything in the room can overwhelm the model. Instead, ESCA identifies only the subset of objects and relations pertinent to the current instruction. It performs probabilistic reasoning to construct prompts enriched with exactly the contextual details the agent needs to act.
The Engine: LASER and Scallop
At the core of ESCA is LASER, a CLIP-based foundation model trained on 87k video-caption pairs. LASER uses Scallop—our neurosymbolic programming language that supports JAX backends—to align predicted scene graphs with logical specifications. This pipeline allows us to train low-level perception models to produce detailed graphs without needing tedious frame-level annotations.
JAX User Experience
1. The Power of Statelessness
JAX's design encouraged a fully functional, stateless architecture. Every component, from feature extraction to similarity computation, was made into a pure modular function. This structure enabled effective use of jit (Just-In-Time) compilation. The XLA compiler could fuse sequences—like normalization, matrix multiplication, and softmax—into fewer kernels, reducing intermediate buffers and lowering GPU overhead.
2. Handling Complex Control Flow
Our pipeline requires selecting the "top-k" most relevant objects from a probabilistic scene graph. This introduces complex control flow. JAX provided the primitives we needed to handle this efficiently:
- We used jax.lax.cond to manage control flow inside the probabilistic graph.
- We leveraged jax.nn and jax.numpy for all activation functions and batched math in a JIT-friendly way.
3. Debugging and Transparency
Migrating to JAX was also a learning experience. Tools like jax.debug.print/callback() allowed us to inspect values inside jit-compiled functions, while jax.disable_jit() let us easily switch to eager execution to step through the program seeing intermediate values.
Furthermore, the transparency of the open-source system was impressive. Being able to read the annotated source code and see how Python functions trace into jaxpr (JAX expression) gave us deep insight into how to design inference logic that scales.
4. Seamless Integration with Flax
NNX fits into our workflow perfectly. We used nnx.Module to structure the model and FrozenDict to keep parameters organized and immutable. The TrainState object made managing model parameters and optimizer states straightforward, without adding the complexity often found in other frameworks.
JAX Performance: A 25% Speedup
Embodied agents operate in a continuous loop: planning, acting, and updating their understanding of a dynamic world. High latency here is a dealbreaker. We ported LASER from PyTorch to JAX to improve real-time performance, and the benefits were significant. By rewriting our core similarity computations and feature pipelines as pure functions wrapped in jax.jit, we achieved significant gains.
On an NVIDIA H100 GPU, JAX reduced the average time per frame from 18.15 ms (PyTorch) to 14.55 ms (JAX)—a roughly 25% speedup.
Framework |
Hardware |
Avg Time Per Frame (ms) ↓ |
FPS ↑ |
|---|---|---|---|
PyTorch |
H100 GPU |
18.15 ± 0.73 |
55.15 ± 2.31 |
JAX |
H100 GPU |
14.55 ± 0.64 |
68.82 ± 3.13 |
Conclusion
ESCA demonstrates that better data—structured, grounded scene graphs—can solve the perception bottleneck in Embodied AI. But it also demonstrates that better infrastructure is required to run these systems in the real world. JAX provided the speed, transparency, and modularity needed to turn our research into a real-time agent capable of reliable reasoning.
Acknowledgements
This research was made possible through support from a Google Research Award to the University of Pennsylvania and from the ARPA-H program on Safe and Explainable AI under award D24AC00253-00.
Get Started
You can explore the LASER code, the ESCA framework and documentation for JAX and Flax at:
