Dedicated Endpoints

Run this model inference on single tenant GPU with unmatched speed and reliability at scale.

Learn more
Container

Run this model inference with full control and performance in your environment.

Learn more

Get help setting up a custom Dedicated Endpoints.

Talk with our engineer to get a quote for reserved GPU instances with discounts.

README

License: apache-2.0

Pipeline การสร้างโมเดล (ทำซ้ำได้)

ขั้นที่ 1 — Depth Pruning (Layer Dropping)

ตัด decoder layer ตรงกลางทิ้ง (มักทำงานซ้ำซ้อน) เก็บเฉพาะ layer หัว (เข้าใจ input) และ layer ท้าย (สร้าง output) — embedding / lm_head / norm คงเดิม จึงไม่พัง dimension

python

import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B", torch_dtype=torch.bfloat16)
# หา text-decoder layer list (เลี่ยง vision encoder กรณี multimodal)
holder, layers = None, None
for _, mod in model.named_modules():
L = getattr(mod, "layers", None)
if isinstance(L, torch.nn.ModuleList) and len(L) and hasattr(L[0], "self_attn"):
holder, layers = mod, L
if "language" in _.lower() or "text" in _.lower():
break
N = len(layers)
# เก็บ 18 จาก 36 layers: หัว + ท้าย
keep = [0,1,2,3,4,5,6,7,8, 27,28,29,30,31,32,33,34,35]
holder.layers = torch.nn.ModuleList([layers[i] for i in keep])
# อัปเดต config (รองรับ nested text_config ของ Gemma3)
for c in {model.config, getattr(model.config, "text_config", model.config)}:
if getattr(c, "num_hidden_layers", None) is not None:
c.num_hidden_layers = len(keep)
lt = getattr(c, "layer_types", None)
if isinstance(lt, list) and len(lt) == N:
c.layer_types = [lt[i] for i in keep]
# reindex layer_idx ของแต่ละ block (สำคัญต่อ KV cache)
for i, lyr in enumerate(holder.layers):
if hasattr(lyr, "self_attn") and hasattr(lyr.self_attn, "layer_idx"):
lyr.self_attn.layer_idx = i

ผลลัพธ์: 3.09B -> 1.70B (ยังไม่ถึง 1B เป๊ะ เพราะ embedding+lm_head+vocab ไม่ลดตาม layer)

หลัง prune โมเดลจะพ่น gibberish ทันที (เส้นประสาทถูกตัดขาด) -> ต้อง Healing ต่อ

ขั้นที่ 2 — Healing SFT

เทรนต่อด้วย causal-LM บน Thai corpus เพื่อให้ layer ที่เหลือกลับมาทำงานร่วมกัน

python

from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
ds = load_dataset("SPAISS6F1/spai-ss6-llm-1b-thai-corpus", split="train") # หรือ SEA-PILE v2 'th'
tds = ds.map(lambda e: tok(e["text"], truncation=True, max_length=1024),
batched=True, remove_columns=ds.column_names)
model.gradient_checkpointing_enable(); model.config.use_cache = False
args = TrainingArguments(output_dir="out", num_train_epochs=2,
per_device_train_batch_size=4, gradient_accumulation_steps=4,
learning_rate=1e-4, lr_scheduler_type="cosine", warmup_ratio=0.03, bf16=True)
Trainer(model=model, args=args, train_dataset=tds,
data_collator=DataCollatorForLanguageModeling(tok, mlm=False)).train()

Hyperparameters:

  • Learning rate: 1e-4 (สูงกว่าปกติเพื่อสมานแผล) | Epochs: 2
  • Batch 4 x grad-accum 4 (effective 16) | max_len 1024 | bf16
  • Optimizer: AdamW + cosine schedule, warmup 3%
  • Env: conda myenv (transformers 4.49)

ขั้นที่ 3 — Save

python

model.config.use_cache = True
try:
model.save_pretrained("out", safe_serialization=True)
except RuntimeError: # Gemma3: tied embeddings -> fallback .bin
model.save_pretrained("out", safe_serialization=False)

โมเดลนี้ save เป็น: model.safetensors


วิธีใช้ (Inference)

python

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
m = "SPAISS6F1/qwen-1b-pruned-th"
tok = AutoTokenizer.from_pretrained(m)
model = AutoModelForCausalLM.from_pretrained(m, torch_dtype=torch.bfloat16, device_map="cuda")
ids = tok("ปัญญาประดิษฐ์ คือ", return_tensors="pt").to(model.device)
out = model.generate(**ids, max_new_tokens=120, do_sample=True,
temperature=0.7, top_p=0.9, repetition_penalty=1.3)
print(tok.decode(out[0], skip_special_tokens=True))

ข้อควรรู้ / ข้อจำกัด

  • เป็น pruned base ที่ heal ด้วย raw web corpus -> ไวยากรณ์ไทยลื่นไหลดี แต่ ข้อเท็จจริงและการคิดเลขยังอ่อน (ยังไม่ผ่าน instruction tuning)
  • แนะนำ repetition_penalty >= 1.2 กันการวนซ้ำ
  • เหมาะเป็น base สำหรับ fine-tune ต่อด้วย instruction dataset มากกว่าใช้ตอบตรง ๆ
  • การตัด layer 50% เป็นการตัดที่ค่อนข้างหนัก (งานวิจัย เช่น ShortGPT แนะ ~25%); ถ้าต้องการคุณภาพสูงขึ้นควร heal นานขึ้น/ตัดเบาลง

Model provider

SPAISS6F1

Model tree

Base

this model

Modalities

Input

Text

Output

Text

Pricing

Dedicated Endpoints

View details

Supported Functionality

Model APIs

Dedicated Endpoints

Container

More information

Explore FriendliAI today