import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from src.data.block import build_attention_mask, convert_attention_mask_to_model_required
model = AutoModelForCausalLM.from_pretrained("hxia7/Qwen3-8B-block-FT", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("hxia7/Qwen3-8B-block-FT")
blocks = [
"\nYou are an intelligent AI assistant. Please answer questions based on the user's instructions. Below are some reference documents that may help you in answering the user's question.\n\n",
"- Title: Document 1\nContent of document 1...\n",
"- Title: Document 2\nContent of document 2...\n",
"\n\nPlease write a high-quality answer for the given question using only the provided search documents.\nQuestion: What is X?\n\n\n",
]
@torch.no_grad()
def block_generate(model, tokenizer, blocks, max_new_tokens=128):
block_token_counts = []
all_ids = []
for b in blocks:
ids = tokenizer.encode(b, add_special_tokens=False)
all_ids.extend(ids)
block_token_counts.append(len(ids))
input_ids = torch.tensor([all_ids], dtype=torch.int64, device=model.device)
total_len = len(all_ids)
helper = torch.tril(torch.ones(total_len + 64, total_len + 64, dtype=torch.bool))
attn_mask = build_attention_mask(
local_attention_block_tokens=torch.tensor(block_token_counts[:-1], dtype=torch.long),
global_attention_block_tokens=torch.tensor(block_token_counts[-1], dtype=torch.long),
lower_triangular_matrix=helper,
)
attn_mask = convert_attention_mask_to_model_required(attn_mask)
attn_mask = attn_mask.unsqueeze(0).unsqueeze(0).to(model.device)
outputs = model(input_ids=input_ids, attention_mask=attn_mask, use_cache=True)
past_kv = outputs.past_key_values
next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1, keepdim=True)
generated = []
for _ in range(max_new_tokens - 1):
if next_token.item() == tokenizer.eos_token_id:
break
generated.append(next_token.item())
outputs = model(input_ids=next_token, past_key_values=past_kv, use_cache=True)
past_kv = outputs.past_key_values
next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1, keepdim=True)
if next_token.item() != tokenizer.eos_token_id:
generated.append(next_token.item())
return tokenizer.decode(generated, skip_special_tokens=True).strip()
answer = block_generate(model, tokenizer, blocks)
print(answer)