Skip to main content

Overview

Cosmos3-Nano-Policy-DROID is the policy model in the Cosmos3 family. Cosmos3 itself is an omnimodal world model for Physical AI; the policy variant takes a language instruction plus a DROID robot platform observation and produces robot action trajectories for manipulation and control. This page covers the multi-GPU Cosmos3 policy path, exposed through the cosmos3_policy_wn plugin. It supports policy, forward_dynamics, and inverse_dynamics. Video latent and action latent advance through the same denoising loop, and the final output is action. If decode_video=True, the plugin also returns rollout video. PhyAI currently supports two kinds of parallelism in this path. The policy transformer runs tensor parallelism on the tp axis. When cfg=2 and guidance_scale > 1, the cond and uncond CFG branches run in parallel on two TP groups. Rollout video VAE decode is also split into spatial tiles across ranks, with halo overlap used to stitch tile boundaries.

Modes and output

The clean / noisy rules for the three modes are:
ModeClean videoClean actionGeneration target
policyDefault latent frame 0, or the frames listed in cond_frame_indexesNoneAction chunk, optional rollout video
forward_dynamicsDefault latent frame 0, or the frames listed in cond_frame_indexesAll action steps in cond_actionRollout video
inverse_dynamicsAll video latent frames by default, or the frames listed in cond_frame_indexesNoneAction chunk that explains the observation transition
Action is always returned. When decode_video=True, the plugin also returns video latent and decoded pixels.
KeyShape / TypeMeaning
action[B, action_chunk, raw_action_dim]Padding tail already removed
video[B, C, t_lat, h_lat, w_lat]Rollout / denoised video latent
pixels[B, 3, T, H, W], optionalReturned only when decode_video=True and the checkpoint has VAE
Cosmos3 policy uses an internal action_dim=64. The real robot action width is raw_action_dim, and the scheduler removes padding before returning output.

Parallel topology

The example below uses TP=4, CFG=2, and world_size=8. Rank 0-3 form the TP group for the cond branch, and rank 4-7 form the TP group for the uncond branch. Within each denoising step, the four TP ranks in the same branch run transformer forward together rather than as a serial pipeline. Cosmos3 WN TP=4 CFG=2 eight-GPU parallel topology P.all_gather(axis="cfg") uses the parallel mesh created during engine initialization. ParallelConfig(world_size=cfg_size * tp_size, cfg_size=cfg_size, tp_size=tp_size) maps each rank to (cfg_rank, tp_rank). Gathering along the cfg axis only collects ranks that share the same tp_rank but differ in cfg_rank, so each TP shard receives cond and uncond velocity and can complete CFG combine locally. The VAE eight-GPU split is shown below, with cfg as the outer axis and tp as the inner axis: Cosmos3 WAN VAE eight-GPU spatial split

Run path

1

Prepare weights and inputs

Prepare a Cosmos3-Nano-Policy-DROID checkpoint. If you want rollout video output, the checkpoint also needs vae/.
/path/to/Cosmos3-Nano-Policy-DROID/
  transformer/
  text_tokenizer/
  scheduler/
  vae/             # required when decode_video=True
policy and inverse_dynamics can take observation image or video input. forward_dynamics also needs action JSON.
2

Construct the multi-GPU policy engine

The plugin name is "cosmos3_policy_wn". torchrun --nproc_per_node must equal cfg_size * tp_size.
import torch

from phyai.engine import Engine, EngineArgs
from phyai.engine_config import (
    DeviceConfig,
    EngineConfig,
    ParallelConfig,
    RuntimeConfig,
)
from phyai.models.cosmos3.main_cosmos3_policy_wn import Cosmos3PolicyWNArgs

checkpoint_dir = "/path/to/Cosmos3-Nano-Policy-DROID"
local_rank = 0
cfg_size = 1
tp_size = 4

engine = Engine(
    EngineArgs(
        plugin="cosmos3_policy_wn",
        plugin_args=Cosmos3PolicyWNArgs(
            checkpoint_dir=checkpoint_dir,
            flow_shift=10.0,
            use_karras_sigmas=None,
            decode_video=True,
        ),
        config=EngineConfig(
            device=DeviceConfig(
                target=f"cuda:{local_rank}",
                params_dtype=torch.bfloat16,
            ),
            parallel=ParallelConfig(
                world_size=cfg_size * tp_size,
                cfg_size=cfg_size,
                tp_size=tp_size,
            ),
            runtime=RuntimeConfig(use_cuda_graph=False),
        ),
    )
)
3

Construct the input processor

Cosmos3PolicyProcessor handles observation resize / padding, prompt tokenization, action padding, domain id, and output postprocessing.
from phyai_utils_tools.models.cosmos3 import Cosmos3PolicyProcessor

