Skip to content

Commit

Permalink
[skip ci] Update llms.md (#16745)
Browse files Browse the repository at this point in the history
More edits to LLMs.md

### Ticket
Link to Github Issue

### Problem description
Provide context for the problem.

### What's changed
Describe the approach used to solve the problem.
Summarize the changes made and its impact.

### Checklist
- [ ] Post commit CI passes
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
bbeggsTT authored Jan 15, 2025
1 parent fce5276 commit 1600fd6
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions tech_reports/LLMs/llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,12 @@ The distributed implementation is designed for cases where activations are **sha

Attention in TT-NN is implemented in custom TT-NN kernels. In PyTorch, the attention OP is usually implemented in the following way with six steps:

1. QKV projections matmuls
2. Reshape Q, K, V to match the expected input shape for the attention OP
3. Apply RoPE to Q and K
4. Cache K and V
5. Scaled Dot Product Attention
6. Output reshape and output matmul
1. QKV projections matmuls.
2. Reshape Q, K, V to match the expected input shape for the attention OP.
3. Apply RoPE to Q and K.
4. Cache K and V.
5. Scaled Dot Product Attention.
6. Output reshape and output matmul.

For example, the Llama model is implemented as follows:
```python
Expand Down Expand Up @@ -452,7 +452,7 @@ An end-to-end example of the decode attention module is in the `models/demos/lla
)
)
```
- **Input/Output shapes**: The output is height sharded across the batch dimension on `bsz` number of cores.
- **Input/Output Shapes**: The output is height sharded across the batch dimension on `bsz` number of cores.
```python
(1, 1, bsz, (n_q_heads+2*n_kv_heads)*head_dim) -> (1, bsz, n_q_heads, head_dim), (1, bsz, n_kv_heads, head_dim), (1, bsz, n_kv_heads, head_dim)
```
Expand All @@ -461,7 +461,7 @@ An end-to-end example of the decode attention module is in the `models/demos/lla
- Again, apply the RoPE transformation to Q and K using the rotary embedding op outlined in [2.2 RoPE](#22-rope). The input/output shapes remain the same as in step 2.

4. Cache K and V
- Populate the KV cache at `cur_pos` for all batches with the current K and V tensors using the `ttnn.experimental.paged_update_cache` op. This op takes in an optional `page_table` argument to support paged KV cache updates. Example:
- Populate the KV cache at `cur_pos` for all batches with the current K and V tensors using the `ttnn.experimental.paged_update_cache` op. This OP takes in an optional `page_table` argument to support paged KV cache updates. Example:
```python
ttnn.experimental.paged_update_cache(keys, K, update_idxs=cur_pos, page_table=page_table)
ttnn.experimental.paged_update_cache(values, V, update_idxs=cur_pos, page_table=page_table)
Expand Down Expand Up @@ -764,7 +764,10 @@ y = FF2(w2_in)
```
FF2 is a row-parallel matmul, meaning that that the weights are fractured across devices in the inner dim. The inputs of FF2, produced by FF1/FF3, are also fractured across devices in the same dimension and as a result, FF2 produces partial outputs across all devices.

Here's what the call for the FF2 matmul looks like. Note, that once the matmul operations are completed, we can undo the reshape operation we performed on the inputs of MLP to fit the matmuls on device in `prefill`.
Here's what the call for the FF2 matmul looks like.
> [!NOTE]
> Once the matmul operations are completed, we can undo the reshape operation we performed on the inputs of MLP to fit the matmuls on device in `prefill`.

```py
w2_out = ttnn.linear(
w2_in,
Expand All @@ -781,11 +784,11 @@ if seq_len >= 1024: # Reshape back to intended shape
w2_out = ttnn.reshape(w2_out, [1, 1, seq_len, -1])
```

5.1 Accumulating the partial outputs of FF2
###### 5.1 Accumulating the partial outputs of FF2

Since the output of FF2 is the correct shape, but only a partial on each device. The output of the MLP module is required to be fractured, where each device has fully accumulated the inner dim of the matmul, but only has a fraction of the outer dim. There are two different cases to handle this, depending on if the WH system has a 1D or 2D device mesh.
Since the output of FF2 is the correct shape but only a partial on each device, the output of the MLP module is required to be fractured where each device has fully accumulated the inner dim of the matmul, but only has a fraction of the outer dim. There are two different ways to handle this, depending on if the WH system has a 1D or 2D device mesh.

1. 1D Device Mesh (n300, T3000): reduce-scatter operation across all devices, resulting in outputs fractued in the outer dim.
1. 1D Device Mesh (n300, T3000): Use a reduce-scatter operation across all devices, resulting in outputs fractued in the outer dim.
```py
w2_out_reduced = ttnn.reduce_scatter(
w2_out,
Expand All @@ -795,7 +798,7 @@ Since the output of FF2 is the correct shape, but only a partial on each device.
memory_config=ttnn.DRAM_MEMORY_CONFIG if mode == "prefill" else ttnn.L1_MEMORY_CONFIG,
)
```
2. 2D Device Mesh (TG): all-reduce operation along the same cluster axis as which the inner dimension is fractured on. The FF2 matmul inner dim is fractured across cluster axis 0 (row-parallel across 8 device), and the outer dim is fractured across cluster axis 1 (4 devices). Then an all-reduce performed on cluster axis 0 will accumulate the partials across the inner dim of the matmul and replicate them along all the devices in that axis, while still keeping them fractured across cluster axis 1 (4 devices).
2. 2D Device Mesh (TG): Use an all-reduce operation along the same cluster axis that the inner dimension is fractured on. The FF2 matmul inner dim is fractured across cluster axis 0 (row-parallel across 8 devices), and the outer dim is fractured across cluster axis 1 (4 devices). Then an all-reduce operation is performed on cluster axis 0, it will accumulate partials across the inner dim of the matmul and replicate them along all the devices in that axis, while keeping them fractured across cluster axis 1 (4 devices).
```py
w2_out_reduced = tt_all_reduce(
w2_out,
Expand All @@ -812,18 +815,16 @@ Since the output of FF2 is the correct shape, but only a partial on each device.
<div align="center">
<img src="images/2.6-decoder.png" alt="Decoder Diagram" title="Decoder Title" width="350" height="400">
</div>
When the components explained in previous sections (MLP, Attention, RMSNorm) are implemented, bringing up the decoder should be relatively straightforward.
According to the diagram (based on the Llama3.1 example), the components are stacked sequentially during the forward pass.
The only thing to consider is whether addition of MLP and Attention outputs should be stored in L1 or in DRAM.
When the components explained in previous sections (MLP, Attention, RMSNorm) are implemented, bringing up the decoder is relatively straightforward. According to the diagram (based on the Llama3.1 example), the components are stacked sequentially during the forward pass. The only consideration is whether addition of MLP and Attention outputs should be stored in L1 or in DRAM.

<br>

The Decode forward pass implementation below follows the diagram above. Keep in mind that, in order to optimize memory usage, it is recommended to deallocate tensors after their usage, which can be crucial under tighter memory constraints.
The Decode forward pass implementation below follows the diagram above. To optimize memory usage, we recommend you deallocate tensors after usage, which is crucial under tighter memory constraints.
<br>

To optimize performance in decode mode, we maintain the residual stream in L1 and shard it across cores and devices. However, determining the optimal number of cores for sharding can be challenging, especially for operations like DRAM-sharded matmuls. Here is the [code](https://github.com/tenstorrent/tt-metal/blob/53c32c0c0da926f97bd0eb042e70fd54c2866f44/models/demos/llama3/tt/model_config.py#L931) in Llama model config, that produces the core grid that will divide the N and K dims of a matmul evenly.
When it’s not feasible to keep the streams sharded, we use the ttnn op `interleave_to_sharded`, and conversely, switch back as needed.
In our implementation of Llama3.1 there are some ops that require interleaved tensors and resharding.
When it’s not feasible to keep the streams sharded, we use the TT-NN OP: `interleave_to_sharded`, and conversely, switch back as needed.
In our implementation of Llama3.1, there are some OPs that require interleaved tensors and resharding.

<br>

Expand Down Expand Up @@ -926,7 +927,7 @@ for i, split_size in enumerate(split_sizes):
)
```

We use dram-sharded matmul for LMHead with `program_config` and `memory_config` generated by the code below.
We use DRAM-sharded matmul for LMHead with `program_config` and `memory_config` generated by the code below.
For more information check [Section: Op Configs](#44-op-configs).
The primary reason for having multiple `program_configs` is that the weight shapes may result in unequal split sizes. This variability means the same configuration cannot be used for every matrix multiplication.

Expand Down

0 comments on commit 1600fd6

Please sign in to comment.