import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
BASE = "Qwen/Qwen2.5-7B-Instruct"
ADAPTER = "ArsenyIvanov/toolace-halu-qwen-lora"
tokenizer = AutoTokenizer.from_pretrained(ADAPTER)
tokenizer.padding_side = "left"
base = AutoModelForCausalLM.from_pretrained(
BASE, torch_dtype=torch.bfloat16, attn_implementation="eager"
).to("cuda").eval()
model = PeftModel.from_pretrained(base, ADAPTER).to("cuda").eval()
SYSTEM = (
"You are a hallucination detector for tool-augmented dialogues. "
"Given the tool context, the available tools, the user query and the assistant answer, "
"rewrite the assistant answer wrapping every hallucinated span in "
'<halu type="contradiction">...</halu>, <halu type="missing_tool">...</halu> '
'or <halu type="overgeneration">...</halu> tags. '
"Do not alter any other characters. If the answer contains no hallucinations, return it unchanged."
)
def detect(query, tool_context, tool_names, answer):
user = (
f"[Tool context]\n{tool_context}\n\n"
f"[Available tools]\n{', '.join(tool_names)}\n\n"
f"[User query]\n{query}\n\n"
f"[Assistant answer]\n{answer}\n\n"
"Now rewrite the assistant answer above with <halu> markers around hallucinated spans."
)
msgs = [{"role": "system", "content": SYSTEM},
{"role": "user", "content": user}]
prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
enc = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1536).to("cuda")
with torch.no_grad():
gen = model.generate(**enc, max_new_tokens=512, do_sample=False,
pad_token_id=tokenizer.pad_token_id)
completion = tokenizer.decode(gen[0, enc["input_ids"].shape[1]:], skip_special_tokens=True)
spans, cursor = [], 0
HALU_RE = re.compile(r'<halu type="(contradiction|missing_tool|overgeneration)">(.+?)</halu>', re.DOTALL)
for m in HALU_RE.finditer(completion):
ttype, inner = m.group(1), m.group(2)
idx = answer.find(inner, cursor)
if idx == -1: idx = answer.find(inner)
if idx == -1: continue
spans.append({"start": idx, "end": idx + len(inner), "text": inner, "label": ttype})
cursor = idx + len(inner)
return {"marked": completion, "spans": spans}