Dedicated Endpoints

Run this model inference on single tenant GPU with unmatched speed and reliability at scale.

Learn more
Container

Run this model inference with full control and performance in your environment.

Learn more

Get help setting up a custom Dedicated Endpoints.

Talk with our engineer to get a quote for reserved GPU instances with discounts.

README

License: mit

What this checkpoint does

The model emits <swi> to enter latent mode and </swi> to exit. Inside the latent block it performs Coconut-style hidden-state recurrence (each step's last-layer hidden state becomes the input embedding of the next <latent> position). Outside the block it decodes ordinary text. The boundary tokens are ordinary discrete vocabulary items, so on-policy GRPO is well-defined at every text position; latent positions contribute no policy-gradient term.

This adapter was trained in three phases on Qwen3-8B:

  1. Phase 1 (SFT). Wrap high-entropy CoT spans in <swi>/</swi>.
  2. Phase 2 (Curriculum). Replace text inside <swi> blocks with <latent> placeholders progressively (parallel schedule).
  3. Phase 3 (Switch-GRPO). On-policy RL on the answer reward, with rollouts that perform real hidden-state injection at <latent> positions.

This release is the Phase 3 endpoint, the version reported in the paper.

Quick start

python

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
BASE = "Qwen/Qwen3-8B"
ADAPTER = "LARK-Lab/SWITCH-Phase3-GRPO-LoRA-Qwen3-8B"
tokenizer = AutoTokenizer.from_pretrained(ADAPTER) # contains <swi>, </swi>, <latent>
model = AutoModelForCausalLM.from_pretrained(
BASE, torch_dtype=torch.bfloat16, device_map="auto"
)
model.resize_token_embeddings(len(tokenizer))
model = PeftModel.from_pretrained(model, ADAPTER)
model.eval()

⚠️ Important: A naïve model.generate(...) will treat <latent> as just another token and will not perform the hidden-state recurrence inside <swi>...</swi> blocks. To run inference exactly as in the paper, use the SWITCH inference loop in src/model/coconut_swi_model.py, which feeds the previous latent step's last-layer hidden state back as the next input embedding and enforces the K_min minimum-dwell constraint inside the latent block.

Headline results

BenchmarkSWITCH (this checkpoint)Strongest Coconut-style baselineGap
MATH-50079.3 %53.6 %+25.7
GSM8K89.2 %78.5 %+10.7

All numbers under matched data, decoding, and Qwen3-8B base-model settings.

Training details

Base modelQwen/Qwen3-8B
Phase 1LoRA (r=32, α=64) on {q,k,v,o,gate,up,down}_proj + resized embeddings + LM head, bf16
Phase 2LoRA continued from Phase 1; parallel curriculum schedule, c=2, K_max=8, per-sample latent cap 48
Phase 3Switch-GRPO; group size G=5, clip ε=0.2, KL β=1e-3, lr=1e-6; reward = correctness + format + latent-usage
Training dataLARK-Lab/SWITCH-Math-Train
Hardware8 × NVIDIA H20 (95 GB)
K_min (inference)4

See the paper §3 for the full method and §4–§5 for setup / results.

Special tokens

TokenPurpose
<swi>Enter latent reasoning
</swi>Exit latent reasoning
<latent>Latent placeholder; no token sampled, hidden-state injection happens here

Mechanistic findings (verifiable on this checkpoint)

The boundary tokens make latent computation directly inspectable. Three takeaways from the paper, all reproducible with scripts/interpret_swi.py:

  1. <swi> is a learned switching policy, not a stylistic tag. Sharply localised (rank ≤ 2 at boundaries vs ~10³ at random positions), forms a clean one-token spike, linearly decodable from late hidden states (~91.9 %).
  2. The latent step performs causally important computation. Zeroing the injected hidden states reduces accuracy by roughly two-thirds on the diagnostic subset; same-norm random replacements cost only a few points.
  3. The work is concentrated at a single hidden-state transition on entry; subsequent steps are near-deterministic exits with p(</swi>) ≈ 1. The K_min constraint is what protects this single computational step.

Intended use

  • Math-reasoning research on hidden-state-recurrence latent CoT.
  • Reproducing the SWITCH paper's main table and mechanistic analysis.
  • A starting point for further on-policy RL on latent-recurrent models.

Limitations

  • Only English mathematical reasoning is in the training distribution.
  • Visible token count (~1,700 / problem on MATH-500) is much higher than the pure-Coconut baselines (~10 visible tokens / problem) because we keep visible CoT outside the <swi> block; this is a deliberate design choice for verifiability, not a token-level efficiency claim.
  • Naïve model.generate(...) does not activate the hidden-state recurrence; you must use the SWITCH inference loop to reproduce the paper numbers.

License

MIT.

Citation

bibtex

@misc{yang2026demystifyinghiddenstaterecurrenceswitchable,
title = {Demystifying Hidden-State Recurrence: Switchable Latent Reasoning with On-Policy Reinforcement Learning},
author = {Jiayu Yang and Chao Chen and Shengen Wu and Yinhong Liu and Yuxuan Fan and Lujundong Li and Songning Lai and Chengwei Qin and Zhijiang Guo},
year = {2026},
eprint = {2606.13106},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2606.13106}
}

Model provider

LARK-Lab

Model tree

Base

Qwen/Qwen3-8B

Adapter

this model

Modalities

Input

Text

Output

Text

Pricing

Dedicated Endpoints

View details

Supported Functionality

Model APIs

Dedicated Endpoints

Container

More information

Explore FriendliAI today