!pip -q install -U transformers peft accelerate huggingface_hub
!pip -q uninstall -y torchvision bitsandbytes torchao
import inspect, json, os, sys
from pathlib import Path
from threading import Thread
import torch
from huggingface_hub import snapshot_download
ADAPTER_ID = os.environ.get("ADAPTER_ID", "ehzawad/ec-SFT-qwen25-7b-lora")
ADAPTER_REV = os.environ.get("ADAPTER_REVISION")
ADAPTER_DIR = Path(os.environ.get("ADAPTER_DIR", "/content/adapter"))
ADAPTER_DIR.mkdir(parents=True, exist_ok=True)
snapshot_download(repo_id=ADAPTER_ID, revision=ADAPTER_REV, local_dir=str(ADAPTER_DIR))
required = ["adapter_config.json", "adapter_model.safetensors", "system_prompt.txt",
"tokenizer.json", "tokenizer_config.json"]
missing = [f for f in required if not (ADAPTER_DIR / f).is_file()]
assert not missing, f"adapter dir {ADAPTER_DIR} missing files: {missing}"
FORCE_RELOAD = False
if (not FORCE_RELOAD) and all(n in globals() for n in ("model", "tokenizer", "SYSTEM_PROMPT")):
print("OK reusing already-loaded model (set FORCE_RELOAD=True to refresh)")
else:
import transformers.utils.import_utils as _iu
_iu.is_torchvision_available = _iu.is_torchvision_v2_available = (lambda: False)
if hasattr(_iu, "is_torchao_available"):
_iu.is_torchao_available = (lambda: False)
import peft.import_utils as _piu
for _n in ("is_bnb_available", "is_bnb_4bit_available", "is_torchao_available"):
if hasattr(_piu, _n): setattr(_piu, _n, lambda *a, **k: False)
for _m in list(sys.modules.values()):
if getattr(_m, "__name__", "").startswith("peft."):
for _n in ("is_bnb_available", "is_bnb_4bit_available", "is_torchao_available"):
if hasattr(_m, _n): setattr(_m, _n, lambda *a, **k: False)
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
assert torch.cuda.is_available(), "No GPU found. In Colab: Runtime > Change runtime type > GPU."
print("GPU:", torch.cuda.get_device_name(0))
BASE_MODEL = (
json.loads((ADAPTER_DIR / "training_args.json").read_text()).get("base_model", "Qwen/Qwen2.5-7B-Instruct")
if (ADAPTER_DIR / "training_args.json").is_file() else "Qwen/Qwen2.5-7B-Instruct"
)
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
tokenizer = AutoTokenizer.from_pretrained(str(ADAPTER_DIR))
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
assert tokenizer.vocab_size > 100000, (
f"tokenizer load looks degenerate (vocab_size={tokenizer.vocab_size}); "
f"tokenizer.json should ship in the adapter repo")
assert tokenizer.chat_template, "no chat_template loaded; expected chat_template.jinja in adapter dir"
_dtype_kw = "dtype" if "dtype" in inspect.signature(AutoModelForCausalLM.from_pretrained).parameters else "torch_dtype"
base = AutoModelForCausalLM.from_pretrained(BASE_MODEL, **{_dtype_kw: DTYPE},
device_map={"": 0}, attn_implementation="sdpa")
model = PeftModel.from_pretrained(base, str(ADAPTER_DIR)).eval()
model.config.use_cache = True
SYSTEM_PROMPT = (ADAPTER_DIR / "system_prompt.txt").read_text(encoding="utf-8")
from transformers import TextIteratorStreamer
def answer(question: str, max_new_tokens: int = 1024) -> str:
messages = [{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": question}]
rendered = tokenizer.apply_chat_template(messages, add_generation_prompt=True,
tokenize=True, return_tensors="pt")
input_ids = rendered.input_ids if hasattr(rendered, "input_ids") else rendered
input_ids = input_ids.to(next(model.parameters()).device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=300.0)
gen_kwargs = dict(
input_ids=input_ids, attention_mask=torch.ones_like(input_ids),
max_new_tokens=max_new_tokens, do_sample=False,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
use_cache=True, streamer=streamer,
)
gen_error = {}
def _gen():
try:
with torch.inference_mode():
model.generate(**gen_kwargs)
except BaseException as e:
gen_error["exc"] = e
th = Thread(target=_gen); th.start()
chunks = []
for chunk in streamer:
print(chunk, end="", flush=True)
chunks.append(chunk)
th.join()
if gen_error:
e = gen_error["exc"]
print(f"\n[generation thread raised {type(e).__name__}: {e}]")
reply = "".join(chunks).strip()
if not reply:
print("[WARN: 0 chars generated — check tokenizer vocab_size and chat template]")
return reply
MAX_POS = getattr(model.config, "max_position_embeddings", 32768)
max_new_tokens = 1024
history = []
def _to_ids(enc):
if hasattr(enc, "input_ids"): enc = enc.input_ids
if isinstance(enc, list):
enc = torch.tensor([enc] if not enc or isinstance(enc[0], int) else enc, dtype=torch.long)
if enc.dim() == 1: enc = enc.unsqueeze(0)
return enc
while True:
try:
q = input("USER> ").strip()
except (EOFError, KeyboardInterrupt):
print(); break
if not q: continue
if q in ("/exit", "/quit"): break
if q == "/reset": history.clear(); print("[cleared]"); continue
if q.startswith("/tokens"):
parts = q.split()
if len(parts) == 2 and parts[1].isdigit():
max_new_tokens = int(parts[1])
print(f"[max_new_tokens={max_new_tokens}]")
else:
print("[usage: /tokens 1024]")
continue
messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history + [{"role": "user", "content": q}]
ids = _to_ids(tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"))
ids = ids.to(next(model.parameters()).device)
cap = min(max_new_tokens, max(64, MAX_POS - ids.shape[1] - 32))
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=300.0)
gen_kwargs = dict(
input_ids=ids, attention_mask=torch.ones_like(ids),
max_new_tokens=cap, do_sample=False,
eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id,
use_cache=True, streamer=streamer,
)
gen_error = {}
def _gen():
try:
with torch.inference_mode():
model.generate(**gen_kwargs)
except BaseException as e:
gen_error["exc"] = e
thread = Thread(target=_gen)
thread.start()
print("BOT > ", end="", flush=True)
chunks = []
try:
for chunk in streamer:
print(chunk, end="", flush=True)
chunks.append(chunk)
except Exception as e:
print(f"\n[streamer error: {type(e).__name__}: {e}]")
thread.join()
print()
if gen_error:
e = gen_error["exc"]
print(f"[generation thread raised {type(e).__name__}: {e}]")
reply = "".join(chunks).strip()
if not reply:
print("[WARN: 0 chars generated — check tokenizer vocab_size and chat template]")
history += [{"role": "user", "content": q}, {"role": "assistant", "content": reply}]