Creating Inference Checkpoints
Overview
This document outlines the process for converting inference checkpoints for use with various model servers, such as Jetstream with MaxText or Pytorch/XLA backends. The core of this process utilizes the checkpoint_entrypoint.sh
script, packaged within a Docker container, to handle the specific conversion steps required by different server configurations. The goal is to prepare your trained model checkpoints for efficient deployment and inference serving.
Checkpoint creation
The checkpoint_entrypoint.sh script overviews how to convert your inference checkpoint for various model servers.
-
Clone the AI-on-GKE/tutorial-and-examples repository
git clone https://github.com/ai-on-gke/tutorials-and-examples cd tutorials-and-examples/inference-servers/checkpoints
-
Build the Docker image that contains the conversion script and its dependencies. Tag the image and push it to a container registry (like Google Container Registry - GCR) accessible by your execution environment (e.g., Kubernetes).
docker build -t inference-checkpoint . docker tag inference-checkpoint gcr.io/${PROJECT_ID}/inference-checkpoint:latest docker push gcr.io/${PROJECT_ID}/inference-checkpoint:latest
-
The conversion is typically run as a containerized job, for example, using a Kubernetes job. You will need to configure the job to use the
gcr.io/${PROJECT_ID}/inference-checkpoint:latest
image and pass the required arguments based on your target inference server and checkpoint details.Jetstream + MaxText
- -s=INFERENCE_SERVER - -b=BUCKET_NAME - -m=MODEL_PATH - -v=VERSION (Optional)
Jetstream + Pytorch/XLA
- -s=INFERENCE_SERVER - -m=MODEL_PATH - -n=MODEL_NAME - -q=QUANTIZE_WEIGHTS (Optional) (default=False) - -t=QUANTIZE_TYPE (Optional) (default=int8_per_channel) - -v=VERSION (Optional) (default=jetstream-v0.2.3) - -i=INPUT_DIRECTORY (Optional) - -o=OUTPUT_DIRECTORY - -h=HUGGINGFACE (Optional) (default=False)
Argument descriptions:
The following table details the arguments accepted by the checkpoint_entrypoint.sh script:
Argument | Flag | Type | Description |
---|---|---|---|
BUCKET_NAME | -b |
str |
Google Cloud Storage (GCS) Bucket name, without the gs:// prefix. Required for certain server types like Jetstream+MaxText. |
INFERENCE_SERVER | -s |
str |
Specifies the target inference server. Examples: jetstream-maxtext , jetstream-pytorch . |
MODEL_PATH | -m |
str |
Path to the base model checkpoint. The interpretation varies depending on the inference server and whether it’s a local path or GCS path. |
MODEL_NAME | -n |
str |
Name of the model architecture (e.g., llama-2 , llama-3 , gemma ). Required for servers like Jetstream+Pytorch/XLA. |
HUGGINGFACE | -h |
bool |
(Optional) Set to true if the input checkpoint is from HuggingFace Hub format. Default: false . |
QUANTIZE_WEIGHTS | -q |
str |
(Optional) Set to true to enable weight quantization. Default: false . |
QUANTIZE_TYPE | -t |
str |
(Optional) Specifies the quantization type if QUANTIZE_WEIGHTS is true. Available types include combinations like int8_per_channel , int4_blockwise . Default: int8_per_channel . |
VERSION | -v |
str |
(Optional) Overrides the default version of the inference server components being used (e.g., jetstream-v0.2.2 , jetstream-v0.2.3 ). |
INPUT_DIRECTORY | -i |
str |
(Optional) Explicitly specifies the input directory for the checkpoint, often a GCS path (gs://... ). May override parts of MODEL_PATH . |
OUTPUT_DIRECTORY | -o |
str |
Specifies the directory where the converted checkpoint should be saved, often a GCS path (gs://... ). Required for servers like Jetstream+Pytorch/XLA. |
Feedback
Was this page helpful?
Thank you for your feedback.
Thank you for your feedback.