Dedicated Endpoints

Run this model inference on single tenant GPU with unmatched speed and reliability at scale.

Learn more
Container

Run this model inference with full control and performance in your environment.

Learn more

Get help setting up a custom Dedicated Endpoints.

Talk with our engineer to get a quote for reserved GPU instances with discounts.

README

License: mit

Architecture

Injects modulation layers into Qwen3-4B layers 16-28. Each layer has an independent scale projection from the latent state S_t.

markdown

hs = hs * (1 + alpha * tanh(W_l * S_t))

Contents

  • adapter_model.safetensors — PEFT LoRA adapter (Phase 2.5 persona alignment)
  • plaa_full.pt — PlaaCore GRU + FiLM scale_proj weights
  • modeling_plaa.py — FiLMLayer + PlaaCore definition
  • config.json — PEFT adapter config

Loading

python

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
# Load base model
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)
base = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B",
quantization_config=bnb, device_map="auto",
trust_remote_code=True, torch_dtype=torch.bfloat16)
# Inject FiLM layers
from modeling_plaa import FiLMLayer, PlaaCore
for i in range(16, 29):
base.model.layers[i] = FiLMLayer(base.model.layers[i])
# Load PEFT adapter
peft = PeftModel.from_pretrained(base, "./", adapter_name="plaa")
peft.set_adapter("plaa")
# Load FiLM weights
import torch
ckpt = torch.load("./plaa_full.pt")
plaa_core = PlaaCore()
plaa_core.load_state_dict(ckpt["plaa_core"])
for i in range(16, 29):
peft.base_model.model.model.layers[i].scale_proj.load_state_dict(
ckpt["scale_proj"][i])
peft.base_model.model.model.layers[i].cuda()
# Inference
S = plaa_core.init_state(1)
for i in range(16, 29):
peft.base_model.model.model.layers[i]._s = S
inp = tokenizer(["Hello"], return_tensors="pt").to("cuda")
out = peft.generate(**inp, max_new_tokens=50)

Results

ConditionPure LM Loss
Vanilla Qwen3-4B3.53
Trained mFiLM2.70
FiLM removed2.74
State frozen2.70

Causal ablation Δ = 0.044. See paper for details.

Model provider

JimmyXiao091130

Model tree

Base

Qwen/Qwen3-4B

Adapter

this model

Modalities

Input

Text

Output

Text

Pricing

Dedicated Endpoints

View details

Supported Functionality

Model APIs

Dedicated Endpoints

Container

More information

Explore FriendliAI today