Run this model inference on single tenant GPU with unmatched speed and reliability at scale.
Run this model inference with full control and performance in your environment.
Get help setting up a custom Dedicated Endpoints.
Talk with our engineer to get a quote for reserved GPU instances with discounts.
README
License: apache-2.0Overview
This model is fine-tuned using the Block-Attention mechanism from Block-Attention for Efficient Prefilling. Block-Attention divides the input context into independent blocks during the prefill phase, enabling KV cache reuse across different queries on the same documents — a key optimization for RAG serving.
Training Data Control Variable: This model was fine-tuned on an 8K subset of the Tulu3-Block-FT-RAG dataset. A companion Llama-3.2-1B model uses the full 80K samples for comparison.
Evaluation Results
On Unseen TriviaQA Validation Set (100 clean samples)
Questions and evidence passages from TriviaQA RC validation split, excluded from training data. Substr-EM checks whether the correct answer appears as a substring in the model's response.
| Model | Substr-EM | F1 Score |
|---|---|---|
| meta-llama/Llama-3.2-1B (base) | 56.00% | 12.51% |
| meta-llama/Llama-3.2-1B-Instruct | 86.00% | 23.62% |
| hxia7/Llama-3.2-1B-block-FT (full-attention) | 87.00% | 26.59% |
| hxia7/Llama-3.2-1B-block-FT (block-attention) | 88.00% | 27.53% |
| hxia7/Qwen3-8B-block-FT (full-attention) | 91.00% | 25.18% |
| hxia7/Qwen3-8B-block-FT (block-attention) | 90.00% | 23.71% |
Key observations:
- Block-attention and full-attention produce comparable results (91% vs 90% Substr-EM), confirming the block-attention structure preserves quality.
- Despite training on only 8K samples (vs 80K for Llama), the Qwen3-8B model achieves the highest Substr-EM at 91%, demonstrating the benefit of a larger base model.
- The evidence passages from TriviaQA differ from the Contriever-retrieved passages used in training, making this a meaningful out-of-distribution test.
Block-Attention Mechanism
In Block-Attention, the context is split into N blocks:
- Blocks 1..N-1 (document blocks): Use local attention — each block attends only to itself
- Block N (query block): Uses global attention — attends to all previous blocks
This isolation allows document blocks' KV states to be computed once and reused across multiple queries.
Training Details
- Base Model: Qwen/Qwen3-8B
- Training Data: Tulu3-Block-FT-RAG (8K subset)
- Epochs: 1
- Learning Rate: 2e-6
- Optimizer: AdamW (fused)
- Precision: BF16
- DeepSpeed: ZeRO Stage 2 with CPU optimizer offload
- Loss Reduction: sum (over non-masked tokens)
During training, each sample produces two variants:
- Full-attention version (standard causal mask)
- Block-attention version (with
[Block-Attention]prefix token and 4D block mask)
Both variants contribute to the loss, teaching the model to handle both inference modes.
Inference
Block-Attention Inference (recommended for RAG)
Important: Block-Attention uses a 4D attention mask [1, 1, seq_len, seq_len] during prefill. model.generate() only accepts 2D masks, so inference requires manual prefill + autoregressive decode:
python
import torchfrom transformers import AutoTokenizer, AutoModelForCausalLMfrom src.data.block import build_attention_mask, convert_attention_mask_to_model_requiredmodel = AutoModelForCausalLM.from_pretrained("hxia7/Qwen3-8B-block-FT", torch_dtype=torch.bfloat16, device_map="auto")tokenizer = AutoTokenizer.from_pretrained("hxia7/Qwen3-8B-block-FT")blocks = ["\nYou are an intelligent AI assistant. Please answer questions based on the user's instructions. Below are some reference documents that may help you in answering the user's question.\n\n","- Title: Document 1\nContent of document 1...\n","- Title: Document 2\nContent of document 2...\n","\n\nPlease write a high-quality answer for the given question using only the provided search documents.\nQuestion: What is X?\n\n\n",]@torch.no_grad()def block_generate(model, tokenizer, blocks, max_new_tokens=128):block_token_counts = []all_ids = []for b in blocks:ids = tokenizer.encode(b, add_special_tokens=False)all_ids.extend(ids)block_token_counts.append(len(ids))input_ids = torch.tensor([all_ids], dtype=torch.int64, device=model.device)total_len = len(all_ids)helper = torch.tril(torch.ones(total_len + 64, total_len + 64, dtype=torch.bool))attn_mask = build_attention_mask(local_attention_block_tokens=torch.tensor(block_token_counts[:-1], dtype=torch.long),global_attention_block_tokens=torch.tensor(block_token_counts[-1], dtype=torch.long),lower_triangular_matrix=helper,)attn_mask = convert_attention_mask_to_model_required(attn_mask)attn_mask = attn_mask.unsqueeze(0).unsqueeze(0).to(model.device)outputs = model(input_ids=input_ids, attention_mask=attn_mask, use_cache=True)past_kv = outputs.past_key_valuesnext_token = torch.argmax(outputs.logits[:, -1, :], dim=-1, keepdim=True)generated = []for _ in range(max_new_tokens - 1):if next_token.item() == tokenizer.eos_token_id:breakgenerated.append(next_token.item())outputs = model(input_ids=next_token, past_key_values=past_kv, use_cache=True)past_kv = outputs.past_key_valuesnext_token = torch.argmax(outputs.logits[:, -1, :], dim=-1, keepdim=True)if next_token.item() != tokenizer.eos_token_id:generated.append(next_token.item())return tokenizer.decode(generated, skip_special_tokens=True).strip()answer = block_generate(model, tokenizer, blocks)print(answer)
Full-Attention Inference (standard)
python
from transformers import AutoTokenizer, AutoModelForCausalLMmodel = AutoModelForCausalLM.from_pretrained("hxia7/Qwen3-8B-block-FT", torch_dtype=torch.bfloat16, device_map="auto")tokenizer = AutoTokenizer.from_pretrained("hxia7/Qwen3-8B-block-FT")prompt = "Your full RAG prompt here..."inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=3968).to(model.device)outputs = model.generate(**inputs, max_new_tokens=128, do_sample=False, pad_token_id=tokenizer.eos_token_id)answer = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)print(answer)
References
Model provider
hxia7
Model tree
Base
Qwen/Qwen3-8B
Fine-tuned
this model
Modalities
Input
Text
Output
Text
Pricing
Dedicated Endpoints
View detailsSupported Functionality
Model APIs
Dedicated Endpoints
Container
More information