Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions examples/open-set-segmentation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ The latest generation of Segment Anything Model (SAM3) optimized for image-based

**Note**: For SAM3 with box/point prompts, see the [sam3-tracker example](../image-segmentation/README.md#sam3-tracker).

### SAM3-LiteText
SAM3-Image with the heavy text encoder replaced by a distilled MobileCLIP text encoder ([arXiv:2602.12173](https://arxiv.org/abs/2602.12173)). The ViT-H vision encoder, geometry encoder and mask decoder are kept intact, so it reuses the SAM3-Image vision/decoder ONNX and only swaps the text encoder.
- Same prompts and interface as SAM3-Image (`Sam3Image` model).
- Variants (`--variant`): `s0` (MobileCLIP-S0), `s1` (MobileCLIP-S1), `l` (MobileCLIP2-L).
- Text-encoder ONNX: [wep21/assets `sam3-litetext`](https://github.com/wep21/assets/releases/tag/sam3-litetext); vision/decoder ONNX: [jamjamjon/assets `sam3`](https://github.com/jamjamjon/assets/releases/tag/sam3).

### YOLOEPromptBased
YOLOE with prompt support for flexible object detection and segmentation.
- `Visual`: Uses a visual prompt (image + bounding box) to find similar objects.
Expand Down Expand Up @@ -129,6 +135,27 @@ cargo run -F cuda-full -F vlm --example open-set-segmentation -- sam3-image \
-p "handle;neg:40,183,278,21"
```

### SAM3-LiteText

Same prompts/format as SAM3-Image (`Sam3Image` model), with a lightweight MobileCLIP text encoder. Select the variant with `--variant {s0,s1,l}`.

```bash
cargo run -F cuda-full -F vlm --example open-set-segmentation -- sam3-litetext \
--variant s0 \
--visual-encoder-dtype f16 --visual-encoder-device cuda:0 \
--textual-encoder-dtype fp16 --textual-encoder-device cuda:0 \
--decoder-dtype f16 --decoder-device cuda:0 \
--processor-device cuda:0 \
--source ./assets/dog.jpg \
-p dog
```

```bash
# CPU
cargo run -F vlm --example open-set-segmentation -- sam3-litetext \
--variant s0 --source ./assets/dog.jpg -p dog
```

#### Prompt Format

| Format | Description | Example |
Expand Down
49 changes: 47 additions & 2 deletions examples/open-set-segmentation/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use usls::{
};

mod sam3_image;
mod sam3_litetext;
#[path = "../utils/mod.rs"]
mod utils;
mod yoloe_prompt_based;
Expand Down Expand Up @@ -38,6 +39,7 @@ struct Cli {
enum Commands {
YOLOEPromptBased(yoloe_prompt_based::YoloePromptArgs),
Sam3Image(sam3_image::Sam3ImageArgs),
Sam3Litetext(sam3_litetext::Sam3LitetextArgs),
}

fn main() -> Result<()> {
Expand All @@ -59,6 +61,12 @@ fn main() -> Result<()> {
.commit()?;
run_sam3_image(config, cli.source, &annotator, args, &cli.prompts)?
}
Commands::Sam3Litetext(args) => {
let config = sam3_litetext::config(args)?
.with_class_confs(&cli.confs)
.commit()?;
run_sam3_litetext(config, cli.source, &annotator, args, &cli.prompts)?
}
Commands::YOLOEPromptBased(args) => {
let config = yoloe_prompt_based::config(args)?
.with_class_confs(&cli.confs)
Expand Down Expand Up @@ -135,6 +143,43 @@ fn run_sam3_image(
annotator: &Annotator,
args: &sam3_image::Sam3ImageArgs,
prompts: &[String],
) -> Result<()> {
run_sam3_image_with_batch(
config,
source,
annotator,
args.visual_encoder_batch,
prompts,
"sam3-image",
)
}

// SAM3-LiteText reuses the SAM3 image model (same vision/geometry/decoder), so it
// shares the Sam3Image inference path and only differs in the config preset.
fn run_sam3_litetext(
config: Config,
source: Source,
annotator: &Annotator,
args: &sam3_litetext::Sam3LitetextArgs,
prompts: &[String],
) -> Result<()> {
run_sam3_image_with_batch(
config,
source,
annotator,
args.visual_encoder_batch,
prompts,
"sam3-litetext",
)
}

fn run_sam3_image_with_batch(
config: Config,
source: Source,
annotator: &Annotator,
visual_encoder_batch: usize,
prompts: &[String],
output_dir: &str,
) -> Result<()> {
if prompts.is_empty() {
anyhow::bail!("No prompt. Use -p \"text\" or -p \"text;pos:x,y,w,h\"");
Expand All @@ -147,7 +192,7 @@ fn run_sam3_image(

let mut model = Sam3Image::new(config)?;
let dl = DataLoader::new(source)?
.with_batch(args.visual_encoder_batch)
.with_batch(visual_encoder_batch)
.with_progress_bar(true)
.stream()?;

Expand All @@ -161,7 +206,7 @@ fn run_sam3_image(
}
annotated.save(
usls::Dir::Current
.base_dir_with_subs(&["runs/open-set-segmentation", "sam3-image"])?
.base_dir_with_subs(&["runs/open-set-segmentation", output_dir])?
.join(format!("{}.jpg", usls::timestamp(None))),
)?;
}
Expand Down
92 changes: 92 additions & 0 deletions examples/open-set-segmentation/sam3_litetext.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use anyhow::Result;
use clap::{Args, ValueEnum};
use usls::{Config, DType, Device};

#[derive(Debug, Clone, Copy, ValueEnum)]
pub enum Sam3LitetextVariant {
/// S0: MobileCLIP-S0 text encoder.
S0,
/// S1: MobileCLIP-S1 text encoder.
S1,
/// L: MobileCLIP2-L text encoder.
L,
}

#[derive(Args, Debug)]
pub struct Sam3LitetextArgs {
/// SAM3-LiteText text-encoder variant.
#[arg(long, value_enum, default_value = "s0")]
pub variant: Sam3LitetextVariant,

/// Visual Encoder Dtype: fp32, fp16, q4f16, etc.
#[arg(long, default_value = "f16")]
pub visual_encoder_dtype: DType,

/// Visual Encoder Device: cpu, cuda:0, mps, coreml, openvino:CPU, etc.
#[arg(long, global = true, default_value = "cpu")]
pub visual_encoder_device: Device,

/// Visual encoder batch
#[arg(long, default_value_t = 1)]
pub visual_encoder_batch: usize,

/// Textual Encoder Dtype: fp32, fp16, q4f16, etc.
#[arg(long, default_value = "fp16")]
pub textual_encoder_dtype: DType,

/// Textual Encoder Device: cpu, cuda:0, mps, coreml, openvino:CPU, etc.
#[arg(long, global = true, default_value = "cpu")]
pub textual_encoder_device: Device,

/// Textual encoder batch
#[arg(long, default_value_t = 1)]
pub textual_encoder_batch: usize,

/// Decoder Dtype: fp32, fp16, q4f16, etc.
#[arg(long, default_value = "f16")]
pub decoder_dtype: DType,

/// Decoder Device: cpu, cuda:0, mps, coreml, openvino:CPU, etc.
#[arg(long, global = true, default_value = "cpu")]
pub decoder_device: Device,

/// Decoder batch
#[arg(long, default_value_t = 1)]
pub decoder_batch: usize,

/// Processor device (for pre/post processing)
#[arg(long, global = true, default_value = "cpu")]
pub processor_device: Device,

/// num dry run
#[arg(long, global = true, default_value_t = 0)]
pub num_dry_run: usize,

/// trt_max_workspace_size
#[arg(long, global = true, default_value_t = 3221225472)]
pub trt_max_workspace_size: usize,
}

pub fn config(args: &Sam3LitetextArgs) -> Result<Config> {
let config = match args.variant {
Sam3LitetextVariant::S0 => Config::sam3_litetext_s0(),
Sam3LitetextVariant::S1 => Config::sam3_litetext_s1(),
Sam3LitetextVariant::L => Config::sam3_litetext_l(),
};

let config = config
.with_visual_encoder_batch_min_opt_max(1, args.visual_encoder_batch, 2)
.with_textual_encoder_batch_min_opt_max(1, args.textual_encoder_batch, 2)
.with_decoder_batch_min_opt_max(1, args.decoder_batch, 2)
.with_visual_encoder_device(args.visual_encoder_device)
.with_visual_encoder_dtype(args.visual_encoder_dtype)
.with_textual_encoder_device(args.textual_encoder_device)
.with_textual_encoder_dtype(args.textual_encoder_dtype)
.with_decoder_device(args.decoder_device)
.with_decoder_dtype(args.decoder_dtype)
.with_num_dry_run_all(args.num_dry_run)
.with_image_processor_device(args.processor_device)
.with_tensorrt_max_workspace_size_all(args.trt_max_workspace_size);

Ok(config)
}
29 changes: 29 additions & 0 deletions scripts/sam3-litetext/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SAM3-LiteText text-encoder ONNX export

[SAM3-LiteText](https://huggingface.co/docs/transformers/model_doc/sam3_lite_text)
(arXiv:2602.12173) is SAM3-Image with the heavy text encoder replaced by a
distilled MobileCLIP student; the vision encoder, geometry encoder and mask
decoder are kept intact. The Rust presets reuse the SAM3 image vision/decoder
ONNX and only need this lightweight text encoder.

The exported text encoder is a drop-in replacement for the SAM3 text encoder
(inputs `input_ids[B,32]`, `attention_mask[B,32]`; outputs `text_features[B,32,256]`,
`text_mask[B,32]`).

| variant | HF model | text encoder |
|---|---|---|
| `s0` | `vil-uob/sam3-litetext-s0` | MobileCLIP-S0 |
| `s1` | `vil-uob/sam3-litetext-s1` | MobileCLIP-S1 |
| `l` | `vil-uob/sam3-litetext-l` | MobileCLIP2-L |

## Export

```bash
cd scripts/sam3-litetext
uv run export_text_encoder.py --model vil-uob/sam3-litetext-s0 --prefix sam3-litetext-s0 --precision fp32
uv run export_text_encoder.py --model vil-uob/sam3-litetext-s0 --prefix sam3-litetext-s0 --precision fp16
```

Each run writes to `onnx-sam3-litetext/` (override with `--out-dir`) and verifies
ONNX Runtime matches PyTorch. fp16 uses NVIDIA Model Optimizer AutoCast
(`nvidia-modelopt[onnx]`) for precision-aware conversion.
115 changes: 115 additions & 0 deletions scripts/sam3-litetext/export_text_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Export the SAM3-LiteText MobileCLIP text encoder to ONNX.

SAM3-LiteText (arXiv:2602.12173) keeps the SAM3 ViT-H vision encoder, geometry
encoder and mask decoder intact and only replaces the text encoder with a
distilled MobileCLIP student. usls therefore reuses the existing SAM3 image
vision/decoder ONNX (jamjamjon/assets `sam3` release) and only needs this
lightweight text encoder.

The exported ONNX is a drop-in replacement for the SAM3 text encoder:
inputs : input_ids[B, 32], attention_mask[B, 32]
outputs: text_features[B, 32, 256], text_mask[B, 32] (bool, True = valid)

Variants (HuggingFace `vil-uob/sam3-litetext-{s0,s1,l}`):
s0 -> MobileCLIP-S0, s1 -> MobileCLIP-S1, l -> MobileCLIP2-L

Usage:
uv run export_text_encoder.py --model vil-uob/sam3-litetext-s0 --prefix sam3-litetext-s0 --precision fp32
uv run export_text_encoder.py --model vil-uob/sam3-litetext-s0 --prefix sam3-litetext-s0 --precision fp16

fp16 conversion uses NVIDIA Model Optimizer AutoCast (precision-aware; keeps
numerically unsafe nodes in fp32), which handles MobileCLIP's `.float()`
LayerNorm/Softmax regions that onnxconverter_common cannot.
"""
from __future__ import annotations

import argparse
from pathlib import Path

import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.nn as nn
from transformers import AutoModel


class LiteTextTextEncoder(nn.Module):
"""Wraps the HF model's text path into a plain-tensor ONNX interface."""

def __init__(self, model):
super().__init__()
self.model = model

def forward(self, input_ids, attention_mask):
text = self.model.get_text_features(
input_ids=input_ids, attention_mask=attention_mask, return_dict=True
)
# pooler_output is the per-token projected features [B, seq, 256];
# text_mask uses True = valid token (matches the SAM3 decoder).
return text.pooler_output, attention_mask.bool()


def main() -> None:
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("--model", default="vil-uob/sam3-litetext-s0")
ap.add_argument("--out-dir", default="onnx-sam3-litetext")
ap.add_argument("--prefix", default="sam3-litetext-s0")
ap.add_argument("--precision", choices=["fp32", "fp16"], default="fp32")
ap.add_argument("--seq", type=int, default=32, help="text context length (fixed by the SAM3 decoder)")
args = ap.parse_args()

out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
suffix = "-fp16" if args.precision == "fp16" else ""
out = out_dir / f"{args.prefix}-text-encoder{suffix}.onnx"

model = AutoModel.from_pretrained(args.model).eval()
wrapper = LiteTextTextEncoder(model).eval()

# Dummy prompt "dog": BOS, token, EOS, then EOS-padding to `seq`.
ids = torch.full((1, args.seq), 49407, dtype=torch.long)
ids[0, :3] = torch.tensor([49406, 1929, 49407])
attn = torch.ones(1, args.seq, dtype=torch.long)

with torch.no_grad():
ref_tf, ref_tm = wrapper(ids, attn)
print("torch text_features", tuple(ref_tf.shape), "text_mask", tuple(ref_tm.shape))

torch.onnx.export(
wrapper, (ids, attn), str(out),
input_names=["input_ids", "attention_mask"],
output_names=["text_features", "text_mask"],
opset_version=17, do_constant_folding=True, dynamo=False,
dynamic_axes={"input_ids": {0: "batch"}, "attention_mask": {0: "batch"},
"text_features": {0: "batch"}, "text_mask": {0: "batch"}},
)
print("exported:", out)

if args.precision == "fp16":
from modelopt.onnx.autocast import convert_to_mixed_precision

model_fp16 = convert_to_mixed_precision(
str(out), low_precision_type="fp16", keep_io_types=True
)
onnx.save(model_fp16, str(out))
print("converted to fp16 (modelopt AutoCast)")

# Verify ONNX Runtime matches PyTorch.
sess = ort.InferenceSession(str(out), providers=["CPUExecutionProvider"])
np_dtype = {"tensor(float)": np.float32, "tensor(float16)": np.float16,
"tensor(int64)": np.int64, "tensor(bool)": np.bool_}
feeds = {"input_ids": ids.numpy(), "attention_mask": attn.numpy()}
cast = {i.name: feeds[i.name].astype(np_dtype[i.type]) for i in sess.get_inputs()}
got = dict(zip([o.name for o in sess.get_outputs()], sess.run(None, cast)))
r = ref_tf.numpy().astype(np.float64)
g = np.asarray(got["text_features"], np.float64)
diff = np.abs(r - g)
cos = float(r.ravel() @ g.ravel() / (np.linalg.norm(r.ravel()) * np.linalg.norm(g.ravel()) + 1e-12))
print(f"verify text_features: max_abs={diff.max():.3e} mean_abs={diff.mean():.3e} cos={cos:.6f}")
mism = int((np.asarray(got["text_mask"]).astype(bool) != ref_tm.numpy()).sum())
print(f"verify text_mask: mismatched={mism}")


if __name__ == "__main__":
main()
13 changes: 13 additions & 0 deletions scripts/sam3-litetext/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[project]
name = "sam3-litetext-tools"
version = "0.1.0"
requires-python = ">=3.10"

dependencies = [
"torch",
"transformers>=5.12",
"onnx",
"onnxruntime",
"nvidia-modelopt[onnx]>=0.44",
"numpy>=2.0",
]
Loading
Loading