This repo has JAX implementations of the following:
- the encoder-decoder model from the Attention is all you need paper
- the same encoder-decoder model as above with KV caching
- a decoder-only model with KV caching
All models were trained on the following tasks:
- string reversal (max input character limit: 8)
- addition of two operands (max input character limit per operand: 5)
In reality, there are "four" tasks, since I separated the tasks based on model type:
string_reverse_encoder_decoderstring_reverse_decoder_onlyaddition_encoder_decoderaddition_decoder_only
You can test out the models on the tasks above by running python evaluation.py. Edit the main block to try out different tasks.
evaluation.py: test the JAX models I pre-trained on the taskstrain.py: train your own checkpoints of the JAX models. To test out your saved checkpoints, copy the filename of the checkpoint (excluding the file extension) and paste it into the main block inevaluation.pyand runpython evaluation.py.data.py: generate new data for the tasks (WARNING: this will overwrite the current data, unless you move the existing data elsewhere)config.py: configs for training. There's a separate config for each of the four tasks. The default config values guarantee convergence for the default generated data. (WARNING: Theaddition_decoder_onlytask for the default config values will converge to a validation loss value of0.03668at step/epoch 30880, but the script won't stop until step/epoch 50900 because themax_patiencehyper parameter is set very high. All the other tasks should converge to a validation loss value of0.03at a much shorter step/epoch with the default config values).parameters.py: contains a function to instantiate the model parameterslayers.py: contains the layers of the transformer (implemented as inference/forward functions)inference.py: contains teacher-forcing and auto-regressive decoding functions that call the layer functions fromlayers.pytokenizer.py: contains the tokenizer class for each taskcodebase_string.py: script that copies the entire codebase as a string so I can copy and paste it to Gemini for feedback/debuggingvisualize_attention.py: for visualizing attention weights. (WARNING: does not work currently).data/: directory where the generated data for each task is written tocheckpoints/: directory where the model checkpoints for each task are savedbasic_jax_examples/: unrelated directory where I was practicing basic JAX scripts for linear regressionrequirements.txt: install by runningpip install -r requirements.txt. (NOTE: there are some unnecessary dependencies here, but I just copied my virtual environment's dependencies over by runningpip freeze). I'm using Python version3.10.10.