From cb33bcf4b57b63e7c843c40229444b18c40d943f Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Fri, 10 Jan 2025 03:45:46 +0000 Subject: [PATCH 1/2] feat: MiniCPM3. split, broad cast write and transpose is still on working. --- examples/CMakeLists.txt | 1 + include/OpDefined.hpp | 1 + src/Layer.hpp | 23 ++ src/backends/cpu/CPUBackend.cpp | 2 + src/backends/cpu/op/CPUNTKRoPE.cpp | 330 ++++++++++++++++++ src/backends/cpu/op/CPUNTKRoPE.hpp | 115 ++++++ .../minicpm3/configuration_minicpm3.hpp | 88 ++++- src/models/minicpm3/modeling_minicpm3.hpp | 268 ++++++++++---- 8 files changed, 750 insertions(+), 78 deletions(-) create mode 100644 src/backends/cpu/op/CPUNTKRoPE.cpp create mode 100644 src/backends/cpu/op/CPUNTKRoPE.hpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 16211cee..507a81cf 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -67,6 +67,7 @@ func_llm_add_executable(demo_yi) func_llm_add_executable(demo_opt) func_llm_add_executable(demo_phi3) func_llm_add_executable(demo_minicpm) +func_llm_add_executable(demo_minicpm3) func_llm_add_executable(demo_minicpm_moe) func_llm_add_executable(demo_smollm) func_llm_add_executable(demo_openelm) diff --git a/include/OpDefined.hpp b/include/OpDefined.hpp index 29ccd1c6..fa9d1818 100644 --- a/include/OpDefined.hpp +++ b/include/OpDefined.hpp @@ -61,6 +61,7 @@ enum OpType { SPLITINPUT, IROPE, OP_NUM, + NTKROPE, // add in xnnpack DIRECT, diff --git a/src/Layer.hpp b/src/Layer.hpp index 978c1e86..282f48c5 100644 --- a/src/Layer.hpp +++ b/src/Layer.hpp @@ -1008,6 +1008,29 @@ class ScaledDotProductAttention final : public Layer { return ts[0].get(); } }; + +class NTKRoPE final : public Layer { +public: + NTKRoPE(float theta, int max_position_embeddings, int original_max_position_embeddings, const std::vector &long_factor, const std::vector &short_factor, std::string name) { + init(std::move(name), OpType::NTKROPE); + param_["theta"] = theta; + param_["max_position_embeddings"] = (float)max_position_embeddings; + param_["original_max_position_embeddings"] = (float)original_max_position_embeddings; + param_["long_factor_n"] = (float)long_factor.size(); + for (int i = 0; i < long_factor.size(); i++) { + param_["long_factor_" + std::to_string(i)] = long_factor[i]; + } + param_["short_factor_n"] = (float)short_factor.size(); + for (int i = 0; i < short_factor.size(); i++) { + param_["short_factor_" + std::to_string(i)] = short_factor[i]; + } + } + + Tensor &operator()(Tensor &input) { + auto ts = run({input}, 1); + return ts[0].get(); + } +}; // Only for QNN END } // namespace mllm diff --git a/src/backends/cpu/CPUBackend.cpp b/src/backends/cpu/CPUBackend.cpp index ba46a5a9..461d56af 100644 --- a/src/backends/cpu/CPUBackend.cpp +++ b/src/backends/cpu/CPUBackend.cpp @@ -8,6 +8,7 @@ #include "memory/SystemMemoryManager.hpp" #include "op/CPULinearInt8.hpp" +#include "op/CPUNTKRoPE.hpp" #include "op/CPUPoEmbedding.hpp" #include "op/CPUSplitInput.hpp" #include "op/CPUView.hpp" @@ -171,6 +172,7 @@ void CPUBackend::registerOps() { addCreator(LINEARINT8SHADOW, (CPUBackend::Creator *)(new CPULinearINT8ShadowCreator())); addCreator(IROPE, (CPUBackend::Creator *)(new CPUIRoPECreator())); addCreator(XP_KVCACHE, (CPUBackend::Creator *)(new CPUKVCacheXpCreator())); + addCreator(NTKROPE, (CPUBackend::Creator *)(new CPUNTKRoPECreator())); } TensorFunction *CPUBackend::funcCreate(const TensorFuncType type) { auto iter = map_function_.find(type); diff --git a/src/backends/cpu/op/CPUNTKRoPE.cpp b/src/backends/cpu/op/CPUNTKRoPE.cpp new file mode 100644 index 00000000..39ff95cb --- /dev/null +++ b/src/backends/cpu/op/CPUNTKRoPE.cpp @@ -0,0 +1,330 @@ +/** + * @file CPUNTKRoPE.cpp + * @author chenghua wang (chenghua.wang.edu@gmail.com) + * @brief + * @version 0.1 + * @date 2025-01-08 + * + * @copyright Copyright (c) 2025 + * + */ +#include "CPUNTKRoPE.hpp" +#include "Types.hpp" +#include +#include +#include "backends/cpu/quantize/QuantizeQ8.hpp" + +namespace mllm { + +int CPUNTKRoPE::in_shape_old = 0; +std::vector> CPUNTKRoPE::emb_sin_; +std::vector> CPUNTKRoPE::emb_cos_; + +namespace { +void get_sin_cos_emb_hf( + std::vector> &emb_sin, + std::vector> &emb_cos, + int seq_len, + int output_dim, + float theta, + std::vector &long_factor, + std::vector &short_factor, + int original_max_position_embeddings, + int max_position_embeddings = 2048) { + auto scale = (float)max_position_embeddings / (float)original_max_position_embeddings; + auto scaling_factor = (float)std::sqrt(1 + std::log(scale) / std::log(original_max_position_embeddings)); + + // compute sin and cos + emb_sin.resize(seq_len); + for (int i = 0; i < seq_len; ++i) { + emb_sin[i].resize(output_dim); + } + emb_cos.resize(seq_len); + for (int i = 0; i < seq_len; ++i) { + emb_cos[i].resize(output_dim); + } + + // get ext_factor + std::vector ext_factors; + if (seq_len > original_max_position_embeddings) + ext_factors = long_factor; + else + ext_factors = short_factor; + + // calculate inv_freq + std::vector inv_freq(output_dim / 2, 0.f); + for (int i = 0; i < output_dim / 2; ++i) { + inv_freq[i] = 1.f / (float)(std::pow(theta, (float)i / (float)output_dim)); + } + + std::vector t(seq_len, 0.f); + for (int s = 0; s < seq_len; ++s) t[s] = (float)s; + + std::vector> freqs; + { + int seq_len = t.size(); + int output_dim = inv_freq.size() * 2; // Since inv_freq is half the size of the final output dimension + + for (int i = 0; i < seq_len; ++i) { + freqs.emplace_back(output_dim / 2, 0.f); + for (int j = 0; j < output_dim / 2; ++j) { + freqs[i][j] = t[i] * (1.0f / ext_factors[j]) * inv_freq[j]; + } + } + } + + for (int i = 0; i < seq_len; ++i) { + for (int j = 0; j < output_dim / 2; ++j) { + emb_sin[i][j] = std::sin(freqs[i][j]) * scaling_factor; + emb_cos[i][j] = std::cos(freqs[i][j]) * scaling_factor; + } + for (int j = output_dim / 2; j < output_dim; ++j) { + emb_sin[i][j] = std::sin(freqs[i][j - output_dim / 2]) * scaling_factor; + emb_cos[i][j] = std::cos(freqs[i][j - output_dim / 2]) * scaling_factor; + } + } +} + +void apply_rope_hf( + std::shared_ptr &input, + std::shared_ptr &output, + std::vector> &emb_sin, + std::vector> &emb_cos, + int h_cnt) { + auto out_dtype = output->dtype(); + int partial_dimension = (input->dimension()) * 1; + int half = (int)(partial_dimension / 2); + assert(partial_dimension % 2 == 0); + if (output->ctype() == BSHD) { + if (input->dtype() == MLLM_TYPE_F16) { +#pragma omp parallel for collapse(4) num_threads(4) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequence + for (int d = 0; d < partial_dimension / 2; ++d) { + auto v = input->ptrAt(n, h, s, d); + auto o = output->ptrAt(n, h, s, d); + float in_value = static_cast(v[0]); + float in_value_2 = static_cast(v[half]); + float sin_value = emb_sin[s + h_cnt][d]; + float cos_value = emb_cos[s + h_cnt][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + o[0] = MLLM_FP32_TO_FP16(value); + o[half] = MLLM_FP32_TO_FP16(value2); + } + } + } + } + + } else { + if (out_dtype == MLLM_TYPE_F32) { +#pragma omp parallel for collapse(4) num_threads(4) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequence + for (int d = 0; d < partial_dimension / 2; ++d) { + auto v = input->ptrAt(n, h, s, d); + auto o = output->ptrAt(n, h, s, d); + float in_value = v[0]; + float in_value_2 = v[half]; + float sin_value = emb_sin[s + h_cnt][d]; + float cos_value = emb_cos[s + h_cnt][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + o[0] = value; + o[half] = value2; + } + } + } + } + } else if (out_dtype == MLLM_TYPE_F16) { +#pragma omp parallel for collapse(4) num_threads(4) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequence + for (int d = 0; d < partial_dimension / 2; ++d) { + auto v = input->ptrAt(n, h, s, d); + auto o = output->ptrAt(n, h, s, d); + float in_value = v[0]; + float in_value_2 = v[half]; + float sin_value = emb_sin[s + h_cnt][d]; + float cos_value = emb_cos[s + h_cnt][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + o[0] = MLLM_FP32_TO_FP16(value); + o[half] = MLLM_FP32_TO_FP16(value2); + } + } + } + } + } + } + return; + } +#pragma omp parallel for collapse(4) num_threads(4) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequence + for (int d = 0; d < partial_dimension / 2; ++d) { + if (input->dtype() == MLLM_TYPE_F16) { + float in_value = static_cast(input->dataAt(n, h, s, d)); + float in_value_2 = static_cast(input->dataAt(n, h, s, d + partial_dimension / 2)); + float sin_value = emb_sin[s + h_cnt][d]; + float cos_value = emb_cos[s + h_cnt][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, value); + output->setDataAt(n, h, s, d + partial_dimension / 2, value2); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + output->setDataAt(n, h, s, d + partial_dimension / 2, MLLM_FP32_TO_FP16(value2)); + } + + } else { + auto in_value = input->dataAt(n, h, s, d); + auto in_value_2 = input->dataAt(n, h, s, d + partial_dimension / 2); + float sin_value = emb_sin[s + h_cnt][d]; + float cos_value = emb_cos[s + h_cnt][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, value); + output->setDataAt(n, h, s, d + partial_dimension / 2, value2); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + output->setDataAt(n, h, s, d + partial_dimension / 2, MLLM_FP32_TO_FP16(value2)); + } + } + } + } + } + } +} +} // namespace + +CPUNTKRoPE::CPUNTKRoPE(Backend *bn, string op_name, int pose_type, int thread_count) : + Op(bn, op_name), thread_count_(thread_count), pose_type_(pose_type) { +} + +CPUNTKRoPE::CPUNTKRoPE(Backend *bn, string op_name, int pose_type, float rope_theta, + const std::vector &long_factor, + const std::vector &short_factor, + int original_max_position_embeddings, + int max_position_embeddings, + int thread_count) : + Op(bn, op_name), + thread_count_(thread_count), + pose_type_(pose_type), + rope_theta_(rope_theta), + long_factor_(long_factor), + short_factor_(short_factor), + original_max_position_embeddings_(original_max_position_embeddings), + max_position_embeddings_(max_position_embeddings) { +} + +ErrorCode CPUNTKRoPE::doExecute(std::vector> inputs, std::vector> outputs) { + auto &input = inputs[0]; + auto &output = outputs[0]; + auto out_dtype = output->dtype(); + int partial_dimension = (input->dimension()) * 1; + switch ((RoPEType)pose_type_) { + case RoPEType::HFHUBROPE: + apply_rope_hf(input, output, emb_sin_, emb_cos_, h_cnt_); + break; + default: + MLLM_LOG_ERROR("RoPEType={} is not supported yet. Currently, only support HFHUBROPE style NTKRoPE", pose_type_); + break; + } + +#pragma omp parallel for collapse(4) num_threads(4) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { + for (int d = partial_dimension; d < input->dimension(); ++d) { + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, input->dataAt(n, h, s, d)); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(input->dataAt(n, h, s, d))); + } + } + } + } + } + + h_cnt_ += input->sequence(); + if (h_cnt_ >= max_position_embeddings_) { + h_cnt_ = 0; + } + return Op::execute(inputs, outputs); +} + +ErrorCode CPUNTKRoPE::reshape(std::vector> inputs, std::vector> outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + outputs[0]->reshape(inputs[0]->batch(), inputs[0]->head(), inputs[0]->sequence(), inputs[0]->dimension()); + in_shape = inputs[0]->dimension(); + if (emb_sin_.empty() || in_shape_old < in_shape) { + in_shape_old = in_shape; + switch ((RoPEType)pose_type_) { + case RoPEType::HFHUBROPE: + get_sin_cos_emb_hf( + emb_sin_, + emb_cos_, + max_position_embeddings_, + inputs[0]->dimension(), + rope_theta_, + long_factor_, + short_factor_, + original_max_position_embeddings_, + max_position_embeddings_); + break; + default: + MLLM_LOG_ERROR("RoPEType={} is not supported yet. Currently, only support HFHUBROPE style NTKRoPE", pose_type_); + break; + } + } + return Op::reshape(inputs, outputs); + return MLLM_NO_ERROR; +} + +ErrorCode CPUNTKRoPE::execute(std::vector> inputs, std::vector> outputs) { + if (outputs[0]->dtype() == MLLM_TYPE_Q8_0) { + auto tmp_out = std::make_shared(outputs[0]->backend()); + // tmp_out->setBackend(outputs[0]->backend()); + auto b = outputs[0]->batch(); + auto h = outputs[0]->head(); + auto d = outputs[0]->dimension(); + auto s = outputs[0]->sequence(); + tmp_out->chls() = outputs[0]->chls(); + tmp_out->setCtype(outputs[0]->ctype()); + tmp_out->reshape(b, h, s, d); + tmp_out->setDtype(MLLM_TYPE_F32); + tmp_out->alloc(); + doExecute(inputs, {tmp_out}); +#pragma omp parallel for collapse(3) num_threads(4) + for (int b = 0; b < tmp_out->batch(); b++) { + for (int h = 0; h < tmp_out->head(); h++) { + for (int s = 0; s < tmp_out->sequence(); s++) { + quantize_row_q8_0(tmp_out->hostPtr() + tmp_out->offset(b, h, s, 0), + (char *)outputs[0]->rawHostPtr() + + outputs[0]->offset(b, h, s, 0) * sizeof(block_q8_0) / QK8_0, + tmp_out->dimension()); + } + } + } + return MLLM_NO_ERROR; + } else { + return doExecute(inputs, outputs); + } +} + +ErrorCode CPUNTKRoPE::load(AbstructLoader &loader) { + return Op::load(loader); +} + +ErrorCode CPUNTKRoPE::free(std::vector> inputs, std::vector> outputs) { + return Op::free(inputs, outputs); +} +} // namespace mllm \ No newline at end of file diff --git a/src/backends/cpu/op/CPUNTKRoPE.hpp b/src/backends/cpu/op/CPUNTKRoPE.hpp new file mode 100644 index 00000000..29c23211 --- /dev/null +++ b/src/backends/cpu/op/CPUNTKRoPE.hpp @@ -0,0 +1,115 @@ +/** + * @file CPUNTKRoPE.hpp + * @author chenghua wang (chenghua.wang.edu@gmail.com) + * @brief + * @version 0.1 + * @date 2025-01-08 + * + * @copyright Copyright (c) 2025 + * + */ +#pragma once + +#include "Op.hpp" +#include "../CPUBackend.hpp" + +// 1. Scaling factor +// \text{scale} = \frac{\text{max\_position\_embeddings}}{\text{original\_max\_position\_embeddings}} +// \text{scaling\_factor} = \sqrt{1 + \frac{\log(\text{scale})}{\log(\text{original\_max\_position\_embeddings})}} + +// 2. Frequency Calculation +// t = [0, 1, 2, \dots, \text{seq\_len} - 1] +// \text{ext\_factors} = +// \begin{cases} +// \text{long\_factor} & \text{if } \text{seq\_len} > \text{original\_max\_position\_embeddings} \\ +// \text{short\_factor} & \text{otherwise} +// \end{cases} +// \text{freqs} = \left(t \cdot \frac{1}{\text{ext\_factors}}\right) \otimes \text{inv\_freq} + +// 3. Rotary Position Embedding +// \text{emb} = [\text{freqs}, \text{freqs}] +// \text{cos\_cached} = \cos(\text{emb}) \cdot \text{scaling\_factor} +// \text{sin\_cached} = \sin(\text{emb}) \cdot \text{scaling\_factor} + +// 4. all +// \text{RoPE}(x, t) = +// \begin{bmatrix} +// \cos(\theta_t) & -\sin(\theta_t) \\ +// \sin(\theta_t) & \cos(\theta_t) +// \end{bmatrix} +// \cdot x +// +// \theta_t = t \cdot \frac{1}{\text{ext\_factors}} \cdot \text{inv\_freq} + +namespace mllm { + +class CPUNTKRoPE final : public Op { +public: + CPUNTKRoPE(Backend *bn, string op_name, int pose_type, int thread_count); + CPUNTKRoPE(Backend *bn, string op_name, int pose_type, float rope_theta, + const std::vector &long_factor, + const std::vector &short_factor, + int original_max_position_embeddings, + int max_position_embeddings, + int thread_count); + + ~CPUNTKRoPE() override = default; + ErrorCode reshape(std::vector> inputs, std::vector> outputs) override; + + // FIXME: Typo here !!! Abstract + ErrorCode load(AbstructLoader &loader) override; + ErrorCode execute(std::vector> inputs, std::vector> outputs) override; + ErrorCode free(std::vector> inputs, std::vector> outputs) override; + ErrorCode doExecute(std::vector> inputs, std::vector> outputs); + +private: + static int in_shape_old; + static std::vector> emb_sin_; + static std::vector> emb_cos_; + std::vector long_factor_; + std::vector short_factor_; + int pose_type_ = 4; + int thread_count_ = 4; + int h_cnt_ = 0; + float rope_theta_ = 1e-4f; + int max_position_embeddings_ = 32768; + int original_max_position_embeddings_ = 32768; + int in_shape = -1; + + void + clearCache() override { + h_cnt_ = 0; + } +}; + +class CPUNTKRoPECreator : public CPUBackend::Creator { +public: + // FIXME: OpParam is copied. + // FIXME: name is copied, may optimized to move by compiler. + Op *create(OpParam op_param, Backend *bn, string name, int thread_count) const override { + int pose_type = static_cast(op_param["pose_type"]); + float rope_theta = op_param["rope_theta"]; + int max_position_embeddings = static_cast(op_param["max_position_embeddings"]); + + int long_factor_n = static_cast(op_param["long_factor_n"]); + int short_factor_n = static_cast(op_param["short_factor_n"]); + std::vector long_factor(long_factor_n); + std::vector short_factor(short_factor_n); + + // FIXME: the way we pass vector to backend is inefficient. + for (int _i_long_factor_n = 0; _i_long_factor_n < long_factor_n; _i_long_factor_n++) { + long_factor.push_back(op_param["long_factor_" + std::to_string(_i_long_factor_n)]); + } + + for (int _i_short_factor_n = 0; _i_short_factor_n < short_factor_n; _i_short_factor_n++) { + short_factor.push_back(op_param["short_factor_" + std::to_string(_i_short_factor_n)]); + } + + int original_max_position_embeddings = static_cast(op_param["original_max_position_embeddings"]); + + return new CPUNTKRoPE(bn, name, pose_type, rope_theta, long_factor, short_factor, + original_max_position_embeddings, max_position_embeddings, thread_count); + } +}; + +} // namespace mllm \ No newline at end of file diff --git a/src/models/minicpm3/configuration_minicpm3.hpp b/src/models/minicpm3/configuration_minicpm3.hpp index b8fdcd7d..4e5c1168 100644 --- a/src/models/minicpm3/configuration_minicpm3.hpp +++ b/src/models/minicpm3/configuration_minicpm3.hpp @@ -3,9 +3,26 @@ #define CONFIG_MINICPM_HPP #include "Types.hpp" #include "models/transformer/configuration_transformer.hpp" +#include using namespace mllm; +// the model naming method is from minicpm3 hf repo +// model.embed_tokens.weight +// model.norm.weight +// model.layers.0.input_layernorm.weight +// model.layers.0.self_attn.q_b_proj.weight +// model.layers.0.self_attn.q_a_proj.weight +// model.layers.0.self_attn.kv_b_proj.weight +// model.layers.0.self_attn.kv_a_proj_with_mqa.weight +// model.layers.0.self_attn.q_a_layernorm.weight +// model.layers.0.self_attn.kv_a_layernorm.weight +// model.layers.0.self_attn.o_proj.weight +// model.layers.0.post_attention_layernorm.weight +// model.layers.0.mlp.gate_proj.weight +// model.layers.0.mlp.up_proj.weight +// model.layers.0.mlp.down_proj.weight + class MiniCPM3NameConfig : public TransformerNameConfig { public: /** @@ -17,7 +34,9 @@ class MiniCPM3NameConfig : public TransformerNameConfig { blk_name = "model.layers."; _attn_base_name = "self_attn."; _ffn_base_name = "mlp."; - _q_proj_name = "q_proj"; + _q_b_proj_name = "q_b_proj"; + _q_a_proj_name = "q_a_proj"; + _q_a_layernorm = "q_a_layernorm"; _kv_a_proj_with_mqa_name = "kv_a_proj_with_mqa"; _kv_a_layernorm_name = "kv_a_layernorm"; _kv_b_proj_name = "kv_b_proj"; @@ -35,6 +54,9 @@ class MiniCPM3NameConfig : public TransformerNameConfig { std::string _kv_a_proj_with_mqa_name; std::string _kv_a_layernorm_name; std::string _kv_b_proj_name; + std::string _q_b_proj_name; + std::string _q_a_proj_name; + std::string _q_a_layernorm; std::string blk_name; std::string token_embd_name; @@ -49,17 +71,65 @@ struct MiniCPM3Config : public TransformerConfig { names_config.init(); }; - int vocab_size = 73448; - int max_position_embeddings = 32768; - int num_hidden_layers = 62; + int bos_token_id = 1; + std::vector eos_token_ids = {2, 73440}; + float initializer_range = 0.1f; int hidden_size = 2560; + int num_hidden_layers = 62; int intermediate_size = 6400; - int num_heads = 40; - int qk_rope_head_dim = 32; // qk_rope_head_dim - int qk_nope_head_dim = 64; // qk_nope_head_dim = qk_rope_head_dim*2 - int kv_lora_rank = 256; // kv_lora_rank = 2568* qk_nope_head_dim; + int max_position_embeddings = 32768; + int num_attention_heads = 40; + int num_key_value_heads = 40; + int qk_nope_head_dim = 64; + int qk_rope_head_dim = 32; + int q_lora_rank = 768; + int kv_lora_rank = 256; + float rms_norm_eps = 1e-06f; + + int vocab_size = 73448; + int scale_emb = 12; + int dim_model_base = 256; + float scale_depth = 1.4f; + + // rope_scaling + std::string rope_type = "longrope"; + std::vector rope_long_factor = {1.0591234137867171, + 1.1241891283591912, + 1.2596935748670968, + 1.5380380402321725, + 2.093982484148734, + 3.1446935121267696, + 4.937952647693647, + 7.524541999994549, + 10.475458000005451, + 13.062047352306353, + 14.85530648787323, + 15.906017515851266, + 16.461961959767827, + 16.740306425132907, + 16.87581087164081, + 16.940876586213285}; + std::vector rope_short_factor = {1.0591234137867171, + 1.1241891283591912, + 1.2596935748670968, + 1.5380380402321725, + 2.093982484148734, + 3.1446935121267696, + 4.937952647693647, + 7.524541999994549, + 10.475458000005451, + 13.062047352306353, + 14.85530648787323, + 15.906017515851266, + 16.461961959767827, + 16.740306425132907, + 16.87581087164081, + 16.940876586213285}; + float rope_theta = 10000.f; + int rope_original_max_position_embeddings = 32768; + + float attention_dropout = 0.f; - float rms_norm_eps = 1e-6; int cache_limit; bool do_mask = true; diff --git a/src/models/minicpm3/modeling_minicpm3.hpp b/src/models/minicpm3/modeling_minicpm3.hpp index 3cc1ca62..591e3f0b 100644 --- a/src/models/minicpm3/modeling_minicpm3.hpp +++ b/src/models/minicpm3/modeling_minicpm3.hpp @@ -5,97 +5,177 @@ #ifndef MODELING_MINICPM_HPP #define MODELING_MINICPM_HPP +#include "Types.hpp" #include "configuration_minicpm3.hpp" using namespace mllm; class MiniCPM3MultiHeadLatentAttention final : public Module { - Layer q_proj; + int hidden_size = 0; + int num_heads = 0; + int max_position_embeddings = 0; + float rope_theta = 0.f; + int q_lora_rank = 0; + int qk_rope_head_dim = 0; + int kv_lora_rank = 0; + int v_head_dim = 0; + int qk_nope_head_dim = 0; + int q_head_dim = 0; + + Layer q_a_proj; + Layer q_a_layernorm; + Layer q_b_proj; Layer kv_a_proj_with_mqa; Layer kv_a_layernorm; Layer kv_b_proj; - Layer k_proj; - Layer v_proj; - RoPE q_rope; - RoPE k_rope; + Layer o_proj; + Layer q_rope; + Layer k_rope; KVCache k_cache; KVCache v_cache; Softmax softmax; - Layer o_proj; - int num_heads{}; - int q_head_dim{}; - int v_head_dim{}; - int qk_nope_head_dim{}; - int qk_rope_head_dim{}; - int kv_lora_rank{}; - float softmax_scale{}; + + float softmax_scale = 0.f; public: MiniCPM3MultiHeadLatentAttention() = default; MiniCPM3MultiHeadLatentAttention(const MiniCPM3Config config, const MiniCPM3NameConfig &names, const string &base_name) { - num_heads = config.num_heads; - qk_nope_head_dim = config.qk_nope_head_dim; + hidden_size = config.hidden_size; + num_heads = config.num_attention_heads; + max_position_embeddings = config.max_position_embeddings; + rope_theta = config.rope_theta; + q_lora_rank = config.q_lora_rank; qk_rope_head_dim = config.qk_rope_head_dim; kv_lora_rank = config.kv_lora_rank; - v_head_dim = config.hidden_size / config.num_heads; // config.v_head_dim; + v_head_dim = config.hidden_size / config.num_attention_heads; + qk_nope_head_dim = config.qk_nope_head_dim; q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim; - q_proj = Linear( - config.hidden_size, + + q_a_proj = Linear( + hidden_size, + q_lora_rank, + false, + base_name + names._q_a_proj_name); + + q_a_layernorm = RMSNorm( + q_lora_rank, + config.rms_norm_eps, + base_name + names._q_a_layernorm); + + q_b_proj = Linear( + q_lora_rank, num_heads * q_head_dim, false, - base_name + names._q_proj_name); + base_name + names._q_b_proj_name); + kv_a_proj_with_mqa = Linear( - config.hidden_size, + hidden_size, kv_lora_rank + qk_rope_head_dim, false, base_name + names._kv_a_proj_with_mqa_name); - kv_a_layernorm = RMSNorm(kv_lora_rank, config.rms_norm_eps, base_name + names._kv_a_layernorm_name); + + kv_a_layernorm = RMSNorm( + kv_lora_rank, + config.rms_norm_eps, + base_name + names._kv_a_layernorm_name); + kv_b_proj = Linear( kv_lora_rank, num_heads * (q_head_dim - qk_rope_head_dim + v_head_dim), false, base_name + names._kv_b_proj_name); + o_proj = Linear( num_heads * v_head_dim, - config.hidden_size, + hidden_size, false, base_name + names._o_proj_name); - q_rope = RoPE(RoPEType::MLAROPE, base_name + "q_rope"); - k_rope = RoPE(RoPEType::MLAROPE, base_name + "k_rope"); + + q_rope = NTKRoPE( + rope_theta, + max_position_embeddings, + config.rope_original_max_position_embeddings, + config.rope_long_factor, + config.rope_short_factor, + base_name + "q_rope"); + + k_rope = NTKRoPE( + rope_theta, + max_position_embeddings, + config.rope_original_max_position_embeddings, + config.rope_long_factor, + config.rope_short_factor, + base_name + "k_rope"); + + // TODO num_heads. may error. if (config.cache_limit > 0) { k_cache = KVCache(num_heads / num_heads, config.cache_limit, base_name + "k_cache"); v_cache = KVCache(num_heads / num_heads, config.cache_limit, base_name + "v_cache"); } + softmax = Softmax(DIMENSION, config.do_mask, base_name + "softmax"); - softmax_scale = 1 / std::sqrt(q_head_dim); + + softmax_scale = 1.f / (float)std::sqrt(q_head_dim); } + vector Forward(vector inputs, vector args) override { auto hidden_states = inputs[0]; - auto q = q_proj(hidden_states); + auto bsz = hidden_states.batch(); + auto q_len = hidden_states.sequence(); + + // q: [bs, len, 1, num_heads * q_head_dim] + auto q = q_b_proj(q_a_layernorm(q_a_proj(hidden_states))); + + // q: [bs, num_heads, len, q_head_dim] + q = q.view(bsz, num_heads, q_len, q_head_dim).transpose(SEQUENCE, HEAD); + + // q_nope: [bs, num_heads, len, qk_nope_head_dim] + // q_pe: [bs, num_heads, len, qk_rope_head_dim] auto qs = Tensor::split(q, {qk_nope_head_dim, qk_rope_head_dim}, D_HD, num_heads); - q = Tensor::cat({qs[0], q_rope(qs[1])}, DIMENSION); + auto q_nope = qs[0]; + auto q_pe = qs[1]; + + // compressed_kv: [bs, len, 1, kv_lora_rank + qk_rope_head_dim] + auto compressed_kv = kv_a_proj_with_mqa(hidden_states); + + // compressed_kv: [bs, len, 1, kv_lora_rank] + // k_pe: [bs, len, 1, qk_rope_head_dim] + auto kvs = Tensor::split(compressed_kv, {kv_lora_rank, qk_rope_head_dim}, DIMENSION); + compressed_kv = kvs[0]; + auto k_pe = kvs[1]; - Tensor compressed_kv = kv_a_proj_with_mqa(hidden_states); - auto kvs = Tensor::split(compressed_kv, - {kv_lora_rank, qk_rope_head_dim}, DIMENSION); - auto k_pe = k_rope(kvs[1]); - auto kv = kv_b_proj(kv_a_layernorm(kvs[0])); //.view(-1, head_size_, -1, qk_nope_head_dim_ + v_head_dim_); + // k_pe: [bs, 1, len, qk_rope_head_dim] + k_pe = k_pe.get().view(bsz, 1, q_len, qk_rope_head_dim).transpose(SEQUENCE, HEAD); + + // kv: [bs, num_heads, len, q_head_dim - qk_rope_head_dim + v_head_dim] + auto kv = kv_b_proj(kv_a_layernorm(compressed_kv)).view(bsz, q_len, num_heads, qk_nope_head_dim + v_head_dim).transpose(SEQUENCE, HEAD); + + // k_nope: [bs, num_heads, len, qk_nope_head_dim] + // value_states: [bs, num_heads, len , v_head_dim] kvs = Tensor::split(kv, {qk_nope_head_dim, v_head_dim}, D_HD, num_heads); - auto v = kvs[1]; - auto k = Tensor::cat({kvs[0], k_pe}, DIMENSION); - if (k_cache.ready() && v_cache.ready()) { - k = k_cache(k); - v = v_cache(v); - } - k = k.transpose(SEQUENCE, DIMENSION); - auto qk = Tensor::mm(q, k); - qk = qk * softmax_scale; - qk = softmax(qk, k_cache.getCacheSeqLen()); - auto o = Tensor::mm(qk, v); - o = o.view(-1, 1, -1, v_head_dim * num_heads); - o = o_proj(o); - return {o}; + auto k_nope = kvs[0]; + auto value_states = kvs[1]; + + q_pe = q_rope(q_pe); + k_pe = k_rope(k_pe); + + auto query_states = Tensor::cat({q_pe, q_nope}, DIMENSION); + + // TODO k_pe should broad cast k_pe to num_heads first. + auto key_states = Tensor::cat({k_pe, k_nope}, DIMENSION); + + key_states = k_cache(key_states); + value_states = v_cache(value_states); + + auto attn_weight = Tensor::mm(query_states, key_states.transpose(DIMENSION, HEAD)); + attn_weight = attn_weight * softmax_scale; + attn_weight = softmax(attn_weight, k_cache.getCacheSeqLen()); + auto attn_output = Tensor::mm(attn_weight, value_states); + attn_output = attn_output.view(-1, 1, -1, v_head_dim * num_heads); + attn_output = o_proj(attn_output); + + return {attn_output}; } }; @@ -104,31 +184,44 @@ class MiniCPM3MLP final : public Module { Layer gate_proj; Layer up_proj; Layer down_proj; - Layer gelu; + Layer silu; public: MiniCPM3MLP() = default; MiniCPM3MLP(const MiniCPM3Config &config, const MiniCPM3NameConfig &names, const std::string &base_name) { int hidden_size = config.hidden_size; int intermediate_size = config.intermediate_size; - gate_proj = Linear(hidden_size, intermediate_size, false, base_name + names._gate_proj_name); - gelu = SiLU(base_name + "act"); - up_proj = Linear(hidden_size, intermediate_size, false, base_name + names._up_proj_name); - down_proj = Linear(intermediate_size, hidden_size, false, base_name + names._down_proj_name); + + gate_proj = Linear( + hidden_size, + intermediate_size, + false, + base_name + names._gate_proj_name); + + silu = SiLU(base_name + "act"); + + up_proj = Linear( + hidden_size, + intermediate_size, + false, + base_name + names._up_proj_name); + + down_proj = Linear( + intermediate_size, + hidden_size, + false, + base_name + names._down_proj_name); } + std::vector Forward(std::vector inputs, std::vector args) override { - auto x = gate_proj(inputs[0]); - x = gelu(x); - auto y = up_proj(inputs[0]); - x = x * y; - x = down_proj(x); - return {x}; + auto x = inputs[0]; + return {down_proj(silu(gate_proj(x)) * up_proj(x))}; } }; class MiniCPM3Decoder final : public Module { private: - MiniCPM3MultiHeadLatentAttention self_atten; + MiniCPM3MultiHeadLatentAttention self_attn; MiniCPM3MLP mlp; Layer input_layernorm; Layer post_attention_layernorm; @@ -136,15 +229,30 @@ class MiniCPM3Decoder final : public Module { public: MiniCPM3Decoder() = default; MiniCPM3Decoder(const MiniCPM3Config &config, const MiniCPM3NameConfig &names, const string &base_name) { - self_atten = MiniCPM3MultiHeadLatentAttention(config, names, base_name + names._attn_base_name); - mlp = MiniCPM3MLP(config, names, base_name + names._ffn_base_name); - input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); - post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._ffn_norm_name); + self_attn = MiniCPM3MultiHeadLatentAttention( + config, + names, + base_name + names._attn_base_name); + + mlp = MiniCPM3MLP( + config, + names, + base_name + names._ffn_base_name); + + input_layernorm = RMSNorm( + config.hidden_size, + config.rms_norm_eps, + base_name + names._attn_norm_name); + + post_attention_layernorm = RMSNorm( + config.hidden_size, + config.rms_norm_eps, + base_name + names._ffn_norm_name); } std::vector Forward(std::vector inputs, std::vector args) override { auto x = input_layernorm(inputs[0]); - x = self_atten({x, x, x})[0]; + x = self_attn({x, x, x})[0]; auto tmp = x + inputs[0]; x = post_attention_layernorm(tmp); x = mlp({x})[0]; @@ -161,8 +269,17 @@ class MiniCPM3Model final : public Module { public: MiniCPM3Model() = default; MiniCPM3Model(const MiniCPM3Config &config, const MiniCPM3NameConfig &names, const string &base_name) { - blocks = List(config.num_hidden_layers, config, names, base_name); - norm = RMSNorm(config.hidden_size, config.rms_norm_eps, true, names.post_norm_name); + blocks = List( + config.num_hidden_layers, + config, + names, + base_name); + + norm = RMSNorm( + config.hidden_size, + config.rms_norm_eps, + true, + names.post_norm_name); } std::vector Forward(std::vector inputs, std::vector args) override { @@ -183,15 +300,28 @@ class MiniCPM3ForCausalLM final : public Module { MiniCPM3Model model; public: - MiniCPM3ForCausalLM(MiniCPM3Config &config) { + explicit MiniCPM3ForCausalLM(MiniCPM3Config &config) { auto names = config.names_config; hidden_size = config.hidden_size; - embedding = Embedding(config.vocab_size, config.hidden_size, names.token_embd_name); - model = MiniCPM3Model(config, names, names.blk_name); + + embedding = Embedding( + config.vocab_size, + config.hidden_size, + names.token_embd_name); + + model = MiniCPM3Model( + config, + names, + names.blk_name); // lm_head and tok_embedding is tied together. // They share same parameters. Use a Transpose to do the lm_head instead. - lm_head = Parameter(1, config.vocab_size, 1, config.hidden_size, names.lm_head_name + ".weight"); + lm_head = Parameter( + 1, + config.vocab_size, + 1, + config.hidden_size, + names.lm_head_name + ".weight"); } std::vector Forward(std::vector inputs, std::vector args) override { auto x = embedding(inputs[0]); From 4e2251331b4ff49a5ddd920d7fedfcae0a616f41 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Fri, 10 Jan 2025 08:23:51 +0000 Subject: [PATCH 2/2] fix: view and transpose bug --- src/Layer.hpp | 3 ++- src/models/minicpm3/modeling_minicpm3.hpp | 30 +++++++++++++++++------ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/Layer.hpp b/src/Layer.hpp index 282f48c5..4f057955 100644 --- a/src/Layer.hpp +++ b/src/Layer.hpp @@ -1011,8 +1011,9 @@ class ScaledDotProductAttention final : public Layer { class NTKRoPE final : public Layer { public: - NTKRoPE(float theta, int max_position_embeddings, int original_max_position_embeddings, const std::vector &long_factor, const std::vector &short_factor, std::string name) { + NTKRoPE(RoPEType type, float theta, int max_position_embeddings, int original_max_position_embeddings, const std::vector &long_factor, const std::vector &short_factor, std::string name) { init(std::move(name), OpType::NTKROPE); + param_["pose_type"] = (float)type; param_["theta"] = theta; param_["max_position_embeddings"] = (float)max_position_embeddings; param_["original_max_position_embeddings"] = (float)original_max_position_embeddings; diff --git a/src/models/minicpm3/modeling_minicpm3.hpp b/src/models/minicpm3/modeling_minicpm3.hpp index 591e3f0b..0ca2e277 100644 --- a/src/models/minicpm3/modeling_minicpm3.hpp +++ b/src/models/minicpm3/modeling_minicpm3.hpp @@ -92,6 +92,7 @@ class MiniCPM3MultiHeadLatentAttention final : public Module { base_name + names._o_proj_name); q_rope = NTKRoPE( + HFHUBROPE, rope_theta, max_position_embeddings, config.rope_original_max_position_embeddings, @@ -100,6 +101,7 @@ class MiniCPM3MultiHeadLatentAttention final : public Module { base_name + "q_rope"); k_rope = NTKRoPE( + HFHUBROPE, rope_theta, max_position_embeddings, config.rope_original_max_position_embeddings, @@ -128,11 +130,11 @@ class MiniCPM3MultiHeadLatentAttention final : public Module { auto q = q_b_proj(q_a_layernorm(q_a_proj(hidden_states))); // q: [bs, num_heads, len, q_head_dim] - q = q.view(bsz, num_heads, q_len, q_head_dim).transpose(SEQUENCE, HEAD); + q = q.view(-1, num_heads, -1, q_head_dim).transpose(SEQUENCE, HEAD); // q_nope: [bs, num_heads, len, qk_nope_head_dim] // q_pe: [bs, num_heads, len, qk_rope_head_dim] - auto qs = Tensor::split(q, {qk_nope_head_dim, qk_rope_head_dim}, D_HD, num_heads); + auto qs = Tensor::split(q, {qk_nope_head_dim, qk_rope_head_dim}, D_HD, 1); auto q_nope = qs[0]; auto q_pe = qs[1]; @@ -141,19 +143,19 @@ class MiniCPM3MultiHeadLatentAttention final : public Module { // compressed_kv: [bs, len, 1, kv_lora_rank] // k_pe: [bs, len, 1, qk_rope_head_dim] - auto kvs = Tensor::split(compressed_kv, {kv_lora_rank, qk_rope_head_dim}, DIMENSION); + auto kvs = Tensor::split(compressed_kv, {kv_lora_rank, qk_rope_head_dim}, DIMENSION, 1); compressed_kv = kvs[0]; - auto k_pe = kvs[1]; + Tensor k_pe = kvs[1]; // k_pe: [bs, 1, len, qk_rope_head_dim] - k_pe = k_pe.get().view(bsz, 1, q_len, qk_rope_head_dim).transpose(SEQUENCE, HEAD); + k_pe = k_pe.transpose(SEQUENCE, HEAD); // kv: [bs, num_heads, len, q_head_dim - qk_rope_head_dim + v_head_dim] - auto kv = kv_b_proj(kv_a_layernorm(compressed_kv)).view(bsz, q_len, num_heads, qk_nope_head_dim + v_head_dim).transpose(SEQUENCE, HEAD); + auto kv = kv_b_proj(kv_a_layernorm(compressed_kv)).view(-1, num_heads, -1, qk_nope_head_dim + v_head_dim).transpose(SEQUENCE, HEAD); // k_nope: [bs, num_heads, len, qk_nope_head_dim] // value_states: [bs, num_heads, len , v_head_dim] - kvs = Tensor::split(kv, {qk_nope_head_dim, v_head_dim}, D_HD, num_heads); + kvs = Tensor::split(kv, {qk_nope_head_dim, v_head_dim}, D_HD, 1); auto k_nope = kvs[0]; auto value_states = kvs[1]; @@ -162,9 +164,21 @@ class MiniCPM3MultiHeadLatentAttention final : public Module { auto query_states = Tensor::cat({q_pe, q_nope}, DIMENSION); - // TODO k_pe should broad cast k_pe to num_heads first. + std::vector broad_casted_k_pe_list; + broad_casted_k_pe_list.reserve(num_heads); + for (int i = 0; i < num_heads; i++) { + broad_casted_k_pe_list.push_back(k_pe); + } + k_pe = Tensor::cat(broad_casted_k_pe_list, HEAD); + + // TODO error below. auto key_states = Tensor::cat({k_pe, k_nope}, DIMENSION); + // original + // value_states: [bs, num_heads, len , v_head_dim] + // k_nope: [bs, num_heads, len, qk_nope_head_dim + qk_rope_head_dim] + // after kvcache + // ... key_states = k_cache(key_states); value_states = v_cache(value_states);