Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow device to configure conversion to numpy and use of pure_callback #6788

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

albi3ro
Copy link
Contributor

@albi3ro albi3ro commented Jan 8, 2025

Context:

While we have logic for sampling with jax, it does not really integrate very well into the workflow. While you can technically set diff_method=None right now and jit the execution end-to-end, trying to jit diff_method=None will cause incomprehensible error messages on non-DQ devices.

We want to forbid differentiation diff_method=None, but keep a way to jit a finite shot execution.

Description of the Change:

In order to allow jitting finite shot executions, we need a way for the device to be able to configure whether or not the data is converted to numpy. To do so, we simply add another property to the ExecutionConfig, convert_to_numpy. If False, then we will not use a pure_callback to convert the parameters to numpy. If True, we use a pure_callback and convert the parameters to numpy.

Benefits:

Speed ups due to being able to jit the entire execution.

image

Possible Drawbacks:

ExecutionConfig gets an addtional property, making it more complicated.

Related GitHub Issues:

Fixes #6054 Fixes #3259 Blocks #6770

Copy link
Contributor

github-actions bot commented Jan 8, 2025

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

Copy link

codecov bot commented Jan 9, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.60%. Comparing base (5efeffb) to head (2110c9b).
Report is 1 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #6788   +/-   ##
=======================================
  Coverage   99.60%   99.60%           
=======================================
  Files         476      476           
  Lines       45232    45242   +10     
=======================================
+ Hits        45055    45065   +10     
  Misses        177      177           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines +710 to +711
out = jitted_qnode1(0.123)
print(out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
out = jitted_qnode1(0.123)
print(out)
jitted_qnode1(0.123)

Comment on lines +9 to +13
* Devices can now configure whether or not the data is converted to numpy and `jax.pure_callback`
is used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions
on `default.qubit` can now be jitted end-to-end, even with parameter shift.
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Devices can now configure whether or not the data is converted to numpy and `jax.pure_callback`
is used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions
on `default.qubit` can now be jitted end-to-end, even with parameter shift.
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)
* Devices can now configure whether or not the data is converted to numpy enabling `jax.pure_callback` to be used by the new `ExecutionConfig.convert_to_numpy` property. Finite shot executions on `default.qubit` can now be jitted end-to-end leading to performance improvements, even with parameter shift.
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)

Copy link
Contributor

@andrijapau andrijapau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some comments on my initial pass through.

@@ -581,6 +591,12 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio
"""
updated_values = {}

updated_values["convert_to_numpy"] = (
execution_config.interface.value not in {"jax", "jax-jit"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to continue using the Interface enum instead?

Suggested change
execution_config.interface.value not in {"jax", "jax-jit"}
execution_config.interface not in {qml.math.Interface.JAX, qml.math.Interface.JAX_JIT}

Copy link
Contributor

@astralcai astralcai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants