Skip to main content

Overview

pi0 is a vision-language-action model that combines a PaliGemma-style image and language prefix with a Gemma action expert. In PhyAI, the ws1 path runs the full end-to-end inference loop on one GPU: encode cameras and task text, cache the prefix, condition the expert on robot state, and integrate the action chunk with flow matching. This page describes the single-card implementation. There is no tensor parallelism, no continuous batching, and no preemption in PI0WS1Scheduler.
pi0 differs from pi0.5 in how robot state enters the model. pi0 keeps state as a numeric expert-side token. pi0.5 folds discretized state bins into the language prompt.

Architecture

PhyAI uses the same as the other model integrations. The pi0 path is split across configuration, model modules, runners, and a scheduler:
phyai/src/phyai/models/pi0
main_pi0.py
scheduler_ws1_pi0.py
model_runner_pi0.py
modeling_pi0.py
configuration_pi0.py
ComponentResponsibility
PI0EntryRegisters the "pi0" engine plugin, builds PI0Model, loads weights, and creates the scheduler
PI0ConfigStores vision, text, expert, action chunk, tokenizer, and camera-count geometry
PI0ModelOwns the SigLIP/PaliGemma vision tower, PaliGemma text stack, Gemma expert stack, RoPE, and action/time heads
PI0VisionRunnerRuns the vision tower, with optional CUDA graph capture
PI0LLMRunnerRuns the PaliGemma prefix pass and writes prefix K/V into the shared cache
PI0ExpertRunnerRuns the expert state/action passes for each flow-matching step
PI0WS1SchedulerOrchestrates one complete inference request on a single GPU

Model layout

PI0Model is built from three major stacks:
StackDefault shapeNotes
VisionSigLIP, 27 layers, 224×224 images, 14×14 patchesProduces image tokens projected into the PaliGemma text width
TextPaliGemma/Gemma, 18 layers, hidden size 2048Processes image + language prefix and writes prefix K/V
ExpertGemma action expert, 18 layers, hidden size 1024Processes one state token plus the full action chunk
The top-level config defaults are:
FieldDefaultMeaning
chunk_size50Number of action tokens returned per engine step
max_state_dim32Padded robot-state width
max_action_dim32Padded action width
num_inference_steps10Flow-matching Euler steps
tokenizer_max_length48Right-padded PaliGemma task prompt length
empty_cameras0num_images = 3 - empty_cameras; pi0 supports 2 or 3 cameras
The model uses params_dtype for the language and expert stacks. The vision tower has a separate vision_params_dtype, which defaults to fp32 for reference parity. Set PI0Args(vision_params_dtype=torch.bfloat16) only when you intentionally want bf16 vision execution.

Request contract

PI0Request is the scheduler’s canonical input:
FieldShapeNotes
pixel_values(B, num_images, 3, image_size, image_size)Already resized and normalized camera tensors
input_ids(B, tokenizer_max_length) int64Right-padded PaliGemma token ids
lang_lens(B,) int64Real task-prompt length for each sample
state(B, max_state_dim)Numeric robot state, padded before the expert
noise(B, chunk_size, max_action_dim) or NoneOptional initial action noise; when None, the scheduler samples Gaussian noise
B can be any value in [1, max_batch_size]. The scheduler pads smaller batches to max_batch_size internally and slices the result back to actual_B before returning.

Scheduler phases

One engine.step(request) maps to the following scheduler phases:
PhaseWork
pi0.vision_loopMove camera tensors to the vision dtype and run PI0VisionRunner once per real batch item
pi0.lang_packEmbed language ids, then pack image tokens and language tokens into the per-sample prefix buffer
pi0.llm_prefix_planReset static caches and prepare ragged prefix attention metadata
pi0.llm_prefix_fwdRun the PaliGemma text stack and write prefix K/V into KVCachePool
pi0.expert_planPrepare state and action expert attention metadata over prefix + suffix slots
pi0.expert_loopInitialize or copy action noise and run flow-matching integration
pi0.expert_stepOne expert velocity prediction and Euler update inside pi0.expert_loop
The prefix tokens are cached once per request. The expert then attends over:
state query  -> prefix + state
action query -> prefix + state + action chunk
This is why pi0’s suffix length is 1 + chunk_size: one state token followed by the action tokens.

CUDA graphs

When RuntimeConfig(use_cuda_graph=True), the pi0 runners capture CUDA graphs during scheduler.setup():
RunnerCaptured shape
PI0VisionRunner(num_images, 3, image_size, image_size)
PI0LLMRunner(max_batch_size * n_per_sample, text_hidden_size)
PI0ExpertRunnerstate, x_t, and time buffers at fixed max_batch_size
During scheduler.step(), the runners update static graph input buffers and replay the captured graphs. Attention metadata is staged outside the captured region through the attention backend’s capture-aware metadata buffers.
Disable CUDA graphs when you want a more expanded Nsight Systems trace:
uv run python benchmark/bench_n_batch_ws1_pi0.py \
    --batch-sizes 4 \
    --no-cuda-graph

Running pi0

1

Prepare weights

Prepare a HF-style pi0 PyTorch checkpoint directory with config.json and model.safetensors files. You can also omit --checkpoint for random-weight smoke tests.
2

Construct the engine

The plugin name is "pi0". The engine handles setup, optional weight loading, runner setup, and CUDA graph capture.
import torch
from pathlib import Path

from phyai.engine import Engine, EngineArgs
from phyai.engine_config import DeviceConfig, EngineConfig, RuntimeConfig
from phyai.models.pi0.main_pi0 import PI0Args

engine = Engine(
    EngineArgs(
        plugin="pi0",
        plugin_args=PI0Args(
            checkpoint_dir=Path("/path/to/pi0_pytorch"),
            max_batch_size=4,
            vision_params_dtype=torch.float32,
        ),
        config=EngineConfig(
            device=DeviceConfig(target="cuda", params_dtype=torch.bfloat16),
            runtime=RuntimeConfig(use_cuda_graph=True),
        ),
    )
)
max_batch_size fixes the captured graph shapes. Rebuild the engine if you need a different maximum batch.
3

Build a request

Use PI0Processor to convert raw robot observations into model-ready tensors. The processor lives outside the engine in phyai-utils-tools.
from phyai.models.pi0.scheduler_ws1_pi0 import PI0Request
from phyai_utils_tools.models.pi0 import PI0Processor

processor = PI0Processor(
    image_size=224,
    num_channels=3,
    num_images=3,
    tokenizer_max_length=48,
    max_state_dim=32,
    action_dim=7,
    device="cuda",
    params_dtype=torch.bfloat16,
)

processed = processor.preprocess(
    {
        "images": [cam0, cam1, cam2],
        "task": ["pick up the object"],
        "state": state,
    }
)

request = PI0Request(
    pixel_values=processed.pixel_values,
    input_ids=processed.input_ids,
    lang_lens=processed.lang_lens,
    state=processed.state,
)
4

Run one step

actions = engine.step(request)  # (B, chunk_size, max_action_dim)
If you constructed a processor with action_dim, call processor.postprocess(actions) to slice the padded action width and unnormalize actions when dataset stats are available.
5

Close the engine

engine.close()

End-to-end example

examples/pi0/run_pi0.py exercises both raw and processor-backed request paths:
uv run python examples/pi0/run_pi0.py \
    --checkpoint /path/to/pi0_pytorch \
    --batch-size 1
For a random-weight smoke test, omit --checkpoint:
uv run python examples/pi0/run_pi0.py --raw --batch-size 1
Use --num-images 2 when your checkpoint uses one empty camera:
uv run python examples/pi0/run_pi0.py \
    --checkpoint /path/to/pi0_pytorch \
    --num-images 2

Benchmarking and profiling

benchmark/bench_n_batch_ws1_pi0.py sweeps batch sizes and can open a tight profile window for Nsight Systems:
uv run python benchmark/bench_n_batch_ws1_pi0.py \
    --batch-sizes 1 2 4 \
    --n-warmup 5 \
    --n-timed 30 \
    --result-file ./pi0_ws1_results.jsonl
Nsight Systems capture:
nsys profile \
    --capture-range=cudaProfilerApi \
    --capture-range-end=stop \
    -o ./prof/pi0_ws1 \
    uv run python benchmark/bench_n_batch_ws1_pi0.py \
        --batch-sizes 4 \
        --profile-backend nsys \
        --profile-start-step 5 \
        --profile-num-steps 3
Set --vision-dtype bfloat16 only when you intentionally want bf16 vision timing. The default keeps the vision tower in fp32.

Current limitations

  • This path is single-GPU only.
  • max_batch_size is fixed at engine construction.
  • The vision tower is replayed once per real batch item.
  • The scheduler expects already preprocessed tensors. Image resize, tokenization, state padding, and action unnormalization belong to PI0Processor.
  • CUDA graph capture is shape-fixed. Change camera count, image size, tokenizer length, or max batch by rebuilding the engine.

Full example

from pathlib import Path

import torch

from phyai.engine import Engine, EngineArgs
from phyai.engine_config import DeviceConfig, EngineConfig, RuntimeConfig
from phyai.models.pi0.configuration_pi0 import PI0Config
from phyai.models.pi0.main_pi0 import PI0Args
from phyai.models.pi0.scheduler_ws1_pi0 import PI0Request
from phyai.utils import load_config

CHECKPOINT_DIR = Path("/path/to/pi0_pytorch")
BATCH_SIZE = 1

cfg = load_config(CHECKPOINT_DIR, PI0Config)
device = torch.device("cuda")
dtype = torch.bfloat16

engine = Engine(
    EngineArgs(
        plugin="pi0",
        plugin_args=PI0Args(
            checkpoint_dir=CHECKPOINT_DIR,
            max_batch_size=BATCH_SIZE,
            vision_params_dtype=torch.float32,
        ),
        config=EngineConfig(
            device=DeviceConfig(target="cuda", params_dtype=dtype),
            runtime=RuntimeConfig(use_cuda_graph=True),
        ),
    )
)

try:
    input_ids = torch.zeros(
        BATCH_SIZE, cfg.tokenizer_max_length, dtype=torch.int64, device=device
    )
    input_ids[:, 0] = 2

    request = PI0Request(
        pixel_values=torch.rand(
            BATCH_SIZE,
            cfg.num_images,
            cfg.vision.num_channels,
            cfg.vision.image_size,
            cfg.vision.image_size,
            dtype=torch.float32,
            device=device,
        ),
        input_ids=input_ids,
        lang_lens=torch.ones(BATCH_SIZE, dtype=torch.int64, device=device),
        state=torch.rand(BATCH_SIZE, cfg.max_state_dim, dtype=dtype, device=device),
    )

    actions = engine.step(request)
    print(f"action chunk shape={tuple(actions.shape)}, dtype={actions.dtype}")
finally:
    engine.close()