Dedicated Endpoints

Run this model inference on single tenant GPU with unmatched speed and reliability at scale.

Learn more
Container

Run this model inference with full control and performance in your environment.

Learn more

Get help setting up a custom Dedicated Endpoints.

Talk with our engineer to get a quote for reserved GPU instances with discounts.

README

License: mit

⚡ Quick Start — One Cell, Any Notebook

Open in Google Colab (Runtime → Change runtime type → T4 GPU) or any Kaggle notebook and paste this single cell. Change QUESTION to anything you want to ask.

python

# ============================================================
# MedQuery-India-v1 — One-Cell Inference
# Works on Google Colab / Kaggle / any notebook with a T4 GPU
# Just change QUESTION on the last block and run!
# ============================================================
# --- Step 1: Install dependencies (run once) ---
import subprocess
subprocess.run(
["pip", "install", "-q", "transformers", "peft", "bitsandbytes", "accelerate"],
check=True
)
# --- Step 2: Load the model ---
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import torch
BASE_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
ADAPTER = "kanha98/medquery-india-v1"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto"
)
model = PeftModel.from_pretrained(base, ADAPTER)
model.eval()
print("✅ Model loaded successfully!")
# --- Step 3: Ask your question — change this line ↓ ---
QUESTION = "What are the warning signs of severe dengue?"
# -------------------------------------------------------
SYSTEM = (
"You are MedQuery-India, a medical AI assistant trained on Indian healthcare context "
"including AIIMS/NEET clinical protocols, Indian drug brands, regional diseases, "
"Indian procedural guidelines (NTEP, NVBDCP, RSSDI, IAP), and mental health support. "
"Answer accurately, safely, and with cultural sensitivity."
)
prompt = (
f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{SYSTEM}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n{QUESTION}<|eot_id|>"
f"<|start_header_id|>assistant<|end_header_id|>\n"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=250,
temperature=0.3,
do_sample=True,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id,
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True).split("assistant")[-1].strip())

Note: This cell uses transformers + peft + bitsandbytes — no Unsloth required. Works on any free-tier Colab/Kaggle T4 instance (~14.5 GB VRAM).


Model Details

PropertyValue
Base modelmeta-llama/Llama-3.2-1B-Instruct
Parameters1,235,814,400 (1.24B)
Fine-tuning techniqueQLoRA (4-bit NF4 quantization)
LoRA rankr = 64
LoRA alpha128
Target modulesq_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj (7 modules)
Trainable parameters45,088,768 (5.502% of total)
Training hardwareTesla T4 (Kaggle, 14.5GB VRAM)
Final training loss1.5468
Training steps1,030

Why These Decisions

Why Llama-3.2-1B-Instruct?

Three concrete reasons, not vibes:

  1. Tokenizer efficiency on medical vocabulary. Llama-3's 128k BPE vocabulary encodes medical terms like "acetaminophen", "thrombocytopenia", and "leptospirosis" as 1–2 tokens. GPT-2's 50k vocabulary splits the same terms into 4–6 tokens. Fewer tokens per medical term means the model sees more semantic context within the 512-token window — directly relevant for QA where both the question and answer need to fit.

  2. Grouped Query Attention (GQA). Llama-3.2 uses GQA with an 8:1 ratio — 8 query heads share 1 KV head. This reduces KV cache memory significantly compared to standard multi-head attention, enabling longer context at the same VRAM cost.

  3. The 1B sweet spot. Larger than SmolLM2-360M (better reasoning, longer coherent answers), smaller than 3B+ (fits T4 with 4-bit quant, trains in ~4 hours). Every architectural decision in this model is explainable — important for research paper writing and for anyone who wants to reproduce this.

Why QLoRA with r=64?

QLoRA (Dettmers et al., 2023) freezes the base model in 4-bit NF4 quantization and trains only low-rank adapter matrices. This model trains 45M parameters out of 1.28B total — 3.52%.

Why r=64 and not r=16? The standard starting point is r=16. r=64 was chosen here because the task requires cross-domain adaptation — the model needs to simultaneously handle clinical MCQ reasoning (MedMCQA), conversational patient QA (ChatDoctor), structured NIH-style QA (MedQuAD), and Indian-specific synthetic cases. Higher rank gives the adapter more capacity to represent these different response styles without catastrophic interference.

Why 7 target modules?

Most QLoRA implementations target only attention layers (q, k, v, o — 4 modules). This training also targets the MLP layers (gate_proj, up_proj, down_proj) because medical factual recall is stored in the FFN layers, not just attention. Targeting all 7 modules increases trainable parameters from ~11M to ~45M with no inference overhead.

Why 4-bit NF4 quantization?

NF4 (Normal Float 4) is not the same as INT4. NF4 uses a non-uniform quantization grid that matches the normal distribution of neural network weights — meaning it preserves the most information at the center of the weight distribution where most values cluster. INT4 uses a uniform grid and loses more information at those central values. For a medical QA model where precise factual recall matters, NF4 is the correct quantization choice.

Why this optimizer and schedule?

  • AdamW 8-bit (not Adam): AdamW decouples weight decay from the gradient update, which is critical for transformer fine-tuning. Standard Adam applies weight decay incorrectly to adaptive learning rates. 8-bit version saves ~75% optimizer memory with negligible quality loss.
  • Cosine LR scheduler: Smooth decay prevents the loss spikes that linear schedulers cause near the end of training.
  • lr = 1e-4: Standard QLoRA learning rate. 2e-4 causes loss instability on medical data (tested). 5e-5 is too slow for 5 epochs.
  • Effective batch = 32 (4 per device × 8 gradient accumulation): Larger effective batch stabilizes loss on heterogeneous data sources.

Dataset

Total training samples: 6,569 | Val: 780 | Test: 780

