Skip to content

Commit

Permalink
#16013: clean OG embedding sweep and create BH embedding sweep
Browse files Browse the repository at this point in the history
  • Loading branch information
yugi957 committed Dec 19, 2024
1 parent fa779b9 commit 7ed63fd
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 47 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple

import torch
import random
import ttnn

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.utility_functions import torch_random

TIMEOUT = 10
# seed for random
random.seed(0)

parameters = {
"nightly": {
"embedding_specs": [
{"weight_shape": [256, 128], "indices_shape": [1, 32]},
],
}
}


# Invalidate vector is called during the generation phase where each vector will be passed in.
# If invalidated, the vector will still be stored but will be skipped.
# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid.
def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:
if test_vector["layout"] == ttnn.ROW_MAJOR_LAYOUT:
if test_vector["dtype"] == ttnn.bfloat8_b:
return True, "bfloat8_b not supported with ROW_MAJOR_LAYOUT"

return False, None


def run(
embedding_specs,
*,
device,
):
device.enable_async(False)

# Extract the weight and indices shape from embedding_specs
weight_shape = embedding_specs["weight_shape"]
indices_shape = embedding_specs["indices_shape"]
padding_idx = embedding_specs.get("padding_idx", None) # Optional padding index

# Create random weight and indices tensors in PyTorch
weight = torch_random(weight_shape, -0.1, 0.1, dtype=torch.bfloat16)
indices = torch.randint(0, weight_shape[0], indices_shape, dtype=torch.int32)

# Create a PyTorch embedding layer and apply it
torch_embedding = torch.nn.Embedding.from_pretrained(weight, padding_idx=padding_idx)
torch_output_tensor = torch_embedding(indices)

# Convert the weight and indices to ttnn tensor format
ttnn_weight = ttnn.from_torch(weight, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.bfloat16)
ttnn_indices = ttnn.from_torch(indices, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.uint32)

# Measure performance of the embedding operation in ttnn
start_time = start_measuring_time()

# Apply embedding in ttnn
ttnn_output_tensor = ttnn.embedding(
ttnn_indices,
ttnn_weight,
padding_idx=padding_idx,
layout=ttnn.TILE_LAYOUT,
embeddings_type=ttnn.EmbeddingsType.GENERIC, # Default embeddings type
dtype=ttnn.bfloat16,
output_tensor=None, # No preallocated output tensor
memory_config=None, # Default memory config
queue_id=0, # Default queue id
)

e2e_perf = stop_measuring_time(start_time)

# Convert the ttnn tensor back to PyTorch for comparison
ttnn_output_tensor = ttnn.to_torch(ttnn_output_tensor)

# Compare the results and return performance and accuracy check
result = check_with_pcc(torch_output_tensor, ttnn_output_tensor, 0.999)

return [result, e2e_perf]
Original file line number Diff line number Diff line change
Expand Up @@ -15,53 +15,6 @@
# seed for random
random.seed(0)


def extract_brackets_content(line):
# Function to extract the content inside brackets
brackets_content = []
open_brackets = 0
current_content = ""

for char in line:
if char == "[":
open_brackets += 1
if open_brackets > 0:
current_content = "" # Reset content inside the brackets
elif char == "]":
if open_brackets > 0:
brackets_content.append(current_content.strip())
open_brackets -= 1
elif open_brackets > 0:
current_content += char

return brackets_content


def parse_md_file_simple_no_regex(file_path):
view_specs = []
i = 0
with open(file_path, "r") as file:
for line in file:
# Extract all sets of content inside brackets
brackets_content = extract_brackets_content(line)

if len(brackets_content) >= 3: # Ensure we have both shape and size
shape_str = brackets_content[0] # First set of brackets for shape
size_str = brackets_content[2] # Third set of brackets for size

# Convert the shape and size strings to lists of integers
if "s" in shape_str or "s" in size_str:
continue
shape = list(map(int, shape_str.split(",")))
size = list(map(int, size_str.split(",")))

# Append the dictionary to the list
view_specs.append({"shape": shape, "size": size})
i += 1

return view_specs


parameters = {
"nightly": {
"embedding_specs": [
Expand Down

0 comments on commit 7ed63fd

Please sign in to comment.