import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
base_model = "microsoft/Phi-3-mini-4k-instruct"
adapter = "Sid9797/querycraft-phi3-sql"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
device_map="cuda:0",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
model = PeftModel.from_pretrained(model, adapter)
model.eval()
prompt = '''### System:
You are a SQL expert. Given a database schema and a natural language question, generate a valid SQL query that answers the question. Output only the SQL query with no explanation.
### Schema:
CREATE TABLE employees (id INTEGER, name VARCHAR, department VARCHAR, salary FLOAT)
### Question:
What is the average salary by department?
### SQL:
'''
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=128, do_sample=False)
sql = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(sql.strip().split("\n")[0])