Welcome back, future LLM deployment expert! So far in our Tunix journey, you’ve mastered setting up your environment, pre-training, fine-tuning, and evaluating Large Language Models (LLMs) using the power of JAX. You’ve transformed raw data into intelligent, specialized models. But what’s the point of having a brilliant model if it’s just sitting on your hard drive?

This chapter is all about bringing your fine-tuned LLMs to life by deploying them for real-world use. We’ll explore the critical steps and considerations for taking your Tunix-trained models and making them accessible for inference, whether for a small internal tool or a large-scale application. We’ll cover everything from exporting your model to setting up a robust API and even containerizing it for consistent deployment. Get ready to turn your training efforts into tangible, interactive AI!

To make the most of this chapter, you should be comfortable with:

  • The Tunix fine-tuning process (Chapters 10-14).
  • Basic Python programming and command-line usage.
  • An understanding of what an API is and why it’s used.

The Deployment Challenge for LLMs

Deploying LLMs isn’t quite like deploying a simple web application. These models are often massive, requiring significant computational resources (especially GPUs or TPUs) and careful optimization to deliver responses quickly.

Let’s look at the core challenges:

  1. Computational Demand: LLMs perform billions of operations per inference. This demands powerful hardware and efficient software.
  2. Model Size: Models can range from hundreds of megabytes to hundreds of gigabytes, impacting loading times and memory footprint.
  3. Low Latency Requirements: Users expect near-instant responses. Slow inference leads to poor user experience.
  4. Scalability: How do you handle 1 request per minute versus 10,000 requests per second?
  5. Cost: Powerful hardware is expensive. Optimizing for efficiency directly impacts your budget.

Key Deployment Paradigms

There are several ways to deploy an LLM, each with its own trade-offs:

  • Local Inference (for Development/Testing): Running the model directly on your machine. Great for quick tests but not suitable for production.
  • On-Premise Servers: Deploying on your own hardware in a data center. Offers full control but requires significant infrastructure management.
  • Cloud-based Managed Services: Leveraging platforms like Google Cloud’s Vertex AI, AWS SageMaker, or Azure ML. These services handle much of the infrastructure, scaling, and monitoring for you.
  • Cloud-based Custom Infrastructure: Deploying on virtual machines or container orchestration platforms (like Kubernetes) in the cloud. Offers flexibility but requires more hands-on management.

For this chapter, we’ll focus on a common and flexible approach: creating a lightweight API using FastAPI and then containerizing it with Docker. This method provides excellent control and can be adapted for both local execution and cloud-based custom infrastructure.

JAX and Deployment Considerations

JAX brings unique strengths and considerations to deployment:

  • JIT Compilation: JAX’s Just-In-Time (JIT) compilation is a superpower. It compiles your Python code into highly optimized XLA (Accelerated Linear Algebra) computations for your hardware (GPU/TPU). This means incredible speed!
  • Cold Start Latency: The first time a JIT-compiled function runs, there’s a compilation overhead. This “cold start” can cause the first request to an endpoint to be slower than subsequent ones. We’ll need to consider strategies to mitigate this.
  • Model State: Tunix models, being JAX/Flax-native, typically consist of a set of parameters (the weights) and potentially an optimizer state. For deployment, we primarily care about the parameters.

Let’s visualize a typical deployment flow for an LLM API:

flowchart TD User[User Client Application] --->|HTTP Request| API[FastAPI Server] API --->|Load Model State| Model_Loader[Model Loader] Model_Loader --> Tunix_Model[Tunix Fine Tuned Model] API --->|Inference Request| JAX_Runtime[JAX Runtime] JAX_Runtime --->|Generated Text| API API --->|HTTP Response| User

Step-by-Step Implementation: Deploying with FastAPI

We’ll deploy a simple text generation model. For this example, let’s assume you’ve already fine-tuned a model using Tunix and have saved its parameters.

Step 1: Exporting Your Tunix Model

After fine-tuning, your Tunix model’s parameters (weights) are typically saved. Tunix often works with Flax, so you’ll save the Flax params and the tokenizer.

Let’s imagine you saved your model in a previous chapter like this:

# Assuming you have a fine-tuned model and its parameters
# from a Tunix training run.
# For simplicity, we'll use a placeholder structure.
import jax
import jax.numpy as jnp
from flax.core import freeze
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
import os

