Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3 model family - list of required ops for blackhole #16013

Open
mtairum opened this issue Dec 13, 2024 · 30 comments
Open

Llama3 model family - list of required ops for blackhole #16013

mtairum opened this issue Dec 13, 2024 · 30 comments

Comments

@mtairum
Copy link
Contributor

mtairum commented Dec 13, 2024

This issue lists the ops required for the Llama8B model (and the rest of the llama3 model family).

Looking at the current list of supported Blackhole ops, the following seem to be the ops we'll require to properly support Llama3 family in blackhole:

Prefill only ops:

  • ttnn.transformer.scaled_dot_product_attention (if not using chunks)
  • ttnn.transformer.chunked_scaled_dot_product_attention (if using chunks - not in the traces. This is the same op as previous one but with page table and chunk start index)

Below are the graph trace and the perf trace with extra info on the ops (including memory configs and shapes).

Updated traces [14 Jan 2025]

Please use these new traces for 1B, 8B and 70B llama3 model. These include both prefill and decode and were taken by running the demo.py script with 1L, for 10 iterations.

[OLD] Graph Trace

The list of ops was generated with ttnn graph trace:

ttnn.graph.begin_graph_capture(ttnn.graph.RunMode.NORMAL)

(...) Llama3 - 8B model run

captured_graph = ttnn.graph.end_graph_capture() # End capturing the graph
ttnn.graph.pretty_print(captured_graph)
ttnn.graph.visualize(captured_graph, file_name="graph.svg")

llama8b-1L-op_graph.txt
Image

Ops Perf report

Generated with tracy, this is the ops perf report, which includes the memory configs and input shapes of the required ops

Llama3-70B

Additionally, we'll also want support for Llama3-70B ops, which are mostly the same but with different input sizes.
In this section I'll list any new ops separately, and provide the ops perf report.

Additional ops:

  • ttnn.all_gather
  • ttnn.reduce_scatter

Ops Perf report

@mtairum mtairum added the bug Something isn't working label Dec 13, 2024
@mtairum mtairum self-assigned this Dec 13, 2024
@mtairum mtairum added blackhole and removed bug Something isn't working labels Dec 13, 2024
@mtairum mtairum removed their assignment Dec 13, 2024
@mtairum mtairum changed the title Llama3-8b - blackhole ops Llama3-8b - list of required ops for blackhole Dec 13, 2024
@mtairum
Copy link
Contributor Author

mtairum commented Dec 13, 2024

@prajaramanTT Are you the right person to tag on this issue?

In the model team we want to understand what's the current op support in blackhole and what's missing for us to support Llama3.

For now, this issue is listing Llama3-8B, which will run on a single device. We want to provide the list of ops + shapes required so those can be added to the ttnn op sweep tests soon.

Let me know of next steps and please tag other relevant people on this 🙇

@abhullar-tt abhullar-tt added this to the BHLD milestone Dec 13, 2024
@mtairum
Copy link
Contributor Author

mtairum commented Dec 16, 2024

FIY @uaydonat

@mtairum mtairum changed the title Llama3-8b - list of required ops for blackhole Llama3 model family - list of required ops for blackhole Dec 16, 2024
@mtairum
Copy link
Contributor Author

mtairum commented Dec 16, 2024

Added ops for Llama3-70B as well.
It's mostly the same as 8B, but since it's exclusively multichip, has CCL ops: all-gather and reduce scatter.

@ntarafdar
Copy link
Contributor

@yugi957 has done great work , most TMs are accounted for and work on BH (embedding, slice, transpose, sharded_to_interleaved) . He will test the remaining ones tomorrow (interleaved_to_sharded, concat)

@ntarafdar
Copy link
Contributor

@yugi957 has confirmed all TMs that worked on WH for llama work on BH.

@mtairum
Copy link
Contributor Author

mtairum commented Jan 8, 2025

@ntarafdar that's great to hear.

What about other ops, such as the experimental ones (rotary embedding, paged dot product, etc.). Any chance to add these to the BH sweeps and test them there?

These will be crucial in supporting transformer-based LLMs in BH.

@bbradelTT
Copy link
Contributor

@vsureshTT will look at

  • ttnn.layer_norm
  • ttnn.arg_max

I created #16525 to track that effort.

@uaydonat
Copy link
Contributor

uaydonat commented Jan 8, 2025

@cmaryanTT mentioned that someone will be assigned to custom ops (e.g. nlp_create_qkv_heads_decode, rotary_embedding_llama, paged_scaled_dot_product_attention_decode, etc.).

@cmaryanTT
Copy link

@ntarafdar will be assigning someone in his group to look at the custom ops

@cmaryanTT
Copy link

@yugi957 and @amorrisonTT will be looking at the custom transformer ops. ETA Monday.
ttnn.experimental.nlp_create_qkv_heads_decode
tnn.experimental.rotary_embedding_llama
ttnn.experimental.paged_update_cache
ttnn.transformer.paged_scaled_dot_product_attention_decode
ttnn.experimental.nlp_concat_heads_decode

@amorrisonTT
Copy link

These traces don't appear to contain,

ttnn.transformer.paged_scaled_dot_product_attention_decode (although there is scaled_dot_product_attention_decode)
ttnn.add
ttnn.mul

Is this expected?

@cmaryanTT
Copy link

@amorrisonTT add and mult show up as "BinaryDeviceOperation" with the MULT or ADD type.

@cmaryanTT
Copy link

I think "paged" in the other op is just a typo

@amorrisonTT
Copy link

@amorrisonTT add and mult show up as "BinaryDeviceOperation" with the MULT or ADD type.

Thanks!

@bbradelTT
Copy link
Contributor

@mtairum the only reference to argmax is on row 1960 of llama-70b-1L-ops-pers.csv and it's for torch argmax and does not specify any inputs.

What do we need to check for ttnn.argmax?

@uaydonat
Copy link
Contributor

@amorrisonTT I think paged_scaled_dot_product_attention_decode is the right op. It is contained in 8B trace, not sure why it is not in 70B. @mtairum maybe you did not have Salar's vllm changes?

Also, we runargmax on device for batch=1, and on host if batch>1. I am guessing the traces are for batch=32, that's why you only see torch.argmax. It would be good verify ttnn.argmax since we will ultimately extend it for batch=32 as well, but it is lower priority.

@mtairum
Copy link
Contributor Author

mtairum commented Jan 10, 2025

@amorrisonTT add and mult show up as "BinaryDeviceOperation" with the MULT or ADD type.

This is correct.

the only reference to argmax is on row 1960 of llama-70b-1L-ops-pers.csv and it's for torch argmax and does not specify any inputs.

