-
Notifications
You must be signed in to change notification settings - Fork 615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow device to configure conversion to numpy and use of pure_callback
#6788
base: master
Are you sure you want to change the base?
Changes from all commits
cbe1f20
cef0e31
f6949e1
1b069b8
a733cca
fbce8cc
df4bd13
2110c9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -161,6 +161,14 @@ def _conditional_broastcast_expand(tape): | |||||
return (tape,), null_postprocessing | ||||||
|
||||||
|
||||||
@qml.transform | ||||||
def no_counts(tape): | ||||||
"""Throws an error on counts measurements.""" | ||||||
if any(isinstance(mp, qml.measurements.CountsMP) for mp in tape.measurements): | ||||||
raise NotImplementedError("The JAX-JIT interface doesn't support qml.counts.") | ||||||
astralcai marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
return (tape,), null_postprocessing | ||||||
|
||||||
|
||||||
@qml.transform | ||||||
def adjoint_state_measurements( | ||||||
tape: QuantumScript, device_vjp=False | ||||||
|
@@ -535,6 +543,8 @@ def preprocess( | |||||
config = self._setup_execution_config(execution_config) | ||||||
transform_program = TransformProgram() | ||||||
|
||||||
if config.interface == qml.math.Interface.JAX_JIT: | ||||||
transform_program.add_transform(no_counts) | ||||||
transform_program.add_transform(validate_device_wires, self.wires, name=self.name) | ||||||
transform_program.add_transform( | ||||||
mid_circuit_measurements, device=self, mcm_config=config.mcm_config | ||||||
|
@@ -581,6 +591,12 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio | |||||
""" | ||||||
updated_values = {} | ||||||
|
||||||
updated_values["convert_to_numpy"] = ( | ||||||
execution_config.interface.value not in {"jax", "jax-jit"} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to continue using the Interface enum instead?
Suggested change
|
||||||
or execution_config.gradient_method == "adjoint" | ||||||
astralcai marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# need numpy to use caching, and need caching higher order derivatives | ||||||
or execution_config.derivative_order > 1 | ||||||
) | ||||||
for option in execution_config.device_options: | ||||||
if option not in self._device_options: | ||||||
raise qml.DeviceError(f"device option {option} not present on {self}") | ||||||
|
@@ -616,7 +632,6 @@ def execute( | |||||
execution_config: ExecutionConfig = DefaultExecutionConfig, | ||||||
) -> Union[Result, ResultBatch]: | ||||||
self.reset_prng_key() | ||||||
|
||||||
max_workers = execution_config.device_options.get("max_workers", self._max_workers) | ||||||
self._state_cache = {} if execution_config.use_device_jacobian_product else None | ||||||
interface = ( | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -141,6 +141,32 @@ def circuit(x): | |
|
||
assert dev.tracker.totals["execute_and_derivative_batches"] == 1 | ||
|
||
@pytest.mark.parametrize("interface", ("jax", "jax-jit")) | ||
def test_not_convert_to_numpy_with_jax(self, interface): | ||
"""Test that we will not convert to numpy when working with jax.""" | ||
|
||
dev = qml.device("default.qubit") | ||
config = qml.devices.ExecutionConfig( | ||
gradient_method=qml.gradients.param_shift, interface=interface | ||
) | ||
processed = dev.setup_execution_config(config) | ||
assert not processed.convert_to_numpy | ||
|
||
def test_convert_to_numpy_with_adjoint(self): | ||
"""Test that we will convert to numpy with adjoint.""" | ||
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface="jax-jit") | ||
dev = qml.device("default.qubit") | ||
processed = dev.setup_execution_config(config) | ||
assert processed.convert_to_numpy | ||
|
||
@pytest.mark.parametrize("interface", ("autograd", "torch", "tf")) | ||
def test_convert_to_numpy_non_jax(self, interface): | ||
"""Test that other interfaces are still converted to numpy.""" | ||
config = qml.devices.ExecutionConfig(gradient_method="adjoint", interface=interface) | ||
dev = qml.device("default.qubit") | ||
processed = dev.setup_execution_config(config) | ||
assert processed.convert_to_numpy | ||
Comment on lines
+155
to
+168
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you include "jax" as an interface in the testing? Also, I'm curious if converting to numpy with adjoint has negative effects on performance. |
||
|
||
|
||
# pylint: disable=too-few-public-methods | ||
class TestPreprocessing: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -417,7 +417,7 @@ def circuit(state): | |
|
||
|
||
@pytest.mark.jax | ||
@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.05)]) | ||
@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.1)]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a lot of shots... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we're using a seed, maybe it's worth reducing the number of shots. |
||
def test_jacobians_with_and_without_jit_match(shots, atol, seed): | ||
"""Test that the Jacobian of the circuit is the same with and without jit.""" | ||
import jax | ||
|
@@ -433,7 +433,7 @@ def circuit(coeffs): | |
circuit_ps = qml.QNode(circuit, dev, diff_method="parameter-shift") | ||
circuit_exact = qml.QNode(circuit, dev_no_shots) | ||
|
||
params = jax.numpy.array([0.5, 0.5, 0.5, 0.5]) | ||
params = jax.numpy.array([0.5, 0.5, 0.5, 0.5], dtype=jax.numpy.float64) | ||
jac_exact_fn = jax.jacobian(circuit_exact) | ||
jac_fd_fn = jax.jacobian(circuit_fd) | ||
jac_fd_fn_jit = jax.jit(jac_fd_fn) | ||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -687,12 +687,13 @@ def func(x, y): | |||||||
assert tape.measurements == contents[3:] | ||||||||
|
||||||||
@pytest.mark.jax | ||||||||
def test_jit_counts_raises_error(self): | ||||||||
@pytest.mark.parametrize("dev_name", ("default.qubit", "reference.qubit")) | ||||||||
def test_jit_counts_raises_error(self, dev_name): | ||||||||
"""Test that returning counts in a quantum function with trainable parameters while | ||||||||
jitting raises an error.""" | ||||||||
import jax | ||||||||
|
||||||||
dev = qml.device("default.qubit", wires=2, shots=5) | ||||||||
dev = qml.device(dev_name, wires=2, shots=5) | ||||||||
|
||||||||
def circuit1(param): | ||||||||
qml.Hadamard(0) | ||||||||
|
@@ -706,7 +707,8 @@ def circuit1(param): | |||||||
with pytest.raises( | ||||||||
NotImplementedError, match="The JAX-JIT interface doesn't support qml.counts." | ||||||||
): | ||||||||
jitted_qnode1(0.123) | ||||||||
out = jitted_qnode1(0.123) | ||||||||
print(out) | ||||||||
Comment on lines
+710
to
+711
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
# Test with qnode decorator syntax | ||||||||
@qml.qnode(dev) | ||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -693,24 +693,26 @@ def test_max_diff(self, tol): | |
|
||
def cost_fn(x): | ||
ops = [qml.RX(x[0], 0), qml.RY(x[1], 1), qml.CNOT((0, 1))] | ||
tape1 = qml.tape.QuantumScript(ops, [qml.var(qml.PauliZ(0) @ qml.PauliX(1))]) | ||
tape1 = qml.tape.QuantumScript( | ||
ops, [qml.var(qml.PauliZ(0) @ qml.PauliX(1))], shots=50000 | ||
) | ||
|
||
ops2 = [qml.RX(x[0], 0), qml.RY(x[0], 1), qml.CNOT((0, 1))] | ||
tape2 = qml.tape.QuantumScript(ops2, [qml.probs(wires=1)]) | ||
tape2 = qml.tape.QuantumScript(ops2, [qml.probs(wires=1)], shots=50000) | ||
|
||
result = execute([tape1, tape2], dev, diff_method=param_shift, max_diff=1) | ||
return result[0] + result[1][0] | ||
|
||
res = cost_fn(params) | ||
x, y = params | ||
expected = 0.5 * (3 + jnp.cos(x) ** 2 * jnp.cos(2 * y)) | ||
assert np.allclose(res, expected, atol=tol, rtol=0) | ||
assert np.allclose(res, expected, atol=2e-2, rtol=0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the |
||
|
||
res = jax.grad(cost_fn)(params) | ||
expected = jnp.array( | ||
[-jnp.cos(x) * jnp.cos(2 * y) * jnp.sin(x), -jnp.cos(x) ** 2 * jnp.sin(2 * y)] | ||
) | ||
assert np.allclose(res, expected, atol=tol, rtol=0) | ||
assert np.allclose(res, expected, atol=2e-2, rtol=0) | ||
|
||
res = jax.jacobian(jax.grad(cost_fn))(params) | ||
expected = jnp.zeros([2, 2]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.