# --- Placeholder for your Tunix-trained model and tokenizer ---
# In a real scenario, these would come from your actual Tunix training output.
# For demonstration, we'll load a pre-trained Flax model and simulate saving.

model_name = "google/flan-t5-small" # Or your actual Tunix base model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = FlaxAutoModelForCausalLM.from_pretrained(model_name)

# In a real Tunix scenario, `fine_tuned_params` would be the output
# of your training loop. Here, we'll just use the pre-trained model's params
# as a stand-in.
fine_tuned_params = model.params
# --- End Placeholder ---


output_dir = "./deployed_model"
os.makedirs(output_dir, exist_ok=True)

# 1. Save the tokenizer
tokenizer.save_pretrained(output_dir)
print(f"Tokenizer saved to {output_dir}")

# 2. Save the model parameters (weights)
# Flax models often save parameters as a PyTree (nested dict of JAX arrays).
# You can save them using JAX's `jax.tree_util.tree_map` and `orbax.checkpoint`
# or a simpler approach for demonstration:
import msgpack
from flax.serialization import to_bytes, from_bytes

# Convert JAX arrays to bytes for MsgPack serialization
# (This is a simplified approach; for production, consider Orbax Checkpoint)
params_bytes = to_bytes(fine_tuned_params)

with open(os.path.join(output_dir, "flax_model.msgpack"), "wb") as f:
    f.write(params_bytes)

print(f"Model parameters saved to {output_dir}/flax_model.msgpack")

Explanation:

  1. We define output_dir where our model artifacts will live.
  2. The tokenizer.save_pretrained(output_dir) method saves all necessary tokenizer files (vocabulary, special tokens, etc.) into the specified directory. This is crucial for consistent text processing.
  3. We then save the fine_tuned_params. Since Tunix often uses Flax, these parameters are typically a nested dictionary of JAX arrays. We use flax.serialization.to_bytes and msgpack to serialize these parameters into a binary file. This file contains the learned weights of your LLM.

Run this script to create your deployed_model directory.

Step 2: Setting up a Basic Inference API with FastAPI

Now, let’s create a web API that can load this saved model and perform inference.

First, install the necessary libraries:

pip install fastapi uvicorn "jax[cuda12_pip]" "flax" "transformers" "msgpack"

Note: Replace cuda12_pip with the appropriate CUDA version for your system if you’re using a GPU (e.g., cuda11_pip). If you’re on a CPU, just use jax. As of 2026-01-30, JAX versions are highly optimized for specific CUDA versions. Check the official JAX documentation for the precise installation command for your environment.

Create a new file named main.py:

# main.py
import jax
import jax.numpy as jnp
from flax.core import freeze
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
from fastapi import FastAPI
from pydantic import BaseModel
import os
import msgpack
from flax.serialization import from_bytes

# --- Configuration ---
MODEL_PATH = "./deployed_model"
MODEL_NAME_OR_PATH = "google/flan-t5-small" # The base model name used during Tunix training
# --- End Configuration ---

app = FastAPI(
    title="Tunix LLM Deployment API",
    description="API for inference with a Tunix fine-tuned LLM.",
    version="0.1.0"
)

# Global variables to hold model and tokenizer
tokenizer = None
model = None
model_params = None

# Define a Pydantic model for our request body
class InferenceRequest(BaseModel):
    prompt: str
    max_length: int = 50
    num_return_sequences: int = 1
    temperature: float = 0.7
    top_k: int = 50
    do_sample: bool = True

class InferenceResponse(BaseModel):
    generated_text: list[str]

