Add vocabulary tiling to reduce redundant memory#2242
Add vocabulary tiling to reduce redundant memory#2242copybara-service[bot] merged 1 commit intomainfrom
Conversation
b9ad751 to
c1277e3
Compare
f231602 to
6eb981c
Compare
|
🤖 Hi @gobbleturk, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
bf90996 to
9cfd8ef
Compare
src/MaxText/layers/models.py
Outdated
| @@ -297,6 +311,19 @@ def no_op(self, *args, **kwargs): | |||
| def init_cache(self, cache_size: int, batch_size: int, dtype=jnp.float32): | |||
| return True | |||
|
|
|||
| def logits_from_hidden_states(self, hidden_states, deterministic): | |||
There was a problem hiding this comment.
i'm confused - is this function defined twice? here and above
There was a problem hiding this comment.
I will only keep one of them, Sorry for the confusion.
src/MaxText/layers/models.py
Outdated
| @@ -410,6 +441,20 @@ class ZeroOneTransformer(nn.Module): | |||
| def setup(self): | |||
| self.model = transformer_as_linen(self.config, self.mesh, self.quant, self.model_mode) | |||
|
|
|||
| def logits_from_hidden_states(self, hidden_states, deterministic): | |||
There was a problem hiding this comment.
I will remove them. This is because there are three types of Transformers defined in model.py but actually only one of them get used.
9cfd8ef to
82250c2
Compare
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
This pull request introduces vocabulary tiling, a memory-saving optimization that computes the cross-entropy loss in chunks to reduce peak memory usage. The implementation is well-structured, with the core logic encapsulated in maxtext_utils.py and comprehensive tests to ensure correctness across various sharding configurations.
🔍 General Feedback
- The addition of thorough unit tests for different parallelism strategies is excellent and ensures the reliability of this new feature.
- The code is clean and the changes are well-integrated into the existing structure.
- The
TODOs for future optimizations are noted and will be important for maximizing the benefits of this feature.
richjames0
left a comment
There was a problem hiding this comment.
PTAL at Gemini's nits but otherwise LGTM!
82250c2 to
ea86e5a
Compare
ea86e5a to
7775ebf
Compare
Description
This PR introduces vocabulary tiling, a new feature for MaxText designed to significantly reduce peak memory consumption during training.
This optimization is particularly beneficial for two key scenarios:
The core idea of vocabulary tiling is to avoid explicitly materializing the full final logits tensor. Instead, the logits activation is chunked (or "tiled") along the batch-sequence dimension.
As illustrated in the diagram below, the forward and backward passes are repeated
num_vocab_tilingtimes. In each iteration, a small slice of the logits is computed, used to calculate the loss, and the Vector-Jacobian product is immediately backpropagated. This iterative process avoids holding the complete, memory-intensive logits tensor.Figure 1: The vocabulary tiling process. The forward and backward passes are repeated for each tile, preventing the full logits tensor from being stored in memory.
For a more in-depth technical explanation, please see the design document.
Doc: go/maxtext-vocab-tiling
FIXES: b/429255841
Tests
Correctness Tests
Test losses and embedding table gradient differences in
MaxText/tests/vocab_tiling_test.pywith 1% relative error tolerance in following cases:logits_via_embedding=False).Performance Tests
See doc.
Checklist
Before submitting this PR, please make sure (put X in square brackets):