Introduction
Welcome back, future LLM expert! In our previous chapters, we laid the groundwork by setting up Tunix and understanding its core philosophy. Now, it’s time to peek under the hood and explore how Tunix, built on the powerful JAX ecosystem, handles the intricate dance of model architectures and their ever-evolving state.
Understanding how your Large Language Model (LLM) is represented and how its parameters (the “knowledge” it holds) are managed is absolutely crucial for effective post-training. Unlike traditional imperative frameworks where model state might be implicitly updated, JAX operates on a functional paradigm. This means state management is explicit, predictable, and incredibly powerful when you know how to wield it. Tunix leverages this power, often integrating with libraries like Flax NNX, to give you granular control over your LLM’s internal workings.
By the end of this chapter, you’ll have a solid grasp of JAX’s functional approach to models, how Tunix utilizes Flax NNX for defining and managing LLM components, and the critical concept of explicit state management. This knowledge is fundamental for building sophisticated post-training routines, debugging effectively, and ultimately, achieving peak performance with your models. Let’s dive in!
Core Concepts: The Functional Heart of Tunix
Tunix’s strength comes from its foundation in JAX, a high-performance numerical computation library. JAX’s design principles, particularly its functional programming paradigm, significantly influence how models are built and how their state is handled.
JAX’s Functional Approach
At its core, JAX treats computations as pure functions. This means a function, given the same inputs, will always produce the same outputs and have no side effects. This might seem restrictive at first, especially if you’re used to frameworks where model parameters are internal attributes that get modified in place.
Think of it like a meticulous chef:
- Imperative Frameworks: The chef might grab a pot, add ingredients, stir, and modify the pot’s contents directly. The “pot” (model) itself changes.
- JAX’s Functional Approach: The chef always takes fresh ingredients (input data and current model parameters), performs a cooking step (the model’s forward pass), and produces a new, updated dish (output and new parameters). The original ingredients and pot remain untouched.
This immutability has profound benefits:
- Predictability: Easier to reason about and debug.
- Parallelism: JAX can safely and efficiently parallelize operations (
vmap) and compile them for various hardware (jit) because there are no hidden side effects. - Explicit State: You always know exactly what state is being used and produced.
Enter Flax NNX: Building Blocks for JAX Models
While JAX provides the numerical backbone, building complex neural networks efficiently often requires a higher-level API. This is where Flax comes in, and specifically, Flax NNX (Neural Network eXperiments), which Tunix often integrates with. NNX is designed to give you maximum flexibility and control over your model’s state within JAX’s functional paradigm.
NNX introduces a few key concepts:
nnx.Module: The base class for all neural network layers and models. It’s a container for parameters and other states.nnx.Param: This is how you declare trainable parameters (like weights and biases) within yournnx.Module. These are the values that your optimizer will update during training.nnx.State: When you instantiate annnx.Module, it creates annnx.Stateobject. This object holds all the internal variables of your module – not just parameters, but also things like batch normalization statistics or even optimizer states. ThisStateis what you explicitly pass around and update.nnx.Rngs: JAX uses a deterministic pseudo-random number generator (PRNG). For operations that require randomness (like parameter initialization or dropout), you need to provide explicit random keys (rngs). NNX provides a convenient way to manage these.
Tunix’s “White-Box” Design
The combination of JAX’s functional purity and Flax NNX’s explicit state management underpins Tunix’s “white-box” design philosophy. What does “white-box” mean here?
It means that Tunix allows you to:
- See everything: Every parameter, every layer, every internal state variable is accessible and explicit.
- Modify anything: Because state is passed explicitly, you can inspect, modify, or even replace parts of your model’s state during the post-training process. This is incredibly powerful for advanced techniques like parameter-efficient fine-tuning (PEFT), model editing, or targeted knowledge injection.
Contrast this with a “black-box” approach where you might only interact with an LLM via its inputs and outputs, without direct access to its internal components. Tunix empowers you to dive deep.
Let’s visualize this flow:
- Explanation: The
nnx.Moduledefines the architecture. It’s initialized usingnnx.Rngsto create annnx.State. ThisStateis then explicitly passed along with input data to JAX functional calls (like a forward pass). The call produces an output and a new, updatednnx.State. This updated state is then used for subsequent steps, like an optimizer applying gradient updates, which again produces a new state.
Step-by-Step Implementation: Building a Simple NNX Module
Let’s put these concepts into practice by defining a simple multi-layer perceptron (MLP) using Flax NNX. We’ll see how to define parameters, initialize the model, and manage its state.
First, make sure you have Tunix and its dependencies installed. If you skipped Chapter 1, you can install it via pip:
pip install "tunix[full]>=0.1.0" jax flax
(Note: As of 2026-01-30, tunix version 0.1.0 is a stable reference. Always check the official Tunix GitHub releases for the absolute latest version.)
Now, let’s create our Python file, say mlp_model.py.
1. Setting up a Basic Flax NNX Module
We’ll start by importing nnx and defining a simple MLP class.
# mlp_model.py
import jax
import jax.numpy as jnp
from flax.experimental import nnx
# Let's define a simple Multi-Layer Perceptron (MLP)
# using Flax NNX.
class SimpleMLP(nnx.Module):
# The __init__ method is where we define the layers
# and any parameters our module will have.
def __init__(self, features_out: int, *, rngs: nnx.Rngs):
# We'll create a linear layer.
# nnx.Linear automatically creates nnx.Param objects
# for its weights and biases.
# It also needs an RNG key for initialization.
self.layer1 = nnx.Linear(2, 4, rngs=rngs)
self.layer2 = nnx.Linear(4, features_out, rngs=rngs)
# The __call__ method defines the forward pass of our module.
# It takes the input data and applies the layers.
def __call__(self, x: jax.Array) -> jax.Array:
# Apply the first layer, then a ReLU activation.
x = nnx.relu(self.layer1(x))
# Apply the second layer.
x = self.layer2(x)
return x
print("SimpleMLP module defined!")
Explanation:
import jax,import jax.numpy as jnp: Standard JAX imports.from flax.experimental import nnx: Imports thennxmodule. Flax NNX is still considered experimental but is the recommended way for explicit state management.class SimpleMLP(nnx.Module): Our MLP inherits fromnnx.Module, making it an NNX-compatible component.def __init__(self, features_out: int, *, rngs: nnx.Rngs):features_out: The number of output features for our MLP.rngs: nnx.Rngs: This is crucial! We’re explicitly requiring annnx.Rngsobject. This object holds the random number generator keys needed for operations like parameter initialization.self.layer1 = nnx.Linear(2, 4, rngs=rngs): We define our first linear layer.nnx.Linearis a pre-built NNX module that createsnnx.Paramobjects for its weights and biases internally. We passrngsso it can initialize these parameters randomly. Our input will have 2 features, and this layer will output 4.self.layer2 = nnx.Linear(4, features_out, rngs=rngs): Our second linear layer, taking 4 features from the first layer and outputtingfeatures_out.
def __call__(self, x: jax.Array) -> jax.Array: This method defines how data flows through our model. It’s like theforwardmethod in other frameworks.x = nnx.relu(self.layer1(x)): We apply the first linear layer and then a ReLU activation function.x = self.layer2(x): We apply the second linear layer.
2. Initializing Model State
Now that we have our SimpleMLP definition, let’s create an instance of it and see how its state is initialized.
Add the following code to mlp_model.py:
# ... (previous code for SimpleMLP) ...
# To initialize our model, we need to provide a random number generator (RNG) key.
# JAX's PRNG is deterministic; we need to explicitly manage keys.
# nnx.Rngs helps us manage multiple RNGs for different purposes (e.g., params, dropout).
# We'll create an initial key for 'params' group.
key = jax.random.PRNGKey(0) # A fixed seed for reproducibility
rngs = nnx.Rngs(params=key)
# Instantiate our SimpleMLP module.
# When we instantiate an nnx.Module, it creates an nnx.State object
# that holds all its parameters and other internal variables.
features_out = 1
model = SimpleMLP(features_out, rngs=rngs)
print("\nModel initialized!")
print(f"Model type: {type(model)}")
print(f"Model state type: {type(model.state)}")
# We can inspect the parameters stored within the model's state.
# model.state.parameters() gives us a view of the trainable parameters.
print("\nModel parameters:")
# nnx.State stores parameters in a tree structure.
# We can use .pretty_tree() for a nice formatted output.
print(model.state.parameters().pretty_tree())
# Let's inspect a specific parameter, e.g., the weights of layer1
print("\nWeights of layer1:")
print(model.state.layer1.kernel.value)
Explanation:
key = jax.random.PRNGKey(0): We create a JAX PRNG key with a seed of0. Using a fixed seed ensures that our parameter initialization is reproducible.rngs = nnx.Rngs(params=key): We wrap our PRNG key in annnx.Rngsobject. We assign it to theparamsgroup. If we had dropout, we might also havedropout=jax.random.PRNGKey(1).model = SimpleMLP(features_out, rngs=rngs): We instantiate ourSimpleMLP. During this step, the__init__method runs, andnnx.Linearuses therngsobject to initialize its weights and biases, which are stored asnnx.Paramobjects within themodelinstance’sstate.model.state: This is thennx.Stateobject that holds all the variables of ourSimpleMLPinstance.model.state.parameters().pretty_tree(): This shows us a nicely formatted tree structure of all thennx.Paramobjects (trainable parameters) within our model. You’ll seekernel(weights) andbiasfor bothlayer1andlayer2.model.state.layer1.kernel.value: We can directly access the NumPy array value of a specific parameter.
3. Performing a Forward Pass and Understanding State Updates
Now, let’s perform a forward pass with some dummy data. This is where the functional nature of JAX and NNX becomes very clear: the __call__ method will implicitly take the current state and return a new state (though in this simple forward pass, the parameters themselves aren’t changed, but if we had, say, batch norm, its statistics would update).
Add the following to mlp_model.py:
# ... (previous code for initialization) ...
# Let's create some dummy input data.
# Our first layer expects 2 input features.
dummy_input = jnp.array([[1.0, 2.0], [3.0, 4.0]]) # A batch of 2 samples
print(f"\nInput data shape: {dummy_input.shape}")
# To perform a forward pass, we call the model instance like a function.
# The model's __call__ method implicitly operates on and returns its state.
# For simple forward passes without stateful layers (like BatchNorm),
# the state returned will be identical to the input state.
# But it's good practice to always capture the returned state.
output, new_state = model(dummy_input)
print(f"Output from model (shape {output.shape}):")
print(output)
# In a pure forward pass, the model's parameters themselves don't change.
# However, if we had layers like BatchNorm, their statistics would be updated
# in the returned `new_state`.
# Let's confirm that the state objects are the same (no batch norm here).
print(f"\nIs the original state object the same as the new_state object? {model.state is new_state}")
# We can also access the parameters from the new_state
print("\nWeights of layer1 from new_state (should be identical to original):")
print(new_state.layer1.kernel.value)
# Let's verify if the parameter values are indeed identical
print(f"Are layer1 kernel values identical? {jnp.array_equal(model.state.layer1.kernel.value, new_state.layer1.kernel.value)}")
Explanation:
dummy_input = jnp.array([[1.0, 2.0], [3.0, 4.0]]): We create a small batch of input data. Each sample has 2 features, matching ournnx.Linear(2, 4, ...)layer.output, new_state = model(dummy_input): This is the core of the functional call. When we “call” annnx.Moduleinstance, JAX implicitly passes its currentnnx.Statealong with the input data to the__call__method. The__call__method then computes the output and returns the potentially updatednnx.Statealong with the output.model.state is new_state: In this specificSimpleMLPwithout stateful layers like Batch Normalization, thenew_statereturned will be the exact same object asmodel.state. This indicates no internal variables (beyond the parameters, which are currently static) were updated. If we had a Batch Norm layer,new_statewould be a new object containing updated moving averages. This explicit state passing is key to JAX’s functional paradigm.
This step-by-step example shows you how to define a model, initialize its state, and perform a forward pass using Flax NNX, which is the foundation for how Tunix manages LLMs.
Mini-Challenge: Adding Dropout to Your MLP
Now it’s your turn to get hands-on!
Challenge: Modify the SimpleMLP to include a dropout layer after the first relu activation. Observe how nnx.Rngs are used for this.
Task Description:
- Add
self.dropout = nnx.Dropout(rate=0.5)in the__init__method ofSimpleMLP. - In the
__call__method, applyself.dropoutafternnx.reluand beforeself.layer2. Remember thatnnx.Dropoutrequires anrngsobject and ause_running_modeflag (for train/eval). - When you call the model for the forward pass, you’ll need to provide an additional
rngsgroup fordropoutand setuse_running_mode=Falsefor training-like behavior. - Print the
new_state’s parameters again and verify the dropout layer itself doesn’t add trainable parameters (it’s a process, not a parameter holder).
Hint:
- You’ll need to define a separate RNG key for
dropoutin yournnx.Rngsobject, e.g.,rngs = nnx.Rngs(params=key, dropout=jax.random.PRNGKey(1)). - The
dropoutlayer in__call__will look something likex = self.dropout(x, rngs=rngs, use_running_mode=False).
What to Observe/Learn:
- How
nnx.Rngsare compartmentalized for different random operations. - The difference in model behavior (output values will change due to dropout).
- That
nnx.Dropoutitself doesn’t addnnx.Paramobjects to the state.
Feel free to experiment and try to solve this before looking up solutions!
Click for Solution Hint!
# mlp_model_solution.py
import jax
import jax.numpy as jnp
from flax.experimental import nnx
class SimpleMLP(nnx.Module):
def __init__(self, features_out: int, *, rngs: nnx.Rngs):
self.layer1 = nnx.Linear(2, 4, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.5) # Add dropout layer
self.layer2 = nnx.Linear(4, features_out, rngs=rngs)
def __call__(self, x: jax.Array, *, rngs: nnx.Rngs, use_running_mode: bool) -> jax.Array:
x = nnx.relu(self.layer1(x))
# Apply dropout, passing the 'dropout' RNG and the running mode flag
x = self.dropout(x, rngs=rngs, use_running_mode=use_running_mode)
x = self.layer2(x)
return x
key = jax.random.PRNGKey(0)
# Create a separate RNG key for dropout
rngs_init = nnx.Rngs(params=key, dropout=jax.random.PRNGKey(1))
features_out = 1
model = SimpleMLP(features_out, rngs=rngs_init)
print("Model initialized with Dropout!")
print("\nModel parameters (dropout doesn't add params):")
print(model.state.parameters().pretty_tree())
dummy_input = jnp.array([[1.0, 2.0], [3.0, 4.0]])
# When calling, provide the rngs for dropout and specify use_running_mode
output, new_state = model(dummy_input, rngs=rngs_init, use_running_mode=False) # False for training mode
print(f"\nOutput from model with dropout (shape {output.shape}):")
print(output)
# Try calling again with a new dropout key to see different results if desired
rngs_second_call = nnx.Rngs(params=key, dropout=jax.random.PRNGKey(2))
output_2, _ = model(dummy_input, rngs=rngs_second_call, use_running_mode=False)
print(f"\nOutput from model with different dropout key (shape {output_2.shape}):")
print(output_2)
print(f"Are outputs from different dropout keys identical? {jnp.array_equal(output, output_2)}")
# Now, let's call it in evaluation mode (no dropout applied)
output_eval, _ = model(dummy_input, rngs=rngs_init, use_running_mode=True)
print(f"\nOutput from model in evaluation mode (shape {output_eval.shape}):")
print(output_eval)
Common Pitfalls & Troubleshooting
Working with JAX and Flax NNX, especially when coming from other frameworks, can present a few unique challenges.
Forgetting to Capture Updated State: This is perhaps the most common pitfall. Because JAX functions are pure and state is immutable, any function that “updates” state (like an optimizer step or a batch normalization layer during training) will return a new state object. If you don’t capture this new state and pass it to the next operation, you’ll be using stale parameters or statistics, leading to incorrect or non-converging training.
- Fix: Always assign the returned state:
new_state, output = my_model(state, input_data)ornew_opt_state, new_params = optimizer.update(grads, opt_state, params).
- Fix: Always assign the returned state:
Incorrect RNG Management: JAX’s explicit PRNG system requires careful handling of random keys. Reusing the same key for multiple independent random operations can lead to correlated randomness, while forgetting to split a key for sequential random operations will result in the same “random” numbers being generated.
- Fix:
- For independent random calls (e.g., initializing multiple layers), use
jax.random.split(key, num_splits). - For sequential calls (e.g., dropout in each training step),
jax.random.fold_in(key, step_id)orjax.random.split(key, 2)[0]for the next step. nnx.Rngshelps by managing different key streams for different purposes (e.g.,params,dropout). Ensure you pass the correctrngsobject to your module’s__call__method when needed.
- For independent random calls (e.g., initializing multiple layers), use
- Fix:
Confusion between Flax NNX and Flax Linen: Flax has two main API styles: Linen (the more established, object-oriented API) and NNX (the newer, experimental API designed for explicit state management). While they share some concepts, their usage patterns for state management are different. Tunix often leans into the explicit state management of NNX.
- Fix: Be mindful of which API you’re using. If you see
nnx.Module,nnx.Param,nnx.State, you’re in NNX land. If you seeflax.linen.Module,self.sow,self.param, that’s Linen. Stick to NNX when working with Tunix’s explicit state patterns.
- Fix: Be mindful of which API you’re using. If you see
Summary
Phew! You’ve navigated the functional depths of JAX and discovered how Tunix leverages Flax NNX to manage LLM architectures and their state. Here are the key takeaways from this chapter:
- JAX’s Functional Core: JAX emphasizes pure functions and immutable data, meaning operations produce new states rather than modifying existing ones in place.
- Flax NNX for Explicit Models: Tunix integrates with Flax NNX, which provides
nnx.Module,nnx.Param,nnx.State, andnnx.Rngsfor defining model architectures and managing their variables explicitly. - “White-Box” Control: This explicit state management enables Tunix’s “white-box” design, giving you fine-grained access and control over every part of your LLM during post-training.
- RNGs are Crucial: JAX’s deterministic PRNG requires careful management of random keys, often handled conveniently by
nnx.Rngs. - State is Immutable: Always remember to capture the new state returned by JAX/Flax NNX functions that perform updates.
Understanding these concepts is not just theoretical; it’s the bedrock for building robust, scalable, and highly customizable post-training workflows with Tunix. In the next chapter, we’ll start putting this model architecture knowledge to use as we explore how Tunix orchestrates the training loop for LLMs. Get ready to train some models!
References
- Tunix Official Documentation
- JAX Official Documentation
- Flax NNX GitHub Repository
- Introducing Tunix: A JAX-Native Library for LLM Post-Training - Google Developers Blog
This page is AI-assisted and reviewed. It references official documentation and recognized resources where relevant.