Skip to content

Commit

Permalink
model: Add support for PhiMoE arch (#11003)
Browse files Browse the repository at this point in the history
* model: support phimoe

* python linter

* doc: minor

Co-authored-by: ThiloteE <[email protected]>

* doc: minor

Co-authored-by: ThiloteE <[email protected]>

* doc: add phimoe as supported model

ggml-ci

---------

Co-authored-by: ThiloteE <[email protected]>
  • Loading branch information
phymbert and ThiloteE authored Jan 9, 2025
1 parent be0e950 commit f8feb4b
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 31 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen)
- [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557)
- [x] [Phi models](https://huggingface.co/models?search=microsoft/phi)
- [x] [PhiMoE](https://github.com/ggerganov/llama.cpp/pull/11003)
- [x] [GPT-2](https://huggingface.co/gpt2)
- [x] [Orion 14B](https://github.com/ggerganov/llama.cpp/pull/5118)
- [x] [InternLM2](https://huggingface.co/models?search=internlm2)
Expand Down
57 changes: 57 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2562,6 +2562,63 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))


@Model.register("PhiMoEForCausalLM")
class PhiMoeModel(Phi3MiniModel):
model_arch = gguf.MODEL_ARCH.PHIMOE

_experts: list[dict[str, Tensor]] | None = None

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
self.gguf_writer.add_expert_count(self.hparams["num_local_experts"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
n_experts = self.hparams["num_local_experts"]
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for w_name in ["w1", "w2", "w3"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
super().prepare_tensors()

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@Model.register("PlamoForCausalLM")
class PlamoModel(Model):
model_arch = gguf.MODEL_ARCH.PLAMO
Expand Down
10 changes: 5 additions & 5 deletions docs/development/HOWTO-add-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ The required steps to implement for an HF model are:
```python
@Model.register("MyModelForCausalLM")
class MyModel(Model):
model_arch = gguf.MODEL_ARCH.GROK
model_arch = gguf.MODEL_ARCH.MYMODEL
```

2. Define the layout of the GGUF tensors in [constants.py](/gguf-py/gguf/constants.py)
Expand Down Expand Up @@ -79,14 +79,14 @@ Depending on the model configuration, tokenizer, code and tensors layout, you wi
- `Model#set_vocab`
- `Model#write_tensors`

NOTE: Tensor names must end with `.weight` suffix, that is the convention and several tools like `quantize` expect this to proceed the weights.
NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the convention and several tools like `quantize` expect this to proceed the weights.

### 2. Define the model architecture in `llama.cpp`

The model params and tensors layout must be defined in `llama.cpp`:
1. Define a new `llm_arch`
2. Define the tensors layout in `LLM_TENSOR_NAMES`
3. Add any non standard metadata in `llm_load_hparams`
3. Add any non-standard metadata in `llm_load_hparams`
4. Create the tensors for inference in `llm_load_tensors`
5. If the model has a RoPE operation, add the rope type in `llama_rope_type`

Expand All @@ -96,9 +96,9 @@ NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorc

This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`.

Have a look at existing implementation like `build_llama`, `build_dbrx` or `build_bert`.
Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`.

When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR.
Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR.

Note: to debug the inference graph: you can use [llama-eval-callback](/examples/eval-callback/).

Expand Down
20 changes: 20 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ class MODEL_ARCH(IntEnum):
QWEN2VL = auto()
PHI2 = auto()
PHI3 = auto()
PHIMOE = auto()
PLAMO = auto()
CODESHELL = auto()
ORION = auto()
Expand Down Expand Up @@ -428,6 +429,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.QWEN2VL: "qwen2vl",
MODEL_ARCH.PHI2: "phi2",
MODEL_ARCH.PHI3: "phi3",
MODEL_ARCH.PHIMOE: "phimoe",
MODEL_ARCH.PLAMO: "plamo",
MODEL_ARCH.CODESHELL: "codeshell",
MODEL_ARCH.ORION: "orion",
Expand Down Expand Up @@ -940,6 +942,24 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.PHIMOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FACTORS_LONG,
MODEL_TENSOR.ROPE_FACTORS_SHORT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.CODESHELL: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.POS_EMBD,
Expand Down
37 changes: 20 additions & 17 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
Expand All @@ -68,7 +68,7 @@ class TensorNameMap:
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe
"norm", # llama-pth
"transformer.norm_f", # mpt dbrx
"ln_f", # refact bloom qwen gpt2
Expand Down Expand Up @@ -108,7 +108,7 @@ class TensorNameMap:
"transformer.h.{bid}.input_layernorm", # falcon7b
"h.{bid}.input_layernorm", # bloom
"transformer.h.{bid}.ln_mlp", # falcon40b
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe
"layers.{bid}.attention_norm", # llama-pth
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi
Expand Down Expand Up @@ -152,7 +152,7 @@ class TensorNameMap:

# Attention query
MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe
"model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom
"layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert
Expand All @@ -165,7 +165,7 @@ class TensorNameMap:

# Attention key
MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe
"model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom
"layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert
Expand All @@ -179,7 +179,7 @@ class TensorNameMap:

# Attention value
MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe
"layers.{bid}.attention.wv", # llama-pth
"encoder.layer.{bid}.attention.self.value", # bert
"transformer.h.{bid}.attn.v_proj", # gpt-j
Expand All @@ -197,7 +197,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe
"model.layers.{bid}.self_attn.linear_attn", # deci
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
Expand Down Expand Up @@ -242,7 +242,7 @@ class TensorNameMap:
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
"h.{bid}.post_attention_layernorm", # bloom
"transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe phimoe
"layers.{bid}.ffn_norm", # llama-pth
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
"model.layers.{bid}.ln2", # yi
Expand All @@ -265,7 +265,7 @@ class TensorNameMap:

MODEL_TENSOR.FFN_GATE_INP: (
"layers.{bid}.feed_forward.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral phimoe
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
"transformer.decoder_layer.{bid}.router", # Grok
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
Expand Down Expand Up @@ -310,10 +310,11 @@ class TensorNameMap:
),

MODEL_TENSOR.FFN_UP_EXP: (
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
),

MODEL_TENSOR.FFN_UP_SHEXP: (
Expand Down Expand Up @@ -342,10 +343,11 @@ class TensorNameMap:
),

MODEL_TENSOR.FFN_GATE_EXP: (
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
),

MODEL_TENSOR.FFN_GATE_SHEXP: (
Expand Down Expand Up @@ -387,6 +389,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
),

MODEL_TENSOR.FFN_DOWN_SHEXP: (
Expand Down
22 changes: 22 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
{ LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PHIMOE, "phimoe" },
{ LLM_ARCH_PLAMO, "plamo" },
{ LLM_ARCH_CODESHELL, "codeshell" },
{ LLM_ARCH_ORION, "orion" },
Expand Down Expand Up @@ -584,6 +585,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_PHIMOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
{ LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_PLAMO,
{
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ enum llm_arch {
LLM_ARCH_QWEN2VL,
LLM_ARCH_PHI2,
LLM_ARCH_PHI3,
LLM_ARCH_PHIMOE,
LLM_ARCH_PLAMO,
LLM_ARCH_CODESHELL,
LLM_ARCH_ORION,
Expand Down
11 changes: 11 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ const char * llm_type_name(llm_type type) {
case MODEL_8x7B: return "8x7B";
case MODEL_8x22B: return "8x22B";
case MODEL_16x12B: return "16x12B";
case MODEL_16x3_8B: return "16x3.8B";
case MODEL_10B_128x3_66B: return "10B+128x3.66B";
case MODEL_57B_A14B: return "57B.A14B";
case MODEL_27B: return "27B";
Expand Down Expand Up @@ -661,6 +662,15 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
throw std::runtime_error("invalid value for sliding_window");
}
} break;
case LLM_ARCH_PHIMOE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_16x3_8B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_PLAMO:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
Expand Down Expand Up @@ -2094,6 +2104,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_OLMOE:
case LLM_ARCH_PHI2:
case LLM_ARCH_PHI3:
case LLM_ARCH_PHIMOE:
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
case LLM_ARCH_STARCODER2:
Expand Down
1 change: 1 addition & 0 deletions src/llama-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ enum llm_type {
MODEL_8x7B,
MODEL_8x22B,
MODEL_16x12B,
MODEL_16x3_8B,
MODEL_10B_128x3_66B,
MODEL_57B_A14B,
MODEL_27B,
Expand Down
Loading

0 comments on commit f8feb4b

Please sign in to comment.