Run this model inference on single tenant GPU with unmatched speed and reliability at scale.
Run this model inference with full control and performance in your environment.
Get help setting up a custom Dedicated Endpoints.
Talk with our engineer to get a quote for reserved GPU instances with discounts.
README
License: apache-2.0Training pipeline
Stage 1: Supervised fine-tuning (SFT)
LoRA fine-tuning on ~43k neuroscience MCQs generated from a knowledge graph: 12,500 single-hop questions and 30,750 two-hop questions. Questions require chain-of-thought reasoning to identify the correct answer from four options. The LoRA adapters were merged into the base weights before stage 2.
Stage 2: Reinforcement learning with GRPO
Full fine-tuning (no LoRA) starting from the merged SFT checkpoint. Trained on 5,000 validated two-hop neuroscience questions using Group Relative Policy Optimization (GRPO) via TRL. Each training step generates 4 completions per prompt across 4 GPUs with gradient accumulation of 16 steps, effective batch of 64 prompts per update.
Three reward functions were combined:
- Correctness: +1 for correct answer, -1 otherwise, with a smooth length penalty that activates above 550 tokens and caps at 1500 tokens
- Format: up to +0.2 for producing all required XML tags in the correct order
- Path alignment: up to +0.8 F1 score between knowledge graph path tokens and the first 550 tokens of the model's thinking, gated on both correct answer and valid format
Checkpoint 1000 is at step 1000 of 3120 planned steps, approximately 3.2 epochs into training.
Architecture and training config
| Parameter | Value |
|---|---|
| Base model | Qwen/Qwen3-14B |
| Architecture | Qwen3ForCausalLM |
| Hidden size | 5120 |
| Layers | 40 |
| Attention heads | 40 (8 KV heads) |
| Context length | 40960 |
| Dtype | bfloat16 |
| Learning rate | 8e-7 |
| KL penalty beta | 0.12 |
| LR schedule | constant with warmup (5%) |
| Optimizer | AdamW |
| Generations per prompt | 4 |
| Max completion length | 1280 tokens |
Output format
The model expects neuroscience MCQ prompts and produces structured output:
markdown
<think>step-by-step reasoning over the question</think><explanation>concise explanation of the correct answer</explanation><answer>B</answer>
Usage
python
from transformers import AutoTokenizer, AutoModelForCausalLMimport torchmodel_id = "JakeStephen/neuro-si-SFT-RL"tokenizer = AutoTokenizer.from_pretrained(model_id)model = AutoModelForCausalLM.from_pretrained(model_id,torch_dtype=torch.bfloat16,device_map="auto",)SYSTEM_PROMPT = ("A conversation between user and assistant. The user asks a single-choice ""Multiple Choice Question, and the assistant solves it using step-by-step ""reasoning. Please answer the multiple choice question by selecting only one ""from option A, option B, option C, option D.\n\n""The assistant first thinks through the problem systematically, then provides ""the explanation and final answer. Use <think>...</think> tags for internal ""reasoning, then provide the explanation process and answer enclosed within ""<explanation> </explanation> and <answer> </answer> tags, respectively.")question = """Which neurotransmitter is primarily released at the neuromuscular junction?Options:A. DopamineB. SerotoninC. AcetylcholineD. GABA/think"""messages = [{"role": "system", "content": SYSTEM_PROMPT},{"role": "user", "content": question},]text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)inputs = tokenizer(text, return_tensors="pt").to(model.device)with torch.no_grad():outputs = model.generate(**inputs,max_new_tokens=1280,temperature=0.6,top_p=0.9,repetition_penalty=1.15,)response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)print(response)
Test set
neurobench_test_set.json is included in this repo. It contains 5,000 two-hop neuroscience MCQs from the same knowledge graph used for training (NeuroBench). Each item includes the question, four options, a chain-of-thought explanation, the correct answer, and the knowledge graph paths that connect the source and target concepts. Some of these questions overlap with the RL training set. The file can be used for inference benchmarking or qualitative inspection of model behavior.
Framework versions
- TRL: 0.29.0
- Transformers: 5.3.0
- PyTorch: 2.6.0+cu124
- Datasets: 4.7.0
Model provider
JakeStephen
Model tree
Base
Qwen/Qwen3-14B
Fine-tuned
this model
Modalities
Input
Text
Output
Text
Pricing
Dedicated Endpoints
View detailsSupported Functionality
Model APIs
Dedicated Endpoints
Container
More information