Skip to content

Commit

Permalink
Fix failed assertions while running Falcon Mamba
Browse files Browse the repository at this point in the history
  • Loading branch information
jploski committed Aug 25, 2024
1 parent 061e520 commit fae826f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou
}

static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
GGML_ASSERT(ncols % WARP_SIZE == 0 || ncols < WARP_SIZE);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_dims(min(ncols, WARP_SIZE), 1, 1);
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
Expand Down
6 changes: 3 additions & 3 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9119,9 +9119,9 @@ static struct ggml_tensor * llm_build_mamba(

// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
if (ssm_dt_b_c_rms) {
dt = ggml_rms_norm(ctx, dt, norm_rms_eps);
B = ggml_rms_norm(ctx, B, norm_rms_eps);
C = ggml_rms_norm(ctx, C, norm_rms_eps);
dt = ggml_rms_norm(ctx, ggml_cont(ctx, dt), norm_rms_eps);
B = ggml_rms_norm(ctx, ggml_cont(ctx, B), norm_rms_eps);
C = ggml_rms_norm(ctx, ggml_cont(ctx, C), norm_rms_eps);
}

// {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
Expand Down

0 comments on commit fae826f

Please sign in to comment.