import torch
from transformers import AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
model = AutoModel.from_pretrained(
"OpenGVLab/InternVL3-1B",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("OpenGVLab/InternVL3-1B", trust_remote_code=True)
adapter_path = hf_hub_download("blind-assist/internvl3-1b-walk-lora-Epoch_1.85-8500-without_ES_v1", "adapter_model.safetensors")
adapter_weights = load_file(adapter_path)
model_state = model.state_dict()
scaling = 1.0
for key in adapter_weights:
if '.lora_A.' in key:
lora_b_key = key.replace('.lora_A.', '.lora_B.')
if lora_b_key in adapter_weights:
model_key = key.replace('.lora_A.', '.').replace('base_model.model.', '')
if model_key in model_state:
lora_a = adapter_weights[key].float().to(model_state[model_key].device)
lora_b = adapter_weights[lora_b_key].float().to(model_state[model_key].device)
delta = torch.matmul(lora_b, lora_a) * scaling
model_state[model_key] = model_state[model_key].float() + delta
model_state[model_key] = model_state[model_key].to(torch.bfloat16)
elif '.lora_B.' not in key:
model_key = key.replace('base_model.model.', '')
if model_key in model_state and model_state[model_key].shape == adapter_weights[key].shape:
model_state[model_key] = adapter_weights[key].to(model_state[model_key].device)
model.load_state_dict(model_state)
model.eval()
prompt = "Given the visual input from the user's forward perspective, generate exactly one short sentence to guide a visually impaired user by identifying critical obstacles or landmarks, describing their locations using clock directions relative to the user (12 o'clock is straight ahead), including relevant details such as size, material, or distance, and giving one clear action, while prioritizing immediate safety and avoiding any extra explanation."
response = model.chat(
tokenizer=tokenizer,
pixel_values=your_image_tensor,
question=prompt,
generation_config=dict(max_new_tokens=256, do_sample=False)
)
print(response)