Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metadata agnostic user computation hash #8557

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions test/neuron/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ MAX_GRAPH_SIZE=500
GRAPH_CHECK_FREQUENCY=100
VERBOSITY=2

# Utils file
source "${CDIR}/utils/run_tests_utils.sh"

# Note [Keep Going]
#
# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CI tests continue on error.
Expand Down Expand Up @@ -93,16 +96,6 @@ function run_eager_debug {
XLA_USE_EAGER_DEBUG_MODE=1 run_test "$@"
}

function run_save_tensor_ir {
echo "Running in save tensor file mode: $@"
XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="text" run_test "$@"
}

function run_save_tensor_hlo {
echo "Running in save tensor file mode: $@"
XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="hlo" run_test "$@"
}

function run_pt_xla_debug {
echo "Running in save tensor file mode: $@"
PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@"
Expand Down Expand Up @@ -166,16 +159,16 @@ function run_xla_op_tests1 {
run_test "$CDIR/dynamo/test_num_output.py"
run_test "$CDIR/dynamo/test_graph_input_matcher.py"
run_test "$CDIR/dynamo/test_dynamo_config.py"
run_save_tensor_ir "$CDIR/dynamo/test_dynamo_graph_dump.py"
run_save_tensor_ir run_test "$CDIR/dynamo/test_dynamo_graph_dump.py"
#run_test "$CDIR/test_data_type.py"
run_use_bf16 "$CDIR/test_data_type.py"
run_downcast_bf16 "$CDIR/test_data_type.py"
#run_test "$CDIR/test_fp8.py"
run_xla_ir_debug "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_save_load.py"
run_save_tensor_ir "$CDIR/spmd/test_spmd_graph_dump.py"
run_save_tensor_hlo "$CDIR/spmd/test_spmd_graph_dump.py"
run_save_tensor_ir run_test "$CDIR/spmd/test_spmd_graph_dump.py"
run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_graph_dump.py"
}

function run_xla_op_tests2 {
Expand Down Expand Up @@ -230,6 +223,7 @@ function run_xla_op_tests3 {
run_torchrun "$CDIR/pjrt/test_torchrun.py"
run_test "$CDIR/test_persistent_cache.py"
run_test "$CDIR/test_devices.py"
run_xla_ir_hlo_debug run_test "$CDIR/test_user_computation_debug_cache.py"

#python3 examples/data_parallel/train_resnet_xla_ddp.py # compiler error
#python3 examples/fsdp/train_resnet_fsdp_auto_wrap.py
Expand Down
19 changes: 5 additions & 14 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,6 @@ function run_test_without_functionalization {
XLA_DISABLE_FUNCTIONALIZATION=1 run_test "$@"
}

function run_xla_ir_debug {
echo "Running with XLA_IR_DEBUG: $@"
XLA_IR_DEBUG=1 run_test "$@"
}

function run_use_bf16 {
echo "Running with XLA_USE_BF16: $@"
XLA_USE_BF16=1 run_test "$@"
Expand All @@ -100,11 +95,6 @@ function run_downcast_bf16 {
XLA_DOWNCAST_BF16=1 run_test "$@"
}

function run_xla_hlo_debug {
echo "Running with XLA_IR_DEBUG: $@"
XLA_HLO_DEBUG=1 run_test "$@"
}

function run_dynamic {
echo "Running in DynamicShape mode: $@"
XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter:nms" run_test "$@"
Expand Down Expand Up @@ -191,9 +181,9 @@ function run_xla_op_tests1 {
run_use_bf16 "$CDIR/test_data_type.py"
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/test_fp8.py"
run_xla_ir_debug "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_save_load.py"
run_xla_ir_debug run_test "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug run_test "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug run_test "$CDIR/stablehlo/test_stablehlo_save_load.py"
run_save_tensor_ir run_test "$CDIR/spmd/test_spmd_graph_dump.py"
run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_graph_dump.py"
}
Expand Down Expand Up @@ -224,7 +214,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/stablehlo/test_composite.py"
run_test "$CDIR/stablehlo/test_pt2e_qdq.py"
run_test "$CDIR/stablehlo/test_stablehlo_custom_call.py"
run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_inference.py"
run_xla_hlo_debug run_test "$CDIR/stablehlo/test_stablehlo_inference.py"
run_test "$CDIR/stablehlo/test_stablehlo_compile.py"
run_test "$CDIR/stablehlo/test_unbounded_dynamism.py"
run_test "$CDIR/quantized_ops/test_quantized_matmul.py"
Expand Down Expand Up @@ -252,6 +242,7 @@ function run_xla_op_tests3 {
# NOTE: this line below is testing export and don't care about GPU
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$CDIR/test_core_aten_ops.py"
run_test "$CDIR/test_pallas.py"
run_xla_ir_hlo_debug run_test "$CDIR/test_user_computation_debug_cache.py"

# CUDA tests
if [ -x "$(command -v nvidia-smi)" ]; then
Expand Down
70 changes: 70 additions & 0 deletions test/test_user_computation_debug_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
import sys
import unittest

import torch
import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

parent_folder = os.path.dirname(os.path.dirname(__file__))
sys.path.append(parent_folder)


class TestUserComputationDebugCache(unittest.TestCase):

def setUp(self):
self.assertTrue(
os.getenv("XLA_IR_DEBUG") == '1' and os.getenv("XLA_HLO_DEBUG") == '1',
"XLA_IR_DEBUG and XLA_HLO_DEBUG must be set for this test.",
)

def test_user_computation_debug_cache(self):
"""
Test that user computations with the same IR, but different OpMetadata
are cached correctly. The metadata is generated when the environment
variables that enable the Python stack trace for the IR nodes, and
subsequently, the XLA HLO metadata; `XLA_IR_DEBUG` and `XLA_HLO_DEBUG`
respectively.
"""

met.clear_all()

def fn_op(a, b):
return xb.Op.tuple([xb.Op.max(a, b) - xb.Op.min(a, b)])

def input_scope_0(tensor):
return [torch.sin(tensor), torch.cos(tensor)]

def input_scope_1(tensor):
return [torch.sin(tensor), torch.cos(tensor)]

device = xm.xla_device()
init_tensor = torch.tensor(10).to(device)

def create_user_computation(fn):
inputs = fn(init_tensor)
comp = xb.create_computation("computation", fn_op,
[xb.tensor_shape(p) for p in inputs])
_ = torch_xla._XLAC._xla_user_computation("xla::computation", inputs,
comp)
torch_xla.sync()

# Create and launch the graph execution with the same IR graph, but with
# different input tensor scope. When 'XLA_HLO_DEBUG' and 'XLA_IR_DEBUG' are
# enabled, this will generate different OpMetadata for different input
# scopes `input_scope_0` and `input_scope_1`, namely `source_line`.
create_user_computation(input_scope_0)
create_user_computation(input_scope_1)

# Ensure that we only compile once, and hit the cache the next time. This
# is expected as the OpMetadata will not impact the hash of the user
# computation, as the compiled executable is semantically the same.
self.assertEqual(met.counter_value("UncachedCompile"), 1)
self.assertEqual(met.counter_value("CachedCompile"), 1)


if __name__ == "__main__":
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_all_reduce_xla_back
python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py"
python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py"
python3 "$TEST_CDIR/quantized_ops/test_dot_general.py"
run_xla_ir_hlo_debug python3 "$TEST_CDIR/test_user_computation_debug_cache.py"

# run examples, each test should takes <2 minutes
python3 "$TEST_CDIR/../examples/data_parallel/train_resnet_spmd_data_parallel.py"
Expand Down
21 changes: 21 additions & 0 deletions test/utils/run_tests_utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,24 @@ function run_save_tensor_hlo {
echo "Running in save tensor file mode: $@"
run_save_tensor "$run_test_func" "hlo" "$@"
}

function run_xla_ir_debug {
local run_test_func="$1"
shift
echo "Running with XLA_IR_DEBUG: $@"
XLA_IR_DEBUG=1 "$run_test_func" "$@"
}

function run_xla_hlo_debug {
local run_test_func="$1"
shift
echo "Running with XLA_HLO_DEBUG: $@"
XLA_HLO_DEBUG=1 "$run_test_func" "$@"
}

function run_xla_ir_hlo_debug {
local run_test_func="$1"
shift
echo "Running with XLA_IR_DEBUG and XLA_HLO_DEBUG: $@"
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 "$run_test_func" "$@"
}
7 changes: 6 additions & 1 deletion torch_xla/csrc/runtime/computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,13 @@ metrics::Metric* ComputationClient::OutboundDataMetric() {
}

::absl::StatusOr<torch::lazy::hash_t>
ComputationClient::Computation::ComputeHash(const xla::HloModuleProto& proto,
ComputationClient::Computation::ComputeHash(xla::HloModuleProto proto,
const std::string& name) {
for (auto& computation : *proto.mutable_computations()) {
for (auto& instruction : *computation.mutable_instructions()) {
instruction.mutable_metadata()->Clear();
}
}
TF_ASSIGN_OR_RETURN(auto serialized_status,
util::GetDeterministicSerializedModuleProto(proto));
return torch::lazy::MHash(name, serialized_status);
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class ComputationClient {
// elements during during serialization. The resulting hash combines the
// serialized module with its computation name.
static ::absl::StatusOr<torch::lazy::hash_t> ComputeHash(
const xla::HloModuleProto& proto, const std::string& name);
xla::HloModuleProto proto, const std::string& name);
};

using ComputationPtr = std::shared_ptr<Computation>;
Expand Down
Loading