From cbe1f20d68271c11a2c70e11c4aa246f24b26441 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 8 Jan 2025 13:57:26 -0500 Subject: [PATCH 1/6] allow device to configure conversion to numpy --- pennylane/devices/default_qubit.py | 6 +++++- pennylane/devices/execution_config.py | 13 ++++++++++--- pennylane/workflow/_setup_transform_program.py | 3 ++- pennylane/workflow/run.py | 2 +- tests/param_shift_dev.py | 1 + 5 files changed, 19 insertions(+), 6 deletions(-) diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index fcdb25d2783..f1ac48d6461 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -581,6 +581,11 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio """ updated_values = {} + updated_values["convert_to_numpy"] = ( + execution_config.interface.value not in {"jax", "jax-jit"} + or execution_config.gradient_method == "adjoint" + ) + for option in execution_config.device_options: if option not in self._device_options: raise qml.DeviceError(f"device option {option} not present on {self}") @@ -616,7 +621,6 @@ def execute( execution_config: ExecutionConfig = DefaultExecutionConfig, ) -> Union[Result, ResultBatch]: self.reset_prng_key() - max_workers = execution_config.device_options.get("max_workers", self._max_workers) self._state_cache = {} if execution_config.use_device_jacobian_product else None interface = ( diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index fc953a2dc22..60d3fc6482d 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -17,7 +17,7 @@ from dataclasses import dataclass, field from typing import Optional, Union -from pennylane.math import get_canonical_interface_name +from pennylane.math import Interface, get_canonical_interface_name from pennylane.transforms.core import TransformDispatcher @@ -87,7 +87,7 @@ class ExecutionConfig: device_options: Optional[dict] = None """Various options for the device executing a quantum circuit""" - interface: Optional[str] = None + interface: Interface = Interface.NUMPY """The machine learning framework to use""" derivative_order: int = 1 @@ -96,6 +96,13 @@ class ExecutionConfig: mcm_config: MCMConfig = field(default_factory=MCMConfig) """Configuration options for handling mid-circuit measurements""" + convert_to_numpy: bool = True + """Whether or not to convert parameters to numpy before execution. + + If ``False`` and using the jax-jit, no pure callback will occur and the device + execution itself will be jitted. + """ + def __post_init__(self): """ Validate the configured execution options. @@ -124,7 +131,7 @@ def __post_init__(self): ) if isinstance(self.mcm_config, dict): - self.mcm_config = MCMConfig(**self.mcm_config) + self.mcm_config = MCMConfig(**self.mcm_config) # pylint: disable=not-a-mapping elif not isinstance(self.mcm_config, MCMConfig): raise ValueError(f"Got invalid type {type(self.mcm_config)} for 'mcm_config'") diff --git a/pennylane/workflow/_setup_transform_program.py b/pennylane/workflow/_setup_transform_program.py index 5866c8f7bf4..67cfab372b9 100644 --- a/pennylane/workflow/_setup_transform_program.py +++ b/pennylane/workflow/_setup_transform_program.py @@ -117,7 +117,8 @@ def _setup_transform_program( # changing this set of conditions causes a bunch of tests to break. interface_data_supported = ( - resolved_execution_config.interface is Interface.NUMPY + (not resolved_execution_config.convert_to_numpy) + or resolved_execution_config.interface is Interface.NUMPY or resolved_execution_config.gradient_method == "backprop" or ( getattr(device, "short_name", "") == "default.mixed" diff --git a/pennylane/workflow/run.py b/pennylane/workflow/run.py index 456d04f1ad7..9f1096ab32b 100644 --- a/pennylane/workflow/run.py +++ b/pennylane/workflow/run.py @@ -204,7 +204,7 @@ def _get_ml_boundary_execute( elif interface == Interface.TORCH: from .interfaces.torch import execute as ml_boundary - elif interface == Interface.JAX_JIT: + elif interface == Interface.JAX_JIT and resolved_execution_config.convert_to_numpy: from .interfaces.jax_jit import jax_jit_jvp_execute as ml_boundary else: # interface is jax diff --git a/tests/param_shift_dev.py b/tests/param_shift_dev.py index 12fd11eea16..7c7442161e2 100644 --- a/tests/param_shift_dev.py +++ b/tests/param_shift_dev.py @@ -34,6 +34,7 @@ def preprocess(self, execution_config=qml.devices.DefaultExecutionConfig): execution_config, use_device_jacobian_product=True ) program, config = super().preprocess(execution_config) + config = dataclasses.replace(config, convert_to_numpy=True) program.add_transform(qml.transform(qml.gradients.param_shift.expand_transform)) return program, config From cef0e31f010b3ac2d71472c888f2fb69faef7346 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 8 Jan 2025 15:17:19 -0500 Subject: [PATCH 2/6] testing and changelog --- doc/releases/changelog-dev.md | 5 ++++ pennylane/devices/execution_config.py | 2 +- pennylane/devices/qubit/sampling.py | 4 +-- .../test_default_qubit_native_mcm.py | 6 +---- .../test_default_qubit_preprocessing.py | 26 +++++++++++++++++++ .../interfaces/test_jacobian_products.py | 4 +-- .../workflow/test_setup_transform_program.py | 23 ++++++++-------- 7 files changed, 48 insertions(+), 22 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 0f9e46664b9..9cd3f6c51ff 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -6,6 +6,11 @@

Improvements 🛠

+* Devices can now configure whether or not the data is converted to numpy and `jax.pure_callback` + is used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions + on `default.qubit` can now be jitted end-to-end, even with parameter shift. + [(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788) +

Breaking changes 💔

Deprecations 👋

diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index 60d3fc6482d..81d63ee79a7 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -98,7 +98,7 @@ class ExecutionConfig: convert_to_numpy: bool = True """Whether or not to convert parameters to numpy before execution. - + If ``False`` and using the jax-jit, no pure callback will occur and the device execution itself will be jitted. """ diff --git a/pennylane/devices/qubit/sampling.py b/pennylane/devices/qubit/sampling.py index 06ae78b5708..527e3296a5c 100644 --- a/pennylane/devices/qubit/sampling.py +++ b/pennylane/devices/qubit/sampling.py @@ -580,6 +580,6 @@ def _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key=None, _, key = jax_random_split(prng_key) samples = jax.random.choice(key, basis_states, shape=(shots,), p=probs) - powers_of_two = 1 << jnp.arange(num_wires, dtype=jnp.int64)[::-1] + powers_of_two = 1 << jnp.arange(num_wires, dtype=int)[::-1] states_sampled_base_ten = samples[..., None] & powers_of_two - return (states_sampled_base_ten > 0).astype(jnp.int64) + return (states_sampled_base_ten > 0).astype(int) diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py index dc6dd8fb67e..42faaa10809 100644 --- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py +++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py @@ -389,11 +389,7 @@ def func(x, y, z): results1 = func1(*params) jaxpr = str(jax.make_jaxpr(func)(*params)) - if diff_method == "best": - assert "pure_callback" in jaxpr - pytest.xfail("QNode with diff_method='best' cannot be compiled with jax.jit.") - else: - assert "pure_callback" not in jaxpr + assert "pure_callback" not in jaxpr func2 = jax.jit(func) results2 = func2(*params) diff --git a/tests/devices/default_qubit/test_default_qubit_preprocessing.py b/tests/devices/default_qubit/test_default_qubit_preprocessing.py index 40f45a9383e..59a9098f7ce 100644 --- a/tests/devices/default_qubit/test_default_qubit_preprocessing.py +++ b/tests/devices/default_qubit/test_default_qubit_preprocessing.py @@ -141,6 +141,32 @@ def circuit(x): assert dev.tracker.totals["execute_and_derivative_batches"] == 1 + @pytest.mark.parametrize("interface", ("jax", "jax-jit")) + def test_not_convert_to_numpy_with_jax(self, interface): + """Test that we will not convert to numpy when working with jax.""" + + dev = qml.device("default.qubit") + config = qml.devices.ExecutionConfig( + gradient_method=qml.gradients.param_shift, interface=interface + ) + processed = dev.setup_execution_config(config) + assert not processed.convert_to_numpy + + def test_convert_to_numpy_with_adjoint(self): + """Test that we will convert to numpy with adjoint.""" + config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface="jax-jit") + dev = qml.device("default.qubit") + processed = dev.setup_execution_config(config) + assert processed.convert_to_numpy + + @pytest.mark.parametrize("interface", ("autograd", "torch", "tf")) + def test_convert_to_numpy_non_jax(self, interface): + """Test that other interfaces are still converted to numpy.""" + config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface=interface) + dev = qml.device("default.qubit") + processed = dev.setup_execution_config(config) + assert processed.convert_to_numpy + # pylint: disable=too-few-public-methods class TestPreprocessing: diff --git a/tests/workflow/interfaces/test_jacobian_products.py b/tests/workflow/interfaces/test_jacobian_products.py index 1991ba2d523..4d90f2e5012 100644 --- a/tests/workflow/interfaces/test_jacobian_products.py +++ b/tests/workflow/interfaces/test_jacobian_products.py @@ -136,7 +136,7 @@ def test_device_jacobians_repr(self): r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}," r" device_options={}, interface=, derivative_order=1," - r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>" + r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True)>" ) assert repr(jpc) == expected @@ -155,7 +155,7 @@ def test_device_jacobian_products_repr(self): r" use_device_jacobian_product=None," r" gradient_method='adjoint', gradient_keyword_arguments={}, device_options={}," r" interface=, derivative_order=1," - r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None))>" + r" mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True)>" ) assert repr(jpc) == expected diff --git a/tests/workflow/test_setup_transform_program.py b/tests/workflow/test_setup_transform_program.py index c81ed75ae2f..79207491014 100644 --- a/tests/workflow/test_setup_transform_program.py +++ b/tests/workflow/test_setup_transform_program.py @@ -140,9 +140,7 @@ def test_prune_dynamic_transform_warning_raised(): def test_interface_data_not_supported(): """Test that convert_to_numpy_parameters transform is correctly added.""" - config = ExecutionConfig() - config.interface = "autograd" - config.gradient_method = "adjoint" + config = ExecutionConfig(interface="autograd", gradient_method="adjoint") device = qml.device("default.qubit") user_transform_program = TransformProgram() @@ -154,10 +152,8 @@ def test_interface_data_not_supported(): def test_interface_data_supported(): """Test that convert_to_numpy_parameters transform is not added for these cases.""" - config = ExecutionConfig() + config = ExecutionConfig(interface="autograd", gradient_method=None) - config.interface = "autograd" - config.gradient_method = None device = qml.device("default.mixed", wires=1) user_transform_program = TransformProgram() @@ -165,10 +161,8 @@ def test_interface_data_supported(): assert qml.transforms.convert_to_numpy_parameters not in inner_tp - config = ExecutionConfig() + config = ExecutionConfig(interface="autograd", gradient_method="backprop") - config.interface = "autograd" - config.gradient_method = "backprop" device = qml.device("default.qubit") user_transform_program = TransformProgram() @@ -176,10 +170,8 @@ def test_interface_data_supported(): assert qml.transforms.convert_to_numpy_parameters not in inner_tp - config = ExecutionConfig() + config = ExecutionConfig(interface=None, gradient_method="backprop") - config.interface = None - config.gradient_method = "backprop" device = qml.device("default.qubit") user_transform_program = TransformProgram() @@ -187,6 +179,13 @@ def test_interface_data_supported(): assert qml.transforms.convert_to_numpy_parameters not in inner_tp + config = ExecutionConfig( + convert_to_numpy=False, interface="jax", gradient_method=qml.gradients.param_shift + ) + + _, inner_tp = _setup_transform_program(TransformProgram(), device, config) + assert qml.transforms.convert_to_numpy_parameters not in inner_tp + def test_cache_handling(): """Test that caching is handled correctly.""" From f6949e1d687fbe852a66d16b32d0410245ad0779 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 8 Jan 2025 15:39:04 -0500 Subject: [PATCH 3/6] updating xfailing pulse gradient tests --- pennylane/devices/default_mixed.py | 4 ++-- tests/gradients/core/test_pulse_gradient.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pennylane/devices/default_mixed.py b/pennylane/devices/default_mixed.py index 9e18ba09249..d55f9eec623 100644 --- a/pennylane/devices/default_mixed.py +++ b/pennylane/devices/default_mixed.py @@ -1008,8 +1008,8 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio "best", } updated_values["grad_on_execution"] = False - if not execution_config.gradient_method in {"best", "backprop", None}: - execution_config.interface = None + if execution_config.gradient_method not in {"best", "backprop", None}: + updated_values["interface"] = None # Add device options updated_values["device_options"] = dict(execution_config.device_options) # copy diff --git a/tests/gradients/core/test_pulse_gradient.py b/tests/gradients/core/test_pulse_gradient.py index 5fd9bf34937..d0aba1f4582 100644 --- a/tests/gradients/core/test_pulse_gradient.py +++ b/tests/gradients/core/test_pulse_gradient.py @@ -1485,7 +1485,6 @@ def circuit(params): assert qml.math.allclose(j[0], e, atol=tol, rtol=0.0) jax.clear_caches() - @pytest.mark.xfail @pytest.mark.parametrize("num_split_times", [1, 2]) @pytest.mark.parametrize("time_interface", ["python", "numpy", "jax"]) def test_simple_qnode_jit(self, num_split_times, time_interface): From 1b069b8fd156f529973f3da012bd5ebb928d086d Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 8 Jan 2025 16:15:25 -0500 Subject: [PATCH 4/6] fixing more tests --- pennylane/devices/default_qubit.py | 10 ++++++++++ .../test_mottonen_state_prep.py | 4 ++-- tests/test_qnode.py | 8 +++++--- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index f1ac48d6461..4392782c094 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -161,6 +161,14 @@ def _conditional_broastcast_expand(tape): return (tape,), null_postprocessing +@qml.transform +def no_counts(tape): + """Throws an error on counts measurements.""" + if any(isinstance(mp, qml.measurements.CountsMP) for mp in tape.measurements): + raise NotImplementedError("The JAX-JIT interface doesn't support qml.counts.") + return (tape,), null_postprocessing + + @qml.transform def adjoint_state_measurements( tape: QuantumScript, device_vjp=False @@ -535,6 +543,8 @@ def preprocess( config = self._setup_execution_config(execution_config) transform_program = TransformProgram() + if config.interface == qml.math.Interface.JAX_JIT: + transform_program.add_transform(no_counts) transform_program.add_transform(validate_device_wires, self.wires, name=self.name) transform_program.add_transform( mid_circuit_measurements, device=self, mcm_config=config.mcm_config diff --git a/tests/templates/test_state_preparations/test_mottonen_state_prep.py b/tests/templates/test_state_preparations/test_mottonen_state_prep.py index 1a9f0aa6219..494e750fe2e 100644 --- a/tests/templates/test_state_preparations/test_mottonen_state_prep.py +++ b/tests/templates/test_state_preparations/test_mottonen_state_prep.py @@ -417,7 +417,7 @@ def circuit(state): @pytest.mark.jax -@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.05)]) +@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.1)]) def test_jacobians_with_and_without_jit_match(shots, atol, seed): """Test that the Jacobian of the circuit is the same with and without jit.""" import jax @@ -433,7 +433,7 @@ def circuit(coeffs): circuit_ps = qml.QNode(circuit, dev, diff_method="parameter-shift") circuit_exact = qml.QNode(circuit, dev_no_shots) - params = jax.numpy.array([0.5, 0.5, 0.5, 0.5]) + params = jax.numpy.array([0.5, 0.5, 0.5, 0.5], dtype=jax.numpy.float64) jac_exact_fn = jax.jacobian(circuit_exact) jac_fd_fn = jax.jacobian(circuit_fd) jac_fd_fn_jit = jax.jit(jac_fd_fn) diff --git a/tests/test_qnode.py b/tests/test_qnode.py index 5ac318333a3..f4075375653 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -687,12 +687,13 @@ def func(x, y): assert tape.measurements == contents[3:] @pytest.mark.jax - def test_jit_counts_raises_error(self): + @pytest.mark.parametrize("dev_name", ("default.qubit", "reference.qubit")) + def test_jit_counts_raises_error(self, dev_name): """Test that returning counts in a quantum function with trainable parameters while jitting raises an error.""" import jax - dev = qml.device("default.qubit", wires=2, shots=5) + dev = qml.device(dev_name, wires=2, shots=5) def circuit1(param): qml.Hadamard(0) @@ -706,7 +707,8 @@ def circuit1(param): with pytest.raises( NotImplementedError, match="The JAX-JIT interface doesn't support qml.counts." ): - jitted_qnode1(0.123) + out = jitted_qnode1(0.123) + print(out) # Test with qnode decorator syntax @qml.qnode(dev) From fbce8cc5c7d0a62f5685d2ea512b3a5590943d74 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 9 Jan 2025 09:14:30 -0500 Subject: [PATCH 5/6] fixing more tests --- pennylane/devices/default_qubit.py | 3 ++- pennylane/workflow/interfaces/jax.py | 2 +- pennylane/workflow/interfaces/jax_jit.py | 2 +- .../interfaces/execute/test_jax_jit.py | 19 ++++++++++++++----- .../interfaces/qnode/test_jax_jit_qnode.py | 3 ++- 5 files changed, 20 insertions(+), 9 deletions(-) diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 4392782c094..dc4b116fb29 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -594,8 +594,9 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio updated_values["convert_to_numpy"] = ( execution_config.interface.value not in {"jax", "jax-jit"} or execution_config.gradient_method == "adjoint" + # need numpy to use caching, and need caching higher order derivatives + or execution_config.derivative_order > 1 ) - for option in execution_config.device_options: if option not in self._device_options: raise qml.DeviceError(f"device option {option} not present on {self}") diff --git a/pennylane/workflow/interfaces/jax.py b/pennylane/workflow/interfaces/jax.py index aaa2c55dd1e..a8f5b0bea32 100644 --- a/pennylane/workflow/interfaces/jax.py +++ b/pennylane/workflow/interfaces/jax.py @@ -186,7 +186,7 @@ def _to_jax(result: qml.typing.ResultBatch) -> qml.typing.ResultBatch: return result if isinstance(result, (list, tuple)): return tuple(_to_jax(r) for r in result) - return jnp.array(result) + return result if qml.math.get_interface(result) == "jax" else jnp.array(result) def _execute_wrapper(params, tapes, execute_fn, jpc) -> ResultBatch: diff --git a/pennylane/workflow/interfaces/jax_jit.py b/pennylane/workflow/interfaces/jax_jit.py index 3cd5779a5de..296afc4408f 100644 --- a/pennylane/workflow/interfaces/jax_jit.py +++ b/pennylane/workflow/interfaces/jax_jit.py @@ -59,7 +59,7 @@ def _to_jax(result: qml.typing.ResultBatch) -> qml.typing.ResultBatch: """ if isinstance(result, dict): - return {key: jnp.array(value) for key, value in result.items()} + return {key: _to_jax(value) for key, value in result.items()} if isinstance(result, (list, tuple)): return tuple(_to_jax(r) for r in result) return jnp.array(result) diff --git a/tests/workflow/interfaces/execute/test_jax_jit.py b/tests/workflow/interfaces/execute/test_jax_jit.py index ce6a9ef27f4..95b056cf476 100644 --- a/tests/workflow/interfaces/execute/test_jax_jit.py +++ b/tests/workflow/interfaces/execute/test_jax_jit.py @@ -886,14 +886,17 @@ def cost(x, y, device, interface, ek): class TestJitAllCounts: + @pytest.mark.parametrize( + "device_name", (pytest.param("default.qubit", marks=pytest.mark.xfail), "reference.qubit") + ) @pytest.mark.parametrize("counts_wires", (None, (0, 1))) - def test_jit_allcounts(self, counts_wires): + def test_jit_allcounts(self, device_name, counts_wires): """Test jitting with counts with all_outcomes == True.""" tape = qml.tape.QuantumScript( [qml.RX(0, 0), qml.I(1)], [qml.counts(wires=counts_wires, all_outcomes=True)], shots=50 ) - device = qml.device("default.qubit") + device = qml.device(device_name, wires=2) res = jax.jit(qml.execute, static_argnums=(1, 2))( (tape,), device, qml.gradients.param_shift @@ -904,7 +907,14 @@ def test_jit_allcounts(self, counts_wires): for val in ["01", "10", "11"]: assert qml.math.allclose(res[val], 0) - def test_jit_allcounts_broadcasting(self): + @pytest.mark.parametrize( + "device_name", + ( + pytest.param("default.qubit", marks=pytest.mark.xfail), + pytest.param("reference.qubit", marks=pytest.mark.xfail), + ), + ) + def test_jit_allcounts_broadcasting(self, device_name): """Test jitting with counts with all_outcomes == True.""" tape = qml.tape.QuantumScript( @@ -912,7 +922,7 @@ def test_jit_allcounts_broadcasting(self): [qml.counts(wires=(0, 1), all_outcomes=True)], shots=50, ) - device = qml.device("default.qubit") + device = qml.device(device_name, wires=2) res = jax.jit(qml.execute, static_argnums=(1, 2))( (tape,), device, qml.gradients.param_shift @@ -927,7 +937,6 @@ def test_jit_allcounts_broadcasting(self): assert qml.math.allclose(ri[val], 0) -@pytest.mark.xfail(reason="Need to figure out how to handle this case in a less ambiguous manner") def test_diff_method_None_jit(): """Test that jitted execution works when `diff_method=None`.""" diff --git a/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py b/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py index 544da89f73d..862806999c6 100644 --- a/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py +++ b/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py @@ -3185,7 +3185,8 @@ def test_complex64_return(self, diff_method): jax.config.update("jax_enable_x64", False) try: - tol = 2e-2 if diff_method == "finite-diff" else 1e-6 + # finite diff with float32 ... + tol = 5e-2 if diff_method == "finite-diff" else 1e-6 @jax.jit @qml.qnode(qml.device("default.qubit", wires=1), diff_method=diff_method) From 2110c9b416e9742c2a1481dba8091449ab08269f Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 9 Jan 2025 10:03:08 -0500 Subject: [PATCH 6/6] final test --- tests/workflow/interfaces/execute/test_jax.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/workflow/interfaces/execute/test_jax.py b/tests/workflow/interfaces/execute/test_jax.py index e9d07ccabf7..3fd688788b3 100644 --- a/tests/workflow/interfaces/execute/test_jax.py +++ b/tests/workflow/interfaces/execute/test_jax.py @@ -693,10 +693,12 @@ def test_max_diff(self, tol): def cost_fn(x): ops = [qml.RX(x[0], 0), qml.RY(x[1], 1), qml.CNOT((0, 1))] - tape1 = qml.tape.QuantumScript(ops, [qml.var(qml.PauliZ(0) @ qml.PauliX(1))]) + tape1 = qml.tape.QuantumScript( + ops, [qml.var(qml.PauliZ(0) @ qml.PauliX(1))], shots=50000 + ) ops2 = [qml.RX(x[0], 0), qml.RY(x[0], 1), qml.CNOT((0, 1))] - tape2 = qml.tape.QuantumScript(ops2, [qml.probs(wires=1)]) + tape2 = qml.tape.QuantumScript(ops2, [qml.probs(wires=1)], shots=50000) result = execute([tape1, tape2], dev, diff_method=param_shift, max_diff=1) return result[0] + result[1][0] @@ -704,13 +706,13 @@ def cost_fn(x): res = cost_fn(params) x, y = params expected = 0.5 * (3 + jnp.cos(x) ** 2 * jnp.cos(2 * y)) - assert np.allclose(res, expected, atol=tol, rtol=0) + assert np.allclose(res, expected, atol=2e-2, rtol=0) res = jax.grad(cost_fn)(params) expected = jnp.array( [-jnp.cos(x) * jnp.cos(2 * y) * jnp.sin(x), -jnp.cos(x) ** 2 * jnp.sin(2 * y)] ) - assert np.allclose(res, expected, atol=tol, rtol=0) + assert np.allclose(res, expected, atol=2e-2, rtol=0) res = jax.jacobian(jax.grad(cost_fn))(params) expected = jnp.zeros([2, 2])