NeuroDiscoveryAI
OpenMed-qwen3-1.7b-RIM
Run this model inference on single tenant GPU with unmatched speed and reliability at scale.
Run this model inference with full control and performance in your environment.
Get help setting up a custom Dedicated Endpoints.
Talk with our engineer to get a quote for reserved GPU instances with discounts.
README
License: apache-2.0How it works
A memory block is the fixed token sequence [<rim_b> <rim_m> <rim_m> <rim_eb>].
We append K blocks after the question; their contextual representations form a
latent workspace. A two-stage curriculum (Stage 1 grounds the blocks against
reasoning steps; Stage 2 refines the final answer across the K blocks) teaches the
model to compute through the blocks. At inference the answer is read out after the
blocks in a single forward pass — no reasoning tokens are generated.
Only the 3 new special-token embeddings are learned from scratch; the rest of the transformer is fine-tuned and the pretrained vocabulary embeddings are frozen.
Results
Greedy accuracy (N=1000/cell; random = 25% on the 4-option OOD sets).
| model | In-dist (held-out) | MedQA (OOD) | MedMCQA (OOD) | latency/query† |
|---|---|---|---|---|
| Base Qwen3-1.7B (zero-shot) | 50.9% | 45.7% | 42.8% | ~7.8 s |
| CoT (explicit SFT) | 47.3% | 42.3% | 42.4% | ~22 s |
| RiM v1 (this model) | 53.6% | 45.1% | 47.2% | 35 ms |
| RiM v2 (MCQ-weighted Stage 2) | 53.2% | 46.9% | 47.2% | 35 ms |
- RiM is best or tied on all three benchmarks while answering ~220× faster than the base and ~630× faster than CoT per query — because it reads the answer out of the memory blocks instead of autoregressively generating a reasoning trace.
- In-distribution pass@8 ≈ 85% (vs ~54% greedy), and accuracy is stable across memory budgets K∈{1,2,4,8}.
- Honest notes: differences on MedQA are within noise (~±1.5%); the explicit-CoT SFT baseline slightly underperforms the zero-shot base here (fine-tuning on the mixed-quality, 91%-open-ended traces modestly hurt the strong base instruct model).
†Latency methodology. Single-request (batch=1) answer generation on one RTX PRO 6000, bf16, warmed up, mean over 32 samples. RiM = 35 ms to generate the answer (the pure forward-pass readout is 12 ms); base/CoT must generate ~520 / ~1460 tokens (~7.8 s / ~22 s). Under large-batch serving the per-sample throughput gap is smaller (≈8 ms vs ≈1 s) but the single-query latency above is what a user waits for one answer.
Usage (single forward pass, no generated reasoning)
python
import torch, refrom transformers import AutoModelForCausalLM, AutoTokenizerREPO = "NDIJayant/OpenMed-qwen3-1.7b-RIM"K, M = 8, 2 # memory blocks; <rim_m> tokens per blocktok = AutoTokenizer.from_pretrained(REPO)model = AutoModelForCausalLM.from_pretrained(REPO, dtype=torch.bfloat16, attn_implementation="sdpa").cuda().eval()b, m, eb = (tok.convert_tokens_to_ids(t) for t in ("<rim_b>", "<rim_m>", "<rim_eb>"))block = [b] + [m] * M + [eb]PREFIX = tok.encode("The final answer is \\boxed{", add_special_tokens=False)@torch.no_grad()def answer(question: str) -> str:q = tok.apply_chat_template([{"role": "user", "content": question}],tokenize=True, add_generation_prompt=True,enable_thinking=False)ids = q + block * K + PREFIXout = model.generate(torch.tensor([ids]).cuda(), max_new_tokens=8,do_sample=False, pad_token_id=tok.eos_token_id)gen = tok.decode(out[0, len(ids):], skip_special_tokens=True)mtch = re.search(r"([A-J])", gen)return mtch.group(1) if mtch else Noneq = ("Which vitamin deficiency causes scurvy?\n""A: Vitamin A\nB: Vitamin B12\nC: Vitamin C\nD: Vitamin D")print(answer(q)) # -> "C"
Use attn_implementation="sdpa" (not flash-attention) if you ever need the custom
masked training path; for this single-pass inference plain causal attention is fine.
Training
- Base:
Qwen/Qwen3-1.7B(dense, full-attention). Data:OpenMed/Medical-Reasoning-SFT-Mega(mixture of multiple-choice + open-ended; trained on the full mixture, evaluated on the MCQ subset). - Stage 1: 6 epochs, one memory block per reasoning step, linear-relative supervision anneal. Stage 2: 2 epochs, K=8 blocks, anytime-answer objective, lower LR + higher dropout. bf16, 8× GPU, custom 4D attention mask (SDPA).
- Code: training/eval/benchmark scripts are released alongside this model.
Limitations
In-distribution eval uses auto-extracted answer letters from a held-out slice of the training dataset. Single model size (1.7B) and seed. English only. The OOD numbers (MedQA/MedMCQA) are 4-option; in-distribution is up to 10-option. Not safe for any real-world medical decision-making.
Citation
bibtex
@article{aichberger2026rim,title = {Unlocking the Working Memory of Large Language Models for Latent Reasoning},author = {Aichberger, Lukas and Hochreiter, Sepp},year = {2026}}
Also cite Qwen/Qwen3-1.7B and OpenMed/Medical-Reasoning-SFT-Mega (both Apache-2.0).
Model provider
NeuroDiscoveryAI
Model tree
Base
Qwen/Qwen3-1.7B
Fine-tuned
this model
Modalities
Input
Text
Output
Text
Pricing
Dedicated Endpoints
View detailsSupported Functionality
Model APIs
Dedicated Endpoints
Container
More information