@app.on_event("startup")
async def load_model():
    """
    Load the tokenizer and model parameters when the FastAPI application starts up.
    This ensures the model is loaded once and ready for all requests.
    """
    global tokenizer, model, model_params

    print(f"Loading tokenizer from {MODEL_PATH}...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    print("Tokenizer loaded.")

    print(f"Loading base model architecture for {MODEL_NAME_OR_PATH}...")
    # Load the model architecture (without pre-trained weights initially)
    # We will then load our fine-tuned weights into this architecture.
    model = FlaxAutoModelForCausalLM.from_pretrained(MODEL_NAME_OR_PATH, from_pt=False, _do_init=False)
    print("Base model architecture loaded.")

    print(f"Loading fine-tuned parameters from {MODEL_PATH}/flax_model.msgpack...")
    with open(os.path.join(MODEL_PATH, "flax_model.msgpack"), "rb") as f:
        params_bytes = f.read()
    model_params = freeze(from_bytes(type(model.params), params_bytes))
    print("Fine-tuned parameters loaded.")

    # JIT compile the generation function by calling it once with dummy input
    # This helps reduce cold start latency for the first actual request.
    print("Performing initial JIT compilation...")
    _ = model.generate(
        input_ids=jnp.array([[tokenizer.bos_token_id]]), # Dummy input
        params=model_params,
        max_length=5, # Small max_length for quick compilation
        do_sample=False,
        num_beams=1,
    ).sequences
    print("JIT compilation complete.")
    print("Model and tokenizer ready for inference!")


@jax.jit # JIT compile the inference function for performance
def generate_text_jitted(input_ids, params, **kwargs):
    """
    JIT-compiled function for text generation.
    """
    output_ids = model.generate(
        input_ids=input_ids,
        params=params,
        **kwargs
    ).sequences
    return output_ids


@app.post("/infer", response_model=InferenceResponse)
async def infer(request: InferenceRequest):
    """
    Performs text generation based on the provided prompt and parameters.
    """
    if tokenizer is None or model is None or model_params is None:
        raise RuntimeError("Model and tokenizer not loaded yet. Please wait for startup.")

    # Tokenize the input prompt
    input_ids = tokenizer(
        request.prompt,
        return_tensors="jax",
        max_length=model.config.max_position_embeddings,
        truncation=True
    ).input_ids

    # Perform generation using the JIT-compiled function
    generated_output_ids = generate_text_jitted(
        input_ids,
        model_params,
        max_length=request.max_length,
        num_return_sequences=request.num_return_sequences,
        temperature=request.temperature,
        top_k=request.top_k,
        do_sample=request.do_sample,
    )

    # Decode the generated tokens back to text
    generated_texts = [
        tokenizer.decode(output_id, skip_special_tokens=True)
        for output_id in generated_output_ids
    ]

    return InferenceResponse(generated_text=generated_texts)

Explanation of main.py:

  1. Imports: We bring in jax, flax, transformers (for AutoTokenizer and FlaxAutoModelForCausalLM), FastAPI, Pydantic (for data validation), os, and msgpack.
  2. Configuration: MODEL_PATH points to our saved model, and MODEL_NAME_OR_PATH is the identifier for the base model architecture.
  3. FastAPI App: We initialize our FastAPI application.
  4. InferenceRequest & InferenceResponse: These Pydantic models define the expected structure of incoming requests (e.g., prompt, max_length) and outgoing responses (generated_text). This provides automatic validation and clear API documentation.
  5. @app.on_event("startup"): This decorator ensures that the load_model function runs once when the FastAPI server starts.
    • Inside load_model, we load the tokenizer from our MODEL_PATH.
    • We then load the FlaxAutoModelForCausalLM architecture using from_pretrained but importantly, we set _do_init=False because we want to load our own fine-tuned parameters, not the pre-trained ones.
    • We load our saved flax_model.msgpack and deserialize the parameters using from_bytes. freeze is used to make the parameters immutable, which is good practice for inference.
    • JIT Compilation Warm-up: A crucial step here is to call model.generate once with dummy input. This triggers JAX’s JIT compilation of the generation graph during startup, preventing the first real user request from experiencing a “cold start” delay.
  6. @jax.jit for Inference: The generate_text_jitted function is decorated with @jax.jit. This tells JAX to compile this function for maximum performance. Since the model parameters are passed as an argument, JAX can re-use the compiled graph for different inputs and parameters (though here, params are static after loading).
  7. @app.post("/infer"): This defines our API endpoint. It’s a POST request to /infer.
    • It takes an InferenceRequest object, automatically validated by FastAPI.
    • The input prompt is tokenized using our loaded tokenizer.
    • The generate_text_jitted function is called to perform the actual LLM inference.
    • The generated token IDs are then decoded back into human-readable text.
    • Finally, an InferenceResponse object is returned, which FastAPI serializes to JSON.

Step 3: Running the API Locally

To run your API, navigate to the directory containing main.py and deployed_model, and execute:

uvicorn main:app --host 0.0.0.0 --port 8000 --reload

Explanation:

  • uvicorn: The ASGI server that runs your FastAPI application.
  • main:app: Tells Uvicorn to find the app object in main.py.
  • --host 0.0.0.0: Makes the server accessible from outside your local machine (useful for Docker later).
  • --port 8000: Runs the server on port 8000.
  • --reload: (Optional, for development) Automatically reloads the server when code changes. Remove this for production.

You’ll see output indicating the model and tokenizer are loading, followed by the JIT compilation message. Once “Application startup complete.” appears, your API is ready!

Open your browser and go to http://localhost:8000/docs. You’ll see the interactive API documentation provided by FastAPI (Swagger UI). You can try out your /infer endpoint directly from there!

Alternatively, you can send a curl request:

curl -X POST "http://localhost:8000/infer" \
     -H "Content-Type: application/json" \
     -d '{
           "prompt": "The capital of France is",
           "max_length": 20,
           "temperature": 0.5
         }'

You should receive a JSON response with the generated text!

Step 4: Containerization with Docker (Brief Introduction)

For consistent and portable deployment, Docker is invaluable. It packages your application and all its dependencies into a single, isolated “container.”

Create a file named Dockerfile in the same directory as main.py and deployed_model:

# Use an official Python runtime as a parent image
FROM python:3.10-slim-buster

# Set the working directory in the container
WORKDIR /app

# Install system dependencies needed for JAX/TensorFlow if applicable (e.g., CUDA libraries)
# For GPU support, you would use a base image with CUDA pre-installed,
# e.g., FROM nvidia/cuda:12.2.0-cudnn8-runtime-ubuntu22.04
# and then install Python within it.
# For simplicity, we'll assume CPU for this Dockerfile.
# If using GPU, ensure you have the NVIDIA Container Toolkit installed on your host.

# Copy the requirements file and install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Copy the application code and the deployed model
COPY . /app

# Expose the port the API will run on
EXPOSE 8000

# Command to run the application
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]

And create a requirements.txt file:

fastapi
uvicorn[standard]
jax[cpu] # Use jax[cuda] for GPU, ensure correct version for your CUDA
flax
transformers
pydantic
msgpack

Explanation of Dockerfile:

  1. FROM python:3.10-slim-buster: We start with a lightweight Python 3.10 image.
  2. WORKDIR /app: Sets the working directory inside the container.
  3. COPY requirements.txt . and RUN pip install -r requirements.txt: Copies your dependency list and installs them.
  4. COPY . /app: Copies all your application files (including main.py and the deployed_model directory) into the container.
  5. EXPOSE 8000: Informs Docker that the container listens on port 8000.
  6. CMD ["uvicorn", "main:app", ...]: The command that runs when the container starts.

Building and Running the Docker Image:

# 1. Build the Docker image
docker build -t tunix-llm-api .

# 2. Run the Docker container
# For CPU:
docker run -p 8000:8000 tunix-llm-api

# For GPU (assuming NVIDIA Container Toolkit is installed and configured):
# docker run --gpus all -p 8000:8000 tunix-llm-api

Now, your LLM API is running inside a Docker container! You can access it just like before at http://localhost:8000. This container can now be easily moved to any server that has Docker installed, providing a consistent runtime environment.

Mini-Challenge: Batch Inference Endpoint

Our current /infer endpoint handles one prompt at a time. For efficiency, especially with LLMs, it’s often beneficial to process multiple prompts in a single request (batching).

Challenge: Modify the main.py file to add a new endpoint, /batch_infer, that accepts a list of prompts and returns a list of generated texts.

Hint:

  • You’ll need a new Pydantic model for the batch request that contains a list[str] for prompts.
  • The tokenizer can handle a list of strings directly, returning input_ids with a batch dimension.
  • The model.generate function already supports batching if input_ids has a batch dimension.

What to Observe/Learn: After implementing and testing, you should notice that processing multiple prompts in one batch request is generally faster than sending individual requests for each prompt, especially for smaller max_length values, due to reduced overhead and better utilization of your hardware.


Click for Mini-Challenge Solution
# Add to main.py, after InferenceResponse class

class BatchInferenceRequest(BaseModel):
    prompts: list[str]
    max_length: int = 50
    num_return_sequences: int = 1
    temperature: float = 0.7
    top_k: int = 50
    do_sample: bool = True

@app.post("/batch_infer", response_model=InferenceResponse) # Reusing InferenceResponse
async def batch_infer(request: BatchInferenceRequest):
    """
    Performs text generation for a batch of prompts.
    """
    if tokenizer is None or model is None or model_params is None:
        raise RuntimeError("Model and tokenizer not loaded yet. Please wait for startup.")

    # Tokenize the input prompts (tokenizer handles lists directly)
    input_ids = tokenizer(
        request.prompts,
        return_tensors="jax",
        padding=True, # Important for batching, pads to the longest sequence
        max_length=model.config.max_position_embeddings,
        truncation=True
    ).input_ids

    # Perform generation using the JIT-compiled function
    generated_output_ids = generate_text_jitted(
        input_ids,
        model_params,
        max_length=request.max_length,
        num_return_sequences=request.num_return_sequences,
        temperature=request.temperature,
        top_k=request.top_k,
        do_sample=request.do_sample,
    )

    # Decode the generated tokens back to text
    # The output_ids will have shape (batch_size * num_return_sequences, sequence_length)
    generated_texts = [
        tokenizer.decode(output_id, skip_special_tokens=True)
        for output_id in generated_output_ids
    ]

    return InferenceResponse(generated_text=generated_texts)

Test with curl:

curl -X POST "http://localhost:8000/batch_infer" \
     -H "Content-Type: application/json" \
     -d '{
           "prompts": [
             "The capital of France is",
             "What is the largest ocean on Earth?"
           ],
           "max_length": 25,
           "temperature": 0.6
         }'

Common Pitfalls & Troubleshooting

  1. Out of Memory (OOM) Errors:

    • Problem: Your GPU/TPU runs out of memory, especially with large models or high max_length / batch_size.
    • Solution:
      • Reduce batch_size in your API requests.
      • Reduce max_length for generation.
      • Consider model quantization (e.g., to bfloat16 or int8) if supported by Tunix/Flax/JAX and your hardware. Tunix often works with bfloat16 by default on TPUs.
      • Upgrade to a more powerful GPU/TPU.
      • Use techniques like offloading parts of the model to CPU memory if latency is less critical.
  2. Cold Start Latency:

    • Problem: The very first request to your API takes significantly longer than subsequent requests due to JAX’s JIT compilation.
    • Solution: As shown in our main.py, calling the JIT-compiled function with dummy inputs during application startup (@app.on_event("startup")) is an effective way to “warm up” the compilation cache.
  3. Dependency Mismatch/Environment Issues:

    • Problem: Your model runs fine during training but fails in deployment. This is often due to different versions of JAX, Flax, Transformers, or Python itself.
    • Solution:
      • Always use a requirements.txt (or pyproject.toml) to specify exact versions of all dependencies.
      • Use Docker or other containerization technologies to ensure the deployment environment is identical to your testing environment.
      • Verify CUDA/cuDNN versions on your deployment machine match what JAX expects.
  4. Slow Inference (Beyond Cold Start):

    • Problem: Even after warm-up, inference is too slow.
    • Solution:
      • Ensure JAX is correctly configured to use your GPU/TPU (check JAX device count: print(jax.devices())).
      • Optimize your model (e.g., knowledge distillation, smaller architecture if possible).
      • Utilize batching effectively (as in our mini-challenge).
      • Consider hardware upgrades or distributed inference setups for very high throughput.

Summary

Congratulations! You’ve successfully navigated the waters of LLM deployment. In this chapter, you learned:

  • The unique challenges of deploying large language models, including computational demands and latency.
  • How to prepare your Tunix-fine-tuned model for deployment by saving its tokenizer and parameters.
  • To build a robust and efficient inference API using FastAPI, complete with Pydantic for request validation and JAX’s JIT compilation for performance.
  • The importance of “warming up” your JAX functions during application startup to reduce cold start latency.
  • A brief introduction to containerizing your application with Docker for consistent and portable deployment.
  • Strategies for troubleshooting common deployment pitfalls like OOM errors and slow inference.

Deploying LLMs is a critical step in bringing your AI projects to fruition. With these foundational skills, you’re well-equipped to make your fine-tuned Tunix models accessible and impactful.

In the next chapter, we’ll dive into advanced MLOps practices for Tunix, including monitoring, logging, and potentially A/B testing, to ensure your deployed models remain performant and relevant over time.


References

This page is AI-assisted and reviewed. It references official documentation and recognized resources where relevant.