pip install torch transformers pillow
Minimal Captioning Example
import os
import torch
from PIL import Image
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
# =========================
# Config
# =========================
MODEL_PATH = "Felldude/Ministral-3-8B-Uncensored"
FOLDER = "images"
PROMPT = "Describe this image in detail."
MAX_TOKENS = 512
VALID_EXTS = {".png", ".jpg", ".jpeg", ".webp"}
# =========================
# GPU setup
# =========================
if not torch.cuda.is_available():
raise RuntimeError("CUDA GPU required")
device = "cuda"
dtype = (
torch.bfloat16
if torch.cuda.is_bf16_supported()
else torch.float16
)
# =========================
# Load model
# =========================
processor = AutoProcessor.from_pretrained(
MODEL_PATH,
trust_remote_code=True
)
model = Mistral3ForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=dtype,
trust_remote_code=True,
attn_implementation="sdpa"
).to(device)
model.eval()
# =========================
# Caption function
# =========================
def generate_caption(image):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": PROMPT},
],
}
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
)
inputs = {
k: v.to(device)
for k, v in inputs.items()
}
with torch.inference_mode():
output = model.generate(
**inputs,
max_new_tokens=MAX_TOKENS,
do_sample=False
)
trimmed = [
o[len(i):]
for i, o in zip(inputs["input_ids"], output)
]
return processor.batch_decode(
trimmed,
skip_special_tokens=True
)[0].strip()
# =========================
# Process folder
# =========================
for filename in os.listdir(FOLDER):
ext = os.path.splitext(filename)[1].lower()
if ext not in VALID_EXTS:
continue
path = os.path.join(FOLDER, filename)
print("Processing:", filename)
try:
image = Image.open(path).convert("RGB")
caption = generate_caption(image)
txt_path = os.path.splitext(path)[0] + ".txt"
with open(txt_path, "w", encoding="utf-8") as f:
f.write(caption)
print(caption)
except Exception as e:
print("Failed:", e)