Re: argmax, it's what Utku mentioned.
We use ttnn.argmax with multicore if batch_size==1 (otherwise, it's single core). and the inputs are of shape [1, 1, 32, 128256] for either model.

tt_out_tok = ttnn.argmax(tt_out_rm, dim=3, use_multicore=False if batch_size > 1 else True, output_tensor=tt_out_tok)

and from a new tracy I just generated, the line for argmax:

ArgMax,tt_dnn_device,18972,1,{'dim': '3'; 'output_dtype': 'DataType::UINT32'; 'output_mem_config': 'MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED;buffer_type=BufferType::L1;shard_spec=std::nullopt)'; 'use_multicore': 'true'},,47,,52725463868,53033342829,307878961,656089241697,656089887765,303135332,646068,644470,,644470,,,,,,,1,1,32,128256,ROW_MAJOR,BFLOAT16,DEV_1_L1_INTERLEAVED,1,1,1,32,ROW_MAJOR,UINT32,DEV_1_DRAM_INTERLEAVED,,,,,,,,,,,,,,,,,,,,,,1,1,1,32,ROW_MAJOR,UINT32,DEV_1_DRAM_INTERLEAVED,,,,,,,,,,,,,,,,,[],[],['ttnn/cpp/ttnn/operations/reduction/argmax/device/kernels/reader_argmax_interleaved_multicore.cpp'],['reader_argmax_interleaved_multicore/7829460339933776415/'],0,1344,0,0,0,0,1,1,1,[8208384.0],[128.0],307448723,3593

I think paged_scaled_dot_product_attention_decode is the right op. It is contained in 8B trace, not sure why it is not in 70B. @mtairum maybe you did not have Salar's vllm changes?

We care about both:

  • scaled_dot_product_attention_decode
  • paged_scaled_dot_product_attention_decode

It's weird to me why the 70B trace above is not using the paged implementation, I'm pretty sure I run the same config on both. In either case we support both versions, and both have the same input shapes.

@amorrisonTT
Copy link

I think paged_scaled_dot_product_attention_decode is the right op. It is contained in 8B trace, not sure why it is not in 70B. @mtairum maybe you did not have Salar's vllm changes?

We care about both:

  • scaled_dot_product_attention_decode
  • paged_scaled_dot_product_attention_decode

It's weird to me why the 70B trace above is not using the paged implementation, I'm pretty sure I run the same config on both. In either case we support both versions, and both have the same input shapes.

I didn't see paged_scaled_dot_product_attention_decode in either attached trace:

PERF_FILES = ["llama8B-1L-model-ops-perf.csv", "llama-70b-1L-ops-pers.csv"]
df = pd.concat([pd.read_csv(f) for f in PERF_FILES])
df["OP CODE"].value_counts()

OP CODE
Matmul                                 249
InterleavedToShardedDeviceOperation    147
LayerNorm                              147
BinaryDeviceOperation                  147
ReshardDeviceOperation                 145
AllGather                              144
ShardedToInterleavedDeviceOperation    101
Embeddings                              98
Transpose                               98
SliceDeviceOperation                    98
RotaryEmbeddingLlama                    98
PagedUpdateCacheDeviceOperation         98
(torch) __getitem__                     55
NLPCreateHeadsDecodeDeviceOperation     49
ScaledDotProductAttentionDecode         49
NLPConcatHeadsDecodeDeviceOperation     49
ReduceScatter                           48
AllGatherMatmul                         48
(torch) cat                             36
(torch) reshape                         24
(torch) abs                             18
(torch) transpose                       12
(torch) item                            12
(torch) max                             12
(torch) sub                             12
(torch) __get__                          7
(torch) squeeze                          7
(torch) div                              6
(torch) allclose                         6
(torch) permute                          6
ConcatDeviceOperation                    1
(torch) argmax                           1
(torch) embedding                        1
(torch) tolist                           1
Name: count, dtype: int64

@uaydonat
Copy link
Contributor

Hmm, for 8B, the op graph has the paged_scaled_dot_product_attention_decode but not the trace.

It might be some wrong naming, because ScaledDotProductAttentionDecode has paged_attention: true in its arguments for both 8B and 70B.

@mtairum
Copy link
Contributor Author

mtairum commented Jan 13, 2025

Good point. @uaydonat is right.

I double checked the op kernel and the paged version of op is indeed set by an argument.
It executes the ScaledDotProductAttentionDecode with paged_attention=True

@amorrisonTT
Copy link

amorrisonTT commented Jan 13, 2025

ttnn.experimental.nlp_create_qkv_heads_decode failing with:

PCC value: Max ATOL Delta: 0.9921875, Max RTOL Delta: 79360.0, PCC: 0.49738503615123153, PCC check failed
2025-01-10 23:28:15.657 | INFO     | tests.tt_eager.python_api_testing.unit_testing.misc.test_nlp_create_qkv_heads_decode:run_test_create_head_interleaved:66 - PCC value: Max ATOL Delta: 0.98828125, Max RTOL Delta: 18944.0, PCC: 0.5012214626471699, PCC check failed
2025-01-10 23:28:15.660 | INFO     | tests.tt_eager.python_api_testing.unit_testing.misc.test_nlp_create_qkv_heads_decode:run_test_create_head_interleaved:70 - PCC value: Max ATOL Delta: 0.9921875, Max RTOL Delta: 7264.0, PCC: 0.4905412229394511, PCC check failed

See #16667

@amorrisonTT
Copy link

amorrisonTT commented Jan 13, 2025

ttnn.experimental.paged_update_cache consistently causes the machine to hang. See #16674

@cmaryanTT
Copy link

Per @eyonland ADD works, MULT has a non-determinism issue (#16662)

@cmaryanTT
Copy link

@amorrisonTT can you please open issues for the problems you found

@ntarafdar
Copy link
Contributor

ntarafdar commented Jan 13, 2025

@amorrisonTT when you make the update_cache and create_qkv_heads_decode please issue assign it to @cglagovich

@amorrisonTT
Copy link

amorrisonTT commented Jan 13, 2025

ttnn.transformer.scaled_dot_product_attention_decode (with and without paged_attention flag) is failing see #16673

@vsureshTT
Copy link
Contributor

vsureshTT commented Jan 13, 2025

@mtairum
The shapes required for layer norm are

input 0: [1,1,32,8192] input 1: [1,1,256,32]
and
input 0: [1,1,32,4096] input 1: [1,1,128,32]

I also linked a truncated version of the spreadsheet below with the layernorm isolated.

LayerNorms1.csv

@mtairum
Copy link
Contributor Author

mtairum commented Jan 14, 2025

I've added new traces and two new prefill only ops:

  • ttnn.transformer.scaled_dot_product_attention (if not using chunks)
  • ttnn.transformer.chunked_scaled_dot_product_attention(if using chunks - not in the traces. This is the same op as previous one but with page table and chunk start index)

The chunked version was not used in the trace, but it's basically a variation of the main op. This op is separate than the scaled dot product decode that's already being tested.

@cglagovichTT
Copy link
Contributor

^ note that above traces also have these prefill-specific ops

  • ttnn.experimental.nlp_create_qkv_heads
  • ttnn.experimental.paged_fill_cache
  • ttnn.experimental.nlp_concat_heads

@bbradelTT
Copy link
Contributor

@vsureshTT ran the tests. The test scenarios failed.

I'll update the description with issue numbers.

I tried on WH and it seems that the behaviour is the same as on BH, which means that this may not block the model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

9 participants