processor = Cosmos3PolicyProcessor(
    tokenizer_name_or_path=f"{checkpoint_dir}/text_tokenizer",
    height=480,
    width=832,
    num_frames=17,
    mode="policy",
    domain_name="droid_lerobot",
    action_chunk_size=16,
    fps=24.0,
    image_size=480,
    prompt_format="json",
    view_point="ego_view",
    cond_frame_indexes=(0,),
    device=f"cuda:{local_rank}",
    params_dtype=torch.bfloat16,
)

processed = processor.preprocess(
    {
        "images": "/path/to/observation.png",
        "task": "robot picks up the cup",
    }
)
4

Build the request

Cosmos3ActionRequest does not carry parallel topology. Parallelism comes from the engine config; the request only describes this policy inference.
from phyai.models.cosmos3 import Cosmos3ActionRequest, pixel_to_latent_shape

device = f"cuda:{local_rank}"
dtype = torch.bfloat16

request = Cosmos3ActionRequest(
    text_ids=processed.text_ids.to(device),
    text_mask=processed.text_mask.to(device),
    neg_text_ids=processed.neg_text_ids.to(device),
    neg_text_mask=processed.neg_text_mask.to(device),
    video_shape=pixel_to_latent_shape(*processed.video_shape),
    mode=processed.mode,
    domain_id=processed.domain_id,
    action_chunk=processed.action_chunk,
    raw_action_dim=processed.raw_action_dim,
    cond_video_pixels=processed.pixel_values.to(device=device, dtype=dtype),
    cond_action=(
        processed.cond_action.to(device=device, dtype=dtype)
        if processed.cond_action is not None
        else None
    ),
    cond_frame_indexes=processed.cond_frame_indexes,
    fps=24.0,
    num_inference_steps=30,
    guidance_scale=1.0,
    seed=42,
)
5

Run all ranks together

Every rank must call engine.step(request). The scheduler triggers collectives on the tp and cfg axes, so rank 0 cannot run alone.
result = engine.step(request)
6

Save results only on rank 0

The example script only lets rank 0 postprocess, write action JSON, and save mp4 output, so multiple processes do not write the same file.
if local_rank == 0:
    output = processor.postprocess(result)
    action = output["action"]
    pixels = output.get("pixels")

Run examples

TP-only four-GPU policy inference:
torchrun --nproc_per_node=4 examples/cosmos3/run_cosmos3_policy_wn.py \
    --tp 4 \
    --checkpoint /path/to/Cosmos3-Nano-Policy-DROID \
    --image observation.png \
    --prompt "robot picks up the cup" \
    --domain-name droid_lerobot \
    --out .cache/cosmos3_policy_wn
Eight-GPU policy inference with CFG parallel + TP:
torchrun --nproc_per_node=8 examples/cosmos3/run_cosmos3_policy_wn.py \
    --cfg 2 \
    --tp 4 \
    --guidance-scale 4.0 \
    --checkpoint /path/to/Cosmos3-Nano-Policy-DROID \
    --image observation.png \
    --prompt "robot picks up the cup" \
    --domain-name droid_lerobot \
    --out .cache/cosmos3_policy_wn
Forward dynamics requires an action file:
torchrun --nproc_per_node=4 examples/cosmos3/run_cosmos3_policy_wn.py \
    --tp 4 \
    --checkpoint /path/to/Cosmos3-Nano-Policy-DROID \
    --image observation.png \
    --prompt "robot pushes the object forward" \
    --domain-name droid_lerobot \
    --mode forward_dynamics \
    --action-file action.json \
    --out .cache/cosmos3_forward_wn
Inverse dynamics usually takes an observation video and specifies clean latent frames:
torchrun --nproc_per_node=4 examples/cosmos3/run_cosmos3_policy_wn.py \
    --tp 4 \
    --checkpoint /path/to/Cosmos3-Nano-Policy-DROID \
    --video obs.mp4 \
    --prompt "robot moves the cup to the right" \
    --domain-name droid_lerobot \
    --mode inverse_dynamics \
    --condition-frames 0,1 \
    --out .cache/cosmos3_inverse_wn
--nproc_per_node must equal --cfg * --tp. The policy example defaults to guidance_scale=1.0, where cfg=2 has no benefit; CFG parallel only matters once --guidance-scale is greater than 1.

Implementation notes

  • decode_video=True requires vae/ in the checkpoint. Without it, the path can only return action and video latent.
  • forward_dynamics must provide cond_action; the processor pads the raw action to action_dim.
  • This path is still a single-request example / baseline path. It is not a continuous batching scheduler.