Welcome back, future LLM deployment expert! So far in our Tunix journey, you’ve mastered setting up your environment, pre-training, fine-tuning, and evaluating Large Language Models (LLMs) using the power of JAX. You’ve transformed raw data into intelligent, specialized models. But what’s the point of having a brilliant model if it’s just sitting on your hard drive?
This chapter is all about bringing your fine-tuned LLMs to life by deploying them for real-world use. We’ll explore the critical steps and considerations for taking your Tunix-trained models and making them accessible for inference, whether for a small internal tool or a large-scale application. We’ll cover everything from exporting your model to setting up a robust API and even containerizing it for consistent deployment. Get ready to turn your training efforts into tangible, interactive AI!
To make the most of this chapter, you should be comfortable with:
- The Tunix fine-tuning process (Chapters 10-14).
- Basic Python programming and command-line usage.
- An understanding of what an API is and why it’s used.
The Deployment Challenge for LLMs
Deploying LLMs isn’t quite like deploying a simple web application. These models are often massive, requiring significant computational resources (especially GPUs or TPUs) and careful optimization to deliver responses quickly.
Let’s look at the core challenges:
- Computational Demand: LLMs perform billions of operations per inference. This demands powerful hardware and efficient software.
- Model Size: Models can range from hundreds of megabytes to hundreds of gigabytes, impacting loading times and memory footprint.
- Low Latency Requirements: Users expect near-instant responses. Slow inference leads to poor user experience.
- Scalability: How do you handle 1 request per minute versus 10,000 requests per second?
- Cost: Powerful hardware is expensive. Optimizing for efficiency directly impacts your budget.
Key Deployment Paradigms
There are several ways to deploy an LLM, each with its own trade-offs:
- Local Inference (for Development/Testing): Running the model directly on your machine. Great for quick tests but not suitable for production.
- On-Premise Servers: Deploying on your own hardware in a data center. Offers full control but requires significant infrastructure management.
- Cloud-based Managed Services: Leveraging platforms like Google Cloud’s Vertex AI, AWS SageMaker, or Azure ML. These services handle much of the infrastructure, scaling, and monitoring for you.
- Cloud-based Custom Infrastructure: Deploying on virtual machines or container orchestration platforms (like Kubernetes) in the cloud. Offers flexibility but requires more hands-on management.
For this chapter, we’ll focus on a common and flexible approach: creating a lightweight API using FastAPI and then containerizing it with Docker. This method provides excellent control and can be adapted for both local execution and cloud-based custom infrastructure.
JAX and Deployment Considerations
JAX brings unique strengths and considerations to deployment:
- JIT Compilation: JAX’s Just-In-Time (JIT) compilation is a superpower. It compiles your Python code into highly optimized XLA (Accelerated Linear Algebra) computations for your hardware (GPU/TPU). This means incredible speed!
- Cold Start Latency: The first time a JIT-compiled function runs, there’s a compilation overhead. This “cold start” can cause the first request to an endpoint to be slower than subsequent ones. We’ll need to consider strategies to mitigate this.
- Model State: Tunix models, being JAX/Flax-native, typically consist of a set of parameters (the weights) and potentially an optimizer state. For deployment, we primarily care about the parameters.
Let’s visualize a typical deployment flow for an LLM API:
Step-by-Step Implementation: Deploying with FastAPI
We’ll deploy a simple text generation model. For this example, let’s assume you’ve already fine-tuned a model using Tunix and have saved its parameters.
Step 1: Exporting Your Tunix Model
After fine-tuning, your Tunix model’s parameters (weights) are typically saved. Tunix often works with Flax, so you’ll save the Flax params and the tokenizer.
Let’s imagine you saved your model in a previous chapter like this:
# Assuming you have a fine-tuned model and its parameters
# from a Tunix training run.
# For simplicity, we'll use a placeholder structure.
import jax
import jax.numpy as jnp
from flax.core import freeze
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
import os
# --- Placeholder for your Tunix-trained model and tokenizer ---
# In a real scenario, these would come from your actual Tunix training output.
# For demonstration, we'll load a pre-trained Flax model and simulate saving.
model_name = "google/flan-t5-small" # Or your actual Tunix base model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = FlaxAutoModelForCausalLM.from_pretrained(model_name)
# In a real Tunix scenario, `fine_tuned_params` would be the output
# of your training loop. Here, we'll just use the pre-trained model's params
# as a stand-in.
fine_tuned_params = model.params
# --- End Placeholder ---
output_dir = "./deployed_model"
os.makedirs(output_dir, exist_ok=True)
# 1. Save the tokenizer
tokenizer.save_pretrained(output_dir)
print(f"Tokenizer saved to {output_dir}")
# 2. Save the model parameters (weights)
# Flax models often save parameters as a PyTree (nested dict of JAX arrays).
# You can save them using JAX's `jax.tree_util.tree_map` and `orbax.checkpoint`
# or a simpler approach for demonstration:
import msgpack
from flax.serialization import to_bytes, from_bytes
# Convert JAX arrays to bytes for MsgPack serialization
# (This is a simplified approach; for production, consider Orbax Checkpoint)
params_bytes = to_bytes(fine_tuned_params)
with open(os.path.join(output_dir, "flax_model.msgpack"), "wb") as f:
f.write(params_bytes)
print(f"Model parameters saved to {output_dir}/flax_model.msgpack")
Explanation:
- We define
output_dirwhere our model artifacts will live. - The
tokenizer.save_pretrained(output_dir)method saves all necessary tokenizer files (vocabulary, special tokens, etc.) into the specified directory. This is crucial for consistent text processing. - We then save the
fine_tuned_params. Since Tunix often uses Flax, these parameters are typically a nested dictionary of JAX arrays. We useflax.serialization.to_bytesandmsgpackto serialize these parameters into a binary file. This file contains the learned weights of your LLM.
Run this script to create your deployed_model directory.
Step 2: Setting up a Basic Inference API with FastAPI
Now, let’s create a web API that can load this saved model and perform inference.
First, install the necessary libraries:
pip install fastapi uvicorn "jax[cuda12_pip]" "flax" "transformers" "msgpack"
Note: Replace cuda12_pip with the appropriate CUDA version for your system if you’re using a GPU (e.g., cuda11_pip). If you’re on a CPU, just use jax. As of 2026-01-30, JAX versions are highly optimized for specific CUDA versions. Check the official JAX documentation for the precise installation command for your environment.
Create a new file named main.py:
# main.py
import jax
import jax.numpy as jnp
from flax.core import freeze
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
from fastapi import FastAPI
from pydantic import BaseModel
import os
import msgpack
from flax.serialization import from_bytes
# --- Configuration ---
MODEL_PATH = "./deployed_model"
MODEL_NAME_OR_PATH = "google/flan-t5-small" # The base model name used during Tunix training
# --- End Configuration ---
app = FastAPI(
title="Tunix LLM Deployment API",
description="API for inference with a Tunix fine-tuned LLM.",
version="0.1.0"
)
# Global variables to hold model and tokenizer
tokenizer = None
model = None
model_params = None
# Define a Pydantic model for our request body
class InferenceRequest(BaseModel):
prompt: str
max_length: int = 50
num_return_sequences: int = 1
temperature: float = 0.7
top_k: int = 50
do_sample: bool = True
class InferenceResponse(BaseModel):
generated_text: list[str]
@app.on_event("startup")
async def load_model():
"""
Load the tokenizer and model parameters when the FastAPI application starts up.
This ensures the model is loaded once and ready for all requests.
"""
global tokenizer, model, model_params
print(f"Loading tokenizer from {MODEL_PATH}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
print("Tokenizer loaded.")
print(f"Loading base model architecture for {MODEL_NAME_OR_PATH}...")
# Load the model architecture (without pre-trained weights initially)
# We will then load our fine-tuned weights into this architecture.
model = FlaxAutoModelForCausalLM.from_pretrained(MODEL_NAME_OR_PATH, from_pt=False, _do_init=False)
print("Base model architecture loaded.")
print(f"Loading fine-tuned parameters from {MODEL_PATH}/flax_model.msgpack...")
with open(os.path.join(MODEL_PATH, "flax_model.msgpack"), "rb") as f:
params_bytes = f.read()
model_params = freeze(from_bytes(type(model.params), params_bytes))
print("Fine-tuned parameters loaded.")
# JIT compile the generation function by calling it once with dummy input
# This helps reduce cold start latency for the first actual request.
print("Performing initial JIT compilation...")
_ = model.generate(
input_ids=jnp.array([[tokenizer.bos_token_id]]), # Dummy input
params=model_params,
max_length=5, # Small max_length for quick compilation
do_sample=False,
num_beams=1,
).sequences
print("JIT compilation complete.")
print("Model and tokenizer ready for inference!")
@jax.jit # JIT compile the inference function for performance
def generate_text_jitted(input_ids, params, **kwargs):
"""
JIT-compiled function for text generation.
"""
output_ids = model.generate(
input_ids=input_ids,
params=params,
**kwargs
).sequences
return output_ids
@app.post("/infer", response_model=InferenceResponse)
async def infer(request: InferenceRequest):
"""
Performs text generation based on the provided prompt and parameters.
"""
if tokenizer is None or model is None or model_params is None:
raise RuntimeError("Model and tokenizer not loaded yet. Please wait for startup.")
# Tokenize the input prompt
input_ids = tokenizer(
request.prompt,
return_tensors="jax",
max_length=model.config.max_position_embeddings,
truncation=True
).input_ids
# Perform generation using the JIT-compiled function
generated_output_ids = generate_text_jitted(
input_ids,
model_params,
max_length=request.max_length,
num_return_sequences=request.num_return_sequences,
temperature=request.temperature,
top_k=request.top_k,
do_sample=request.do_sample,
)
# Decode the generated tokens back to text
generated_texts = [
tokenizer.decode(output_id, skip_special_tokens=True)
for output_id in generated_output_ids
]
return InferenceResponse(generated_text=generated_texts)
Explanation of main.py:
- Imports: We bring in
jax,flax,transformers(forAutoTokenizerandFlaxAutoModelForCausalLM),FastAPI,Pydantic(for data validation),os, andmsgpack. - Configuration:
MODEL_PATHpoints to our saved model, andMODEL_NAME_OR_PATHis the identifier for the base model architecture. FastAPIApp: We initialize our FastAPI application.InferenceRequest&InferenceResponse: ThesePydanticmodels define the expected structure of incoming requests (e.g.,prompt,max_length) and outgoing responses (generated_text). This provides automatic validation and clear API documentation.@app.on_event("startup"): This decorator ensures that theload_modelfunction runs once when the FastAPI server starts.- Inside
load_model, we load thetokenizerfrom ourMODEL_PATH. - We then load the
FlaxAutoModelForCausalLMarchitecture usingfrom_pretrainedbut importantly, we set_do_init=Falsebecause we want to load our own fine-tuned parameters, not the pre-trained ones. - We load our saved
flax_model.msgpackand deserialize the parameters usingfrom_bytes.freezeis used to make the parameters immutable, which is good practice for inference. - JIT Compilation Warm-up: A crucial step here is to call
model.generateonce with dummy input. This triggers JAX’s JIT compilation of the generation graph during startup, preventing the first real user request from experiencing a “cold start” delay.
- Inside
@jax.jitfor Inference: Thegenerate_text_jittedfunction is decorated with@jax.jit. This tells JAX to compile this function for maximum performance. Since the model parameters are passed as an argument, JAX can re-use the compiled graph for different inputs and parameters (though here, params are static after loading).@app.post("/infer"): This defines our API endpoint. It’s aPOSTrequest to/infer.- It takes an
InferenceRequestobject, automatically validated by FastAPI. - The input
promptis tokenized using our loaded tokenizer. - The
generate_text_jittedfunction is called to perform the actual LLM inference. - The generated token IDs are then decoded back into human-readable text.
- Finally, an
InferenceResponseobject is returned, which FastAPI serializes to JSON.
- It takes an
Step 3: Running the API Locally
To run your API, navigate to the directory containing main.py and deployed_model, and execute:
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
Explanation:
uvicorn: The ASGI server that runs your FastAPI application.main:app: Tells Uvicorn to find theappobject inmain.py.--host 0.0.0.0: Makes the server accessible from outside your local machine (useful for Docker later).--port 8000: Runs the server on port 8000.--reload: (Optional, for development) Automatically reloads the server when code changes. Remove this for production.
You’ll see output indicating the model and tokenizer are loading, followed by the JIT compilation message. Once “Application startup complete.” appears, your API is ready!
Open your browser and go to http://localhost:8000/docs. You’ll see the interactive API documentation provided by FastAPI (Swagger UI). You can try out your /infer endpoint directly from there!
Alternatively, you can send a curl request:
curl -X POST "http://localhost:8000/infer" \
-H "Content-Type: application/json" \
-d '{
"prompt": "The capital of France is",
"max_length": 20,
"temperature": 0.5
}'
You should receive a JSON response with the generated text!
Step 4: Containerization with Docker (Brief Introduction)
For consistent and portable deployment, Docker is invaluable. It packages your application and all its dependencies into a single, isolated “container.”
Create a file named Dockerfile in the same directory as main.py and deployed_model:
# Use an official Python runtime as a parent image
FROM python:3.10-slim-buster
# Set the working directory in the container
WORKDIR /app
# Install system dependencies needed for JAX/TensorFlow if applicable (e.g., CUDA libraries)
# For GPU support, you would use a base image with CUDA pre-installed,
# e.g., FROM nvidia/cuda:12.2.0-cudnn8-runtime-ubuntu22.04
# and then install Python within it.
# For simplicity, we'll assume CPU for this Dockerfile.
# If using GPU, ensure you have the NVIDIA Container Toolkit installed on your host.
# Copy the requirements file and install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy the application code and the deployed model
COPY . /app
# Expose the port the API will run on
EXPOSE 8000
# Command to run the application
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
And create a requirements.txt file:
fastapi
uvicorn[standard]
jax[cpu] # Use jax[cuda] for GPU, ensure correct version for your CUDA
flax
transformers
pydantic
msgpack
Explanation of Dockerfile:
FROM python:3.10-slim-buster: We start with a lightweight Python 3.10 image.WORKDIR /app: Sets the working directory inside the container.COPY requirements.txt .andRUN pip install -r requirements.txt: Copies your dependency list and installs them.COPY . /app: Copies all your application files (includingmain.pyand thedeployed_modeldirectory) into the container.EXPOSE 8000: Informs Docker that the container listens on port 8000.CMD ["uvicorn", "main:app", ...]: The command that runs when the container starts.
Building and Running the Docker Image:
# 1. Build the Docker image
docker build -t tunix-llm-api .
# 2. Run the Docker container
# For CPU:
docker run -p 8000:8000 tunix-llm-api
# For GPU (assuming NVIDIA Container Toolkit is installed and configured):
# docker run --gpus all -p 8000:8000 tunix-llm-api
Now, your LLM API is running inside a Docker container! You can access it just like before at http://localhost:8000. This container can now be easily moved to any server that has Docker installed, providing a consistent runtime environment.
Mini-Challenge: Batch Inference Endpoint
Our current /infer endpoint handles one prompt at a time. For efficiency, especially with LLMs, it’s often beneficial to process multiple prompts in a single request (batching).
Challenge: Modify the main.py file to add a new endpoint, /batch_infer, that accepts a list of prompts and returns a list of generated texts.
Hint:
- You’ll need a new
Pydanticmodel for the batch request that contains alist[str]for prompts. - The
tokenizercan handle a list of strings directly, returninginput_idswith a batch dimension. - The
model.generatefunction already supports batching ifinput_idshas a batch dimension.
What to Observe/Learn:
After implementing and testing, you should notice that processing multiple prompts in one batch request is generally faster than sending individual requests for each prompt, especially for smaller max_length values, due to reduced overhead and better utilization of your hardware.
Click for Mini-Challenge Solution
# Add to main.py, after InferenceResponse class
class BatchInferenceRequest(BaseModel):
prompts: list[str]
max_length: int = 50
num_return_sequences: int = 1
temperature: float = 0.7
top_k: int = 50
do_sample: bool = True
@app.post("/batch_infer", response_model=InferenceResponse) # Reusing InferenceResponse
async def batch_infer(request: BatchInferenceRequest):
"""
Performs text generation for a batch of prompts.
"""
if tokenizer is None or model is None or model_params is None:
raise RuntimeError("Model and tokenizer not loaded yet. Please wait for startup.")
# Tokenize the input prompts (tokenizer handles lists directly)
input_ids = tokenizer(
request.prompts,
return_tensors="jax",
padding=True, # Important for batching, pads to the longest sequence
max_length=model.config.max_position_embeddings,
truncation=True
).input_ids
# Perform generation using the JIT-compiled function
generated_output_ids = generate_text_jitted(
input_ids,
model_params,
max_length=request.max_length,
num_return_sequences=request.num_return_sequences,
temperature=request.temperature,
top_k=request.top_k,
do_sample=request.do_sample,
)
# Decode the generated tokens back to text
# The output_ids will have shape (batch_size * num_return_sequences, sequence_length)
generated_texts = [
tokenizer.decode(output_id, skip_special_tokens=True)
for output_id in generated_output_ids
]
return InferenceResponse(generated_text=generated_texts)
Test with curl:
curl -X POST "http://localhost:8000/batch_infer" \
-H "Content-Type: application/json" \
-d '{
"prompts": [
"The capital of France is",
"What is the largest ocean on Earth?"
],
"max_length": 25,
"temperature": 0.6
}'
Common Pitfalls & Troubleshooting
Out of Memory (OOM) Errors:
- Problem: Your GPU/TPU runs out of memory, especially with large models or high
max_length/batch_size. - Solution:
- Reduce
batch_sizein your API requests. - Reduce
max_lengthfor generation. - Consider model quantization (e.g., to
bfloat16orint8) if supported by Tunix/Flax/JAX and your hardware. Tunix often works withbfloat16by default on TPUs. - Upgrade to a more powerful GPU/TPU.
- Use techniques like offloading parts of the model to CPU memory if latency is less critical.
- Reduce
- Problem: Your GPU/TPU runs out of memory, especially with large models or high
Cold Start Latency:
- Problem: The very first request to your API takes significantly longer than subsequent requests due to JAX’s JIT compilation.
- Solution: As shown in our
main.py, calling the JIT-compiled function with dummy inputs during application startup (@app.on_event("startup")) is an effective way to “warm up” the compilation cache.
Dependency Mismatch/Environment Issues:
- Problem: Your model runs fine during training but fails in deployment. This is often due to different versions of JAX, Flax, Transformers, or Python itself.
- Solution:
- Always use a
requirements.txt(orpyproject.toml) to specify exact versions of all dependencies. - Use Docker or other containerization technologies to ensure the deployment environment is identical to your testing environment.
- Verify CUDA/cuDNN versions on your deployment machine match what JAX expects.
- Always use a
Slow Inference (Beyond Cold Start):
- Problem: Even after warm-up, inference is too slow.
- Solution:
- Ensure JAX is correctly configured to use your GPU/TPU (check JAX device count:
print(jax.devices())). - Optimize your model (e.g., knowledge distillation, smaller architecture if possible).
- Utilize batching effectively (as in our mini-challenge).
- Consider hardware upgrades or distributed inference setups for very high throughput.
- Ensure JAX is correctly configured to use your GPU/TPU (check JAX device count:
Summary
Congratulations! You’ve successfully navigated the waters of LLM deployment. In this chapter, you learned:
- The unique challenges of deploying large language models, including computational demands and latency.
- How to prepare your Tunix-fine-tuned model for deployment by saving its tokenizer and parameters.
- To build a robust and efficient inference API using FastAPI, complete with
Pydanticfor request validation and JAX’s JIT compilation for performance. - The importance of “warming up” your JAX functions during application startup to reduce cold start latency.
- A brief introduction to containerizing your application with Docker for consistent and portable deployment.
- Strategies for troubleshooting common deployment pitfalls like OOM errors and slow inference.
Deploying LLMs is a critical step in bringing your AI projects to fruition. With these foundational skills, you’re well-equipped to make your fine-tuned Tunix models accessible and impactful.
In the next chapter, we’ll dive into advanced MLOps practices for Tunix, including monitoring, logging, and potentially A/B testing, to ensure your deployed models remain performant and relevant over time.
References
- FastAPI Official Documentation
- Uvicorn Official Documentation
- JAX Official Documentation
- Flax Official Documentation
- Hugging Face Transformers - Flax Models
- Docker Official Documentation
This page is AI-assisted and reviewed. It references official documentation and recognized resources where relevant.