Run this model inference on single tenant GPU with unmatched speed and reliability at scale.
Run this model inference with full control and performance in your environment.
Get help setting up a custom Dedicated Endpoints.
Talk with our engineer to get a quote for reserved GPU instances with discounts.
README
License: apache-2.0Pipeline การสร้างโมเดล (ทำซ้ำได้)
ขั้นที่ 1 — Depth Pruning (Layer Dropping)
ตัด decoder layer ตรงกลางทิ้ง (มักทำงานซ้ำซ้อน) เก็บเฉพาะ layer หัว (เข้าใจ input) และ layer ท้าย (สร้าง output) — embedding / lm_head / norm คงเดิม จึงไม่พัง dimension
python
import torchfrom transformers import AutoModelForCausalLMmodel = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B", torch_dtype=torch.bfloat16)# หา text-decoder layer list (เลี่ยง vision encoder กรณี multimodal)holder, layers = None, Nonefor _, mod in model.named_modules():L = getattr(mod, "layers", None)if isinstance(L, torch.nn.ModuleList) and len(L) and hasattr(L[0], "self_attn"):holder, layers = mod, Lif "language" in _.lower() or "text" in _.lower():breakN = len(layers)# เก็บ 18 จาก 36 layers: หัว + ท้ายkeep = [0,1,2,3,4,5,6,7,8, 27,28,29,30,31,32,33,34,35]holder.layers = torch.nn.ModuleList([layers[i] for i in keep])# อัปเดต config (รองรับ nested text_config ของ Gemma3)for c in {model.config, getattr(model.config, "text_config", model.config)}:if getattr(c, "num_hidden_layers", None) is not None:c.num_hidden_layers = len(keep)lt = getattr(c, "layer_types", None)if isinstance(lt, list) and len(lt) == N:c.layer_types = [lt[i] for i in keep]# reindex layer_idx ของแต่ละ block (สำคัญต่อ KV cache)for i, lyr in enumerate(holder.layers):if hasattr(lyr, "self_attn") and hasattr(lyr.self_attn, "layer_idx"):lyr.self_attn.layer_idx = i
ผลลัพธ์: 3.09B -> 1.70B (ยังไม่ถึง 1B เป๊ะ เพราะ embedding+lm_head+vocab ไม่ลดตาม layer)
หลัง prune โมเดลจะพ่น gibberish ทันที (เส้นประสาทถูกตัดขาด) -> ต้อง Healing ต่อ
ขั้นที่ 2 — Healing SFT
เทรนต่อด้วย causal-LM บน Thai corpus เพื่อให้ layer ที่เหลือกลับมาทำงานร่วมกัน
python
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModelingfrom datasets import load_datasetds = load_dataset("SPAISS6F1/spai-ss6-llm-1b-thai-corpus", split="train") # หรือ SEA-PILE v2 'th'tds = ds.map(lambda e: tok(e["text"], truncation=True, max_length=1024),batched=True, remove_columns=ds.column_names)model.gradient_checkpointing_enable(); model.config.use_cache = Falseargs = TrainingArguments(output_dir="out", num_train_epochs=2,per_device_train_batch_size=4, gradient_accumulation_steps=4,learning_rate=1e-4, lr_scheduler_type="cosine", warmup_ratio=0.03, bf16=True)Trainer(model=model, args=args, train_dataset=tds,data_collator=DataCollatorForLanguageModeling(tok, mlm=False)).train()
Hyperparameters:
- Learning rate:
1e-4(สูงกว่าปกติเพื่อสมานแผล) | Epochs: 2 - Batch 4 x grad-accum 4 (effective 16) | max_len 1024 | bf16
- Optimizer: AdamW + cosine schedule, warmup 3%
- Env: conda myenv (transformers 4.49)
ขั้นที่ 3 — Save
python
model.config.use_cache = Truetry:model.save_pretrained("out", safe_serialization=True)except RuntimeError: # Gemma3: tied embeddings -> fallback .binmodel.save_pretrained("out", safe_serialization=False)
โมเดลนี้ save เป็น: model.safetensors
วิธีใช้ (Inference)
python
import torchfrom transformers import AutoModelForCausalLM, AutoTokenizerm = "SPAISS6F1/qwen-1b-pruned-th"tok = AutoTokenizer.from_pretrained(m)model = AutoModelForCausalLM.from_pretrained(m, torch_dtype=torch.bfloat16, device_map="cuda")ids = tok("ปัญญาประดิษฐ์ คือ", return_tensors="pt").to(model.device)out = model.generate(**ids, max_new_tokens=120, do_sample=True,temperature=0.7, top_p=0.9, repetition_penalty=1.3)print(tok.decode(out[0], skip_special_tokens=True))
ข้อควรรู้ / ข้อจำกัด
- เป็น pruned base ที่ heal ด้วย raw web corpus -> ไวยากรณ์ไทยลื่นไหลดี แต่ ข้อเท็จจริงและการคิดเลขยังอ่อน (ยังไม่ผ่าน instruction tuning)
- แนะนำ
repetition_penalty >= 1.2กันการวนซ้ำ - เหมาะเป็น base สำหรับ fine-tune ต่อด้วย instruction dataset มากกว่าใช้ตอบตรง ๆ
- การตัด layer 50% เป็นการตัดที่ค่อนข้างหนัก (งานวิจัย เช่น ShortGPT แนะ ~25%); ถ้าต้องการคุณภาพสูงขึ้นควร heal นานขึ้น/ตัดเบาลง
Model provider
SPAISS6F1
Model tree
Base
this model
Modalities
Input
Text
Output
Text
Pricing
Dedicated Endpoints
View detailsSupported Functionality
Model APIs
Dedicated Endpoints
Container
More information