SourceSamples%Why included
MedMCQA (Indian)3,61355.0%AIIMS/NEET exam questions — directly Indian clinical context
ChatDoctor1,58824.2%Real patient-doctor conversations — teaches conversational tone
MedQuAD80212.2%NIH structured QA — adds reliable factual grounding
PubMedQA2373.6%Expert-annotated research QA — adds clinical reasoning
Synthetic Indian (general)1442.2%Indian drug names, regional disease context
Synthetic Indian (edge cases)1352.1%Drug safety edge cases specific to India
Synthetic Mental Health500.8%NEET stress, exam pressure, Indian mental health context

Why MedMCQA at 55%? It is the only large-scale dataset sourced directly from AIIMS and NEET PG entrance exams — real Indian clinical questions with expert explanations. No other public dataset captures this.

Why synthetic data? Public medical datasets have near-zero coverage of Indian drug brands, DOTS protocol specifics, or mental health in the Indian exam context. 329 hand-crafted synthetic samples fill this gap directly.


Training Configuration

python

# Hardware
GPU: Tesla T4, 14.5GB VRAM, Kaggle
Framework: Unsloth 2026.6.1 + TRL SFTTrainer
# LoRA
r = 64
lora_alpha = 128 # alpha = 2r — standard scaling
lora_dropout = 0 # dropout off: small dataset, stable training
target_modules = 7 # attention + MLP layers
# Training
num_train_epochs = 5
per_device_train_batch_size = 4
gradient_accumulation_steps = 8
effective_batch_size = 32
warmup_steps = 150
learning_rate = 1e-4
lr_scheduler_type = "cosine"
optim = "adamw_8bit"
weight_decay = 0.01
max_seq_length = 512

Inference (Unsloth — faster, if available)

If you have Unsloth installed, you can use the faster inference path:

python

from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "kanha98/medquery-india-v1",
max_seq_length = 512,
dtype = None,
load_in_4bit = True,
)
FastLanguageModel.for_inference(model)
SYSTEM_PROMPT = (
"You are MedQuery-India, a medical AI assistant trained on Indian "
"healthcare context including AIIMS/NEET clinical protocols, Indian "
"drug brands, regional diseases, Indian procedural guidelines (NTEP, "
"NVBDCP, RSSDI, IAP), and mental health support. Answer accurately, "
"safely, and with cultural sensitivity."
)
question = "What are the warning signs of severe dengue?"
inputs = tokenizer(
f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
f"{SYSTEM_PROMPT}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|>"
f"<|start_header_id|>assistant<|end_header_id|>\n",
return_tensors="pt"
).to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens = 250,
temperature = 0.3,
do_sample = True,
repetition_penalty = 1.1,
pad_token_id = tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response.split("assistant")[-1].strip())

Sample Outputs

Q: What are the warning signs of severe dengue?

Warning signs of severe dengue require immediate hospitalization. These include: extreme lethargy or inability to talk; persistent vomiting; severe abdominal pain; bleeding from the nose, gums, or mouth; very high fever above 40°C; black spots on the skin (petechiae); swelling in the face, neck, or groin; extremely low blood pressure. If you have any one of these symptoms, call your local hospital immediately.

Q: Patient took Combiflam for dengue fever. Is this dangerous?

Combiflam contains ibuprofen (an NSAID) and paracetamol. Taking NSAIDs during dengue is dangerous because they increase the risk of bleeding by inhibiting platelet function — already compromised in dengue. For dengue fever, only paracetamol (like Crocin or Dolo 650) is safe. NSAIDs including Combiflam, Brufen, and Aspirin must be strictly avoided.

Q: What is the DOTS program for tuberculosis in India?

DOTS (Directly Observed Treatment Short-course) is India's national TB control program under the National Tuberculosis Elimination Programme (NTEP), launched in 1992. Core components: sputum microscopy for diagnosis, a standardized short-course drug regimen, direct supervision of treatment, case detection at facility and community levels, and free treatment under the PM-JAY scheme. DOTS is implemented by state and district health departments under the Central TB Division, Ministry of Health and Family Welfare.


Limitations

  • Not a substitute for medical advice. This model is for research and educational purposes. Do not use for clinical diagnosis or treatment decisions.
  • English only (v1). Hindi, Marathi, and Bengali support are planned for v2.
  • 1B parameter ceiling. Complex multi-step clinical reasoning may produce errors. Hallucination risk exists on rare diseases.
  • Training data cutoff. Drug approvals, protocol updates, or guideline changes after the training data may not be reflected.
  • USMLE-style questions. This model was not optimized for Western clinical board exams.

Roadmap

  • Evaluation: ROUGE-L, BERTScore F1, BLEU-4 on 780-sample test set vs base model baseline
  • Hindi language support (v2)
  • Gradio demo on HuggingFace Spaces
  • GGUF conversion for local CPU inference
  • arXiv paper: MedQuery-India: A QLoRA Fine-Tuned LLM for Indian Healthcare Question Answering

Citation

If you use this model in research, please cite:

bibtex

@misc{gupta2025medqueryindia,
author = {Kanha98},
title = {MedQuery-India-v1: QLoRA Fine-Tuning of Llama-3.2-1B for Indian Medical QA},
year = {2025},
url = {https://huggingface.co/kanha98/medquery-india-v1}
}

Author

Kanha98

Built with Unsloth · Base model: meta-llama/Llama-3.2-1B-Instruct

Model provider

kanha98

Model tree

Base

meta-llama/Llama-3.2-1B-Instruct

Adapter

this model

Modalities

Input

Text

Output

Text

Pricing

Dedicated Endpoints

View details

Supported Functionality

Model APIs

Dedicated Endpoints

Container

More information

Explore FriendliAI today