diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 72abf946de..f7b091c40f 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -167,3 +167,37 @@ jobs: cargo remove sp1-sdk cargo add sp1-sdk --path $GITHUB_WORKSPACE/crates/sdk SP1_DEV=1 RUST_LOG=info cargo run --release + test-cuda: + name: Test CUDA + runs-on: nvidia-gpu-linux-x64 + steps: + - name: Checkout sources + uses: actions/checkout@v4 + + - name: rust-cache + uses: actions/cache@v3 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + ~/.rustup/ + key: rust-1.79.0-${{ hashFiles('**/Cargo.toml') }} + restore-keys: rust-1.79.0- + + - name: Setup toolchain + id: rustc-toolchain + shell: bash + run: | + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.79.0 -y + + - name: Run script + run: | + . "$HOME/.cargo/env" + curl -L https://sp1.succinct.xyz | bash + /home/runner/.sp1/bin/sp1up + sudo apt install libssl-dev pkg-config + cd examples/fibonacci + RUST_LOG=info cargo run --release --features cuda \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index c6657d2653..c7a66c047e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -137,7 +137,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccb3ead547f4532bc8af961649942f0b9c16ee9226e26caa3f38420651cc0bf4" dependencies = [ "alloy-rlp", - "bytes 1.7.1", + "bytes 1.7.2", "cfg-if", "const-hex", "derive_more", @@ -159,7 +159,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26154390b1d205a4a7ac7352aa2eb4f81f391399d4e2f546fb81a2f8bb383f62" dependencies = [ "arrayvec", - "bytes 1.7.1", + "bytes 1.7.2", ] [[package]] @@ -648,7 +648,7 @@ checksum = "1236b4b292f6c4d6dc34604bb5120d85c3fe1d1aa596bd5cc52ca054d13e7b9e" dependencies = [ "async-trait", "axum-core", - "bytes 1.7.1", + "bytes 1.7.2", "futures-util", "http 1.1.0", "http-body 1.0.1", @@ -681,7 +681,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3" dependencies = [ "async-trait", - "bytes 1.7.1", + "bytes 1.7.2", "futures-util", "http 1.1.0", "http-body 1.0.1", @@ -944,9 +944,9 @@ checksum = "0e4cec68f03f32e44924783795810fa50a7035d8c8ebe78580ad7e6c703fba38" [[package]] name = "bytes" -version = "1.7.1" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" +checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" dependencies = [ "serde", ] @@ -1011,9 +1011,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.1.19" +version = "1.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d74707dde2ba56f86ae90effb3b43ddd369504387e718014de010cec7959800" +checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0" dependencies = [ "jobserver", "libc", @@ -1898,7 +1898,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a3d8dc56e02f954cac8eb489772c552c473346fc34f67412bb6244fd647f7e4" dependencies = [ "base64 0.21.7", - "bytes 1.7.1", + "bytes 1.7.2", "hex", "k256", "log", @@ -2106,7 +2106,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "82d80cc6ad30b14a48ab786523af33b37f28a8623fc06afd55324816ef18fb1f" dependencies = [ "arrayvec", - "bytes 1.7.1", + "bytes 1.7.2", "cargo_metadata", "chrono", "const-hex", @@ -2164,7 +2164,7 @@ dependencies = [ "async-trait", "auto_impl", "base64 0.21.7", - "bytes 1.7.1", + "bytes 1.7.2", "const-hex", "enr", "ethers-core", @@ -2270,7 +2270,7 @@ checksum = "139834ddba373bbdd213dffe02c8d110508dcf1726c2be27e8d1f7d7e1856418" dependencies = [ "arrayvec", "auto_impl", - "bytes 1.7.1", + "bytes 1.7.2", ] [[package]] @@ -2680,7 +2680,7 @@ version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" dependencies = [ - "bytes 1.7.1", + "bytes 1.7.2", "fnv", "futures-core", "futures-sink", @@ -2700,7 +2700,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" dependencies = [ "atomic-waker", - "bytes 1.7.1", + "bytes 1.7.2", "fnv", "futures-core", "futures-sink", @@ -2845,7 +2845,7 @@ version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" dependencies = [ - "bytes 1.7.1", + "bytes 1.7.2", "fnv", "itoa", ] @@ -2856,7 +2856,7 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" dependencies = [ - "bytes 1.7.1", + "bytes 1.7.2", "fnv", "itoa", ] @@ -2867,7 +2867,7 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ - "bytes 1.7.1", + "bytes 1.7.2", "http 0.2.12", "pin-project-lite", ] @@ -2878,7 +2878,7 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ - "bytes 1.7.1", + "bytes 1.7.2", "http 1.1.0", ] @@ -2888,7 +2888,7 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ - "bytes 1.7.1", + "bytes 1.7.2", "futures-util", "http 1.1.0", "http-body 1.0.1", @@ -2949,7 +2949,7 @@ version = "0.14.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" dependencies = [ - "bytes 1.7.1", + "bytes 1.7.2", "futures-channel", "futures-core", "futures-util", @@ -2973,7 +2973,7 @@ version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" dependencies = [ - "bytes 1.7.1", + "bytes 1.7.2", "futures-channel", "futures-util", "h2 0.4.6", @@ -3017,7 +3017,7 @@ dependencies = [ "tokio", "tokio-rustls 0.26.0", "tower-service", - "webpki-roots 0.26.5", + "webpki-roots 0.26.6", ] [[package]] @@ -3026,7 +3026,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ - "bytes 1.7.1", + "bytes 1.7.2", "http-body-util", "hyper 1.4.1", "hyper-util", @@ -3042,7 +3042,7 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da62f120a8a37763efb0cf8fdf264b884c7b8b9ac8660b900c8661030c00e6ba" dependencies = [ - "bytes 1.7.1", + "bytes 1.7.2", "futures-channel", "futures-util", "http 1.1.0", @@ -3836,7 +3836,7 @@ checksum = "786393f80485445794f6043fd3138854dd109cc6c4bd1a6383db304c9ce9b9ce" dependencies = [ "arrayvec", "auto_impl", - "bytes 1.7.1", + "bytes 1.7.2", "ethereum-types", "open-fastrlp-derive", ] @@ -3847,7 +3847,7 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "003b2be5c6c53c1cfeb0a238b8a1c3915cd410feb684457a36c10038f764bb1c" dependencies = [ - "bytes 1.7.1", + "bytes 1.7.2", "proc-macro2", "quote", "syn 1.0.109", @@ -4619,7 +4619,7 @@ version = "0.12.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "deb1435c188b76130da55f17a466d252ff7b1418b2ad3e037d127b94e3411f29" dependencies = [ - "bytes 1.7.1", + "bytes 1.7.2", "prost-derive", ] @@ -4629,7 +4629,7 @@ version = "0.12.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ - "bytes 1.7.1", + "bytes 1.7.2", "heck", "itertools 0.12.1", "log", @@ -4687,7 +4687,7 @@ version = "0.11.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684" dependencies = [ - "bytes 1.7.1", + "bytes 1.7.2", "pin-project-lite", "quinn-proto", "quinn-udp", @@ -4705,7 +4705,7 @@ version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6" dependencies = [ - "bytes 1.7.1", + "bytes 1.7.2", "rand 0.8.5", "ring 0.17.8", "rustc-hash 2.0.0", @@ -4924,7 +4924,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" dependencies = [ "base64 0.21.7", - "bytes 1.7.1", + "bytes 1.7.2", "encoding_rs", "futures-core", "futures-util", @@ -4965,7 +4965,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63" dependencies = [ "base64 0.22.1", - "bytes 1.7.1", + "bytes 1.7.2", "encoding_rs", "futures-core", "futures-util", @@ -5004,7 +5004,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots 0.26.5", + "webpki-roots 0.26.6", "windows-registry", ] @@ -5078,7 +5078,7 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb919243f34364b6bd2fc10ef797edbfa75f33c252e7998527479c6d6b47e1ec" dependencies = [ - "bytes 1.7.1", + "bytes 1.7.2", "rlp-derive", "rustc-hex", ] @@ -5114,7 +5114,7 @@ dependencies = [ "alloy-rlp", "ark-ff 0.3.0", "ark-ff 0.4.2", - "bytes 1.7.1", + "bytes 1.7.2", "fastrlp", "num-bigint 0.4.6", "num-traits", @@ -6924,7 +6924,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" dependencies = [ "backtrace", - "bytes 1.7.1", + "bytes 1.7.2", "libc", "mio", "parking_lot", @@ -6983,7 +6983,7 @@ version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" dependencies = [ - "bytes 1.7.1", + "bytes 1.7.2", "futures-core", "futures-sink", "pin-project-lite", @@ -7162,7 +7162,7 @@ checksum = "30ee6ab729cd4cf0fd55218530c4522ed30b7b6081752839b68fcec8d0960788" dependencies = [ "base64 0.13.1", "byteorder", - "bytes 1.7.1", + "bytes 1.7.2", "http 0.2.12", "httparse", "log", @@ -7190,7 +7190,7 @@ checksum = "dfa3161d8eee0abcad4e762f4215381a430cc1281870d575b0f1e4fbfc74b8ce" dependencies = [ "async-trait", "axum", - "bytes 1.7.1", + "bytes 1.7.2", "futures", "http 1.1.0", "http-body-util", @@ -7264,9 +7264,9 @@ checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f" [[package]] name = "unicode-normalization" -version = "0.1.23" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" dependencies = [ "tinyvec", ] @@ -7558,9 +7558,9 @@ checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" [[package]] name = "webpki-roots" -version = "0.26.5" +version = "0.26.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bd24728e5af82c6c4ec1b66ac4844bdf8156257fccda846ec58b42cd0cdbe6a" +checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958" dependencies = [ "rustls-pki-types", ] diff --git a/crates/core/machine/src/memory/local.rs b/crates/core/machine/src/memory/local.rs index bc272d8cac..0d8ac0b21a 100644 --- a/crates/core/machine/src/memory/local.rs +++ b/crates/core/machine/src/memory/local.rs @@ -15,7 +15,7 @@ use sp1_stark::{ InteractionKind, Word, }; -const NUM_LOCAL_MEMORY_ENTRIES_PER_ROW: usize = 2; +pub const NUM_LOCAL_MEMORY_ENTRIES_PER_ROW: usize = 3; pub(crate) const NUM_MEMORY_LOCAL_INIT_COLS: usize = size_of::>(); diff --git a/crates/core/machine/src/riscv/mod.rs b/crates/core/machine/src/riscv/mod.rs index 814852c5e2..b788f51d41 100644 --- a/crates/core/machine/src/riscv/mod.rs +++ b/crates/core/machine/src/riscv/mod.rs @@ -2,11 +2,14 @@ pub mod cost; mod shape; +use itertools::Itertools; pub use shape::*; use sp1_core_executor::{ExecutionRecord, Program}; use crate::{ - memory::{MemoryChipType, MemoryLocalChip, MemoryProgramChip}, + memory::{ + MemoryChipType, MemoryLocalChip, MemoryProgramChip, NUM_LOCAL_MEMORY_ENTRIES_PER_ROW, + }, riscv::MemoryChipType::{Finalize, Initialize}, syscall::{ memcpy::{self, MemCopy32Chip, MemCopy64Chip, MemCopyChip}, @@ -362,8 +365,6 @@ impl RiscvAir { /// Get the heights of the preprocessed chips for a given program. pub(crate) fn preprocessed_heights(program: &Program) -> Vec<(Self, usize)> { - println!("Program instructions: {}", program.instructions.len()); - println!("Program memory: {}", program.memory_image.len()); vec![ (RiscvAir::Program(ProgramChip::default()), program.instructions.len()), (RiscvAir::ProgramMemory(MemoryProgramChip::default()), program.memory_image.len()), @@ -384,6 +385,15 @@ impl RiscvAir { (RiscvAir::ShiftRight(ShiftRightChip::default()), record.shift_right_events.len()), (RiscvAir::ShiftLeft(ShiftLeft::default()), record.shift_left_events.len()), (RiscvAir::Lt(LtChip::default()), record.lt_events.len()), + ( + RiscvAir::MemoryLocal(MemoryLocalChip::new()), + record + .get_local_mem_events() + .chunks(NUM_LOCAL_MEMORY_ENTRIES_PER_ROW) + .into_iter() + .count(), + ), + (RiscvAir::Syscall(SyscallChip::default()), record.syscall_events.len()), ] } @@ -397,6 +407,8 @@ impl RiscvAir { RiscvAir::Lt(LtChip::default()), RiscvAir::ShiftLeft(ShiftLeft::default()), RiscvAir::ShiftRight(ShiftRightChip::default()), + RiscvAir::MemoryLocal(MemoryLocalChip::new()), + RiscvAir::Syscall(SyscallChip::default()), ] } diff --git a/crates/core/machine/src/riscv/shape.rs b/crates/core/machine/src/riscv/shape.rs index c20dfbb7e9..56f415fdb0 100644 --- a/crates/core/machine/src/riscv/shape.rs +++ b/crates/core/machine/src/riscv/shape.rs @@ -7,13 +7,13 @@ use sp1_stark::{air::MachineAir, ProofShape}; use thiserror::Error; use crate::{ - memory::MemoryProgramChip, + memory::{MemoryLocalChip, MemoryProgramChip}, riscv::MemoryChipType::{Finalize, Initialize}, }; use super::{ AddSubChip, BitwiseChip, CpuChip, DivRemChip, LtChip, MemoryGlobalChip, MulChip, ProgramChip, - RiscvAir, ShiftLeft, ShiftRightChip, + RiscvAir, ShiftLeft, ShiftRightChip, SyscallChip, }; #[derive(Debug, Error)] @@ -79,6 +79,7 @@ impl CoreShapeConfig { }) .collect(); + tracing::info!("Found shape"); let shape = CoreShape { inner: shape? }; Some(shape) } @@ -232,10 +233,13 @@ impl CoreShapeConfig { self.included_shapes.iter().map(ProofShape::from_map).collect::>(); let cpu_name = || RiscvAir::::Cpu(CpuChip::default()).name(); + let memory_local_name = || RiscvAir::::MemoryLocal(MemoryLocalChip::new()).name(); + let syscall_name = || RiscvAir::::Syscall(SyscallChip::default()).name(); let core_filter = move |shape: &ProofShape| { let core_airs = RiscvAir::::get_all_core_airs() .into_iter() .map(|air| air.name()) + .filter(|name| name != &memory_local_name() && name != &syscall_name()) .collect::>(); let core_chips_and_heights = shape .chip_information @@ -251,12 +255,6 @@ impl CoreShapeConfig { } let cpu_height = core_chips_and_heights.first().unwrap().1; - let num_core_chips_at_cpu_height = - core_chips_and_heights.iter().filter(|(_, height)| *height == cpu_height).count(); - - if num_core_chips_at_cpu_height > 2 { - return false; - } let sum_of_heights = core_chips_and_heights.iter().map(|(_, height)| *height).sum::(); @@ -308,14 +306,16 @@ impl Default for CoreShapeConfig { ]); // Get the heights for the short shape cluster (for small shards). - let cpu_heights = vec![Some(10), Some(16)]; - let divrem_heights = vec![None, Some(10), Some(16)]; - let add_sub_heights = vec![None, Some(10), Some(16)]; - let bitwise_heights = vec![None, Some(10), Some(16)]; - let mul_heights = vec![None, Some(10), Some(16)]; - let shift_right_heights = vec![None, Some(10), Some(16)]; - let shift_left_heights = vec![None, Some(10), Some(16)]; - let lt_heights = vec![None, Some(10), Some(16)]; + let cpu_heights = vec![Some(16), Some(19)]; + let divrem_heights = vec![None, Some(10), Some(16), Some(19)]; + let add_sub_heights = vec![None, Some(10), Some(16), Some(19)]; + let bitwise_heights = vec![None, Some(10), Some(16), Some(19)]; + let mul_heights = vec![None, Some(10), Some(16), Some(19)]; + let shift_right_heights = vec![None, Some(10), Some(16), Some(19)]; + let shift_left_heights = vec![None, Some(10), Some(16), Some(19)]; + let lt_heights = vec![None, Some(10), Some(16), Some(19)]; + let memory_local_heights = vec![Some(16), Some(20)]; + let syscall_heights = vec![None, Some(19)]; let short_allowed_log_heights = HashMap::from([ (RiscvAir::Cpu(CpuChip::default()), cpu_heights), @@ -326,6 +326,8 @@ impl Default for CoreShapeConfig { (RiscvAir::ShiftRight(ShiftRightChip::default()), shift_right_heights), (RiscvAir::ShiftLeft(ShiftLeft::default()), shift_left_heights), (RiscvAir::Lt(LtChip::default()), lt_heights), + (RiscvAir::MemoryLocal(MemoryLocalChip::new()), memory_local_heights), + (RiscvAir::Syscall(SyscallChip::default()), syscall_heights), ]); // Get the heights for the medium shape cluster. @@ -337,6 +339,8 @@ impl Default for CoreShapeConfig { let shift_right_heights = vec![None, Some(19), Some(20), Some(21)]; let shift_left_heights = vec![None, Some(19), Some(20), Some(21)]; let lt_heights = vec![None, Some(19), Some(20), Some(21)]; + let memory_local_heights = vec![Some(20), Some(21)]; + let syscall_heights = vec![None, Some(19)]; let medium_allowed_log_heights = HashMap::from([ (RiscvAir::Cpu(CpuChip::default()), cpu_heights), @@ -347,6 +351,8 @@ impl Default for CoreShapeConfig { (RiscvAir::ShiftRight(ShiftRightChip::default()), shift_right_heights), (RiscvAir::ShiftLeft(ShiftLeft::default()), shift_left_heights), (RiscvAir::Lt(LtChip::default()), lt_heights), + (RiscvAir::MemoryLocal(MemoryLocalChip::new()), memory_local_heights), + (RiscvAir::Syscall(SyscallChip::default()), syscall_heights), ]); // Core chip heights for the long shape cluster. @@ -358,6 +364,8 @@ impl Default for CoreShapeConfig { let shift_right_heights = vec![None, Some(20), Some(21), Some(22)]; let shift_left_heights = vec![None, Some(20), Some(21), Some(22)]; let lt_heights = vec![None, Some(20), Some(21), Some(22)]; + let memory_local_heights = vec![Some(21), Some(22)]; + let syscall_heights = vec![None, Some(20)]; let long_allowed_log_heights = HashMap::from([ (RiscvAir::Cpu(CpuChip::default()), cpu_heights), @@ -368,6 +376,8 @@ impl Default for CoreShapeConfig { (RiscvAir::ShiftRight(ShiftRightChip::default()), shift_right_heights), (RiscvAir::ShiftLeft(ShiftLeft::default()), shift_left_heights), (RiscvAir::Lt(LtChip::default()), lt_heights), + (RiscvAir::MemoryLocal(MemoryLocalChip::new()), memory_local_heights), + (RiscvAir::Syscall(SyscallChip::default()), syscall_heights), ]); // Set the memory init and finalize heights. diff --git a/crates/core/machine/src/utils/prove.rs b/crates/core/machine/src/utils/prove.rs index 925ed1fd9b..30460abc51 100644 --- a/crates/core/machine/src/utils/prove.rs +++ b/crates/core/machine/src/utils/prove.rs @@ -9,8 +9,7 @@ use std::{ use web_time::Instant; use crate::riscv::{CoreShapeConfig, RiscvAir}; -use p3_challenger::{CanObserve, FieldChallenger}; -use p3_field::AbstractField; +use p3_challenger::FieldChallenger; use p3_maybe_rayon::prelude::*; use serde::{de::DeserializeOwned, Serialize}; use size::Size; @@ -327,11 +326,12 @@ where // Create the challenger and observe the verifying key. let mut challenger = prover.config().challenger(); - challenger.observe(pk.commit()); - challenger.observe(pk.pc_start()); - for _ in 0..7 { - challenger.observe(Val::::zero()); - } + pk.observe_into(&mut challenger); + // challenger.observe(pk.preprocessed_commit()); + // challenger.observe(pk.pc_start()); + // for _ in 0..7 { + // challenger.observe(Val::::zero()); + // } // Spawn the phase 1 prover thread. let phase_1_prover_span = tracing::Span::current().clone(); diff --git a/crates/cuda/src/lib.rs b/crates/cuda/src/lib.rs index bac2ab791f..5bb38ca874 100644 --- a/crates/cuda/src/lib.rs +++ b/crates/cuda/src/lib.rs @@ -1,10 +1,5 @@ -#[rustfmt::skip] -pub mod proto { - pub mod api; -} - -use core::time::Duration; use std::{ + error::Error as StdError, future::Future, io::{BufReader, Read, Write}, process::{Command, Stdio}, @@ -12,6 +7,7 @@ use std::{ atomic::{AtomicBool, Ordering}, Arc, }, + time::{Duration, Instant}, }; use crate::proto::api::ProverServiceClient; @@ -25,6 +21,11 @@ use sp1_prover::{ use tokio::task::block_in_place; use twirp::{url::Url, Client}; +#[rustfmt::skip] +pub mod proto { + pub mod api; +} + /// A remote client to [sp1_prover::SP1Prover] that runs inside a container. /// /// This is currently used to provide experimental support for GPU hardware acceleration. @@ -82,32 +83,34 @@ pub struct WrapRequestPayload { impl SP1CudaProver { /// Creates a new [SP1Prover] that runs inside a Docker container and returns a /// [SP1ProverClient] that can be used to communicate with the container. - pub fn new() -> Self { + pub fn new() -> Result> { let container_name = "sp1-gpu"; - let image_name = "jtguibas/sp1-gpu:v1.3.0-rc1"; + let image_name = "jtguibas/sp1-gpu:v3.0.0-rc2"; let cleaned_up = Arc::new(AtomicBool::new(false)); let cleanup_name = container_name; let cleanup_flag = cleaned_up.clone(); - // Pull the docker image if it's not present. - Command::new("sudo") - .args(["docker", "pull", image_name]) - .output() - .expect("failed to pull docker image"); + // Check if Docker is available and the user has necessary permissions + if !Self::check_docker_availability()? { + return Err("Docker is not available or you don't have the necessary permissions. Please ensure Docker is installed and you are part of the docker group.".into()); + } - // Start the docker container. - let rust_log_level = std::env::var("RUST_LOG").unwrap_or("none".to_string()); - let mut child = Command::new("sudo") + // Pull the docker image if it's not present + if let Err(e) = Command::new("docker").args(["pull", image_name]).output() { + return Err(format!("Failed to pull Docker image: {}. Please check your internet connection and Docker permissions.", e).into()); + } + + // Start the docker container + let rust_log_level = std::env::var("RUST_LOG").unwrap_or_else(|_| "none".to_string()); + let mut child = Command::new("docker") .args([ - "docker", "run", "-e", - format!("RUST_LOG={}", rust_log_level).as_str(), + &format!("RUST_LOG={}", rust_log_level), "-p", "3000:3000", "--rm", - "--runtime=nvidia", "--gpus", "all", "--name", @@ -117,7 +120,7 @@ impl SP1CudaProver { .stdout(Stdio::piped()) .stderr(Stdio::piped()) .spawn() - .expect("failed to start Docker container"); + .map_err(|e| format!("Failed to start Docker container: {}. Please check your Docker installation and permissions.", e))?; let stdout = child.stdout.take().unwrap(); std::thread::spawn(move || { @@ -135,7 +138,7 @@ impl SP1CudaProver { } }); - // Kill the container on control-c. + // Kill the container on control-c ctrlc::set_handler(move || { tracing::debug!("received Ctrl+C, cleaning up..."); if !cleanup_flag.load(Ordering::SeqCst) { @@ -146,37 +149,57 @@ impl SP1CudaProver { }) .unwrap(); - // Wait a few seconds for the container to start. + // Wait a few seconds for the container to start std::thread::sleep(Duration::from_secs(2)); - // Check if the container is ready. + // Check if the container is ready let client = Client::from_base_url( Url::parse("http://localhost:3000/twirp/").expect("failed to parse url"), ) .expect("failed to create client"); + + let timeout = Duration::from_secs(60); // Set a 60-second timeout + let start_time = Instant::now(); + block_on(async { tracing::info!("waiting for proving server to be ready"); loop { + if start_time.elapsed() > timeout { + return Err("Timeout: proving server did not become ready within 60 seconds. Please check your Docker container and network settings.".to_string()); + } + let request = ReadyRequest {}; - let response = client.ready(request).await; - if let Ok(response) = response { - if response.ready { + match client.ready(request).await { + Ok(response) if response.ready => { tracing::info!("proving server is ready"); break; } + Ok(_) => { + tracing::info!("proving server is not ready, retrying..."); + } + Err(e) => { + tracing::warn!("Error checking server readiness: {}", e); + } } - tracing::info!("proving server is not ready, retrying..."); - std::thread::sleep(Duration::from_secs(2)); + tokio::time::sleep(Duration::from_secs(2)).await; } - }); + Ok(()) + })?; - SP1CudaProver { + Ok(SP1CudaProver { client: Client::from_base_url( Url::parse("http://localhost:3000/twirp/").expect("failed to parse url"), ) .expect("failed to create client"), container_name: container_name.to_string(), cleaned_up: cleaned_up.clone(), + }) + } + + fn check_docker_availability() -> Result> { + match Command::new("docker").arg("version").output() { + Ok(output) => Ok(output.status.success()), + Err(_) => Ok(false), } } @@ -257,7 +280,7 @@ impl SP1CudaProver { impl Default for SP1CudaProver { fn default() -> Self { - Self::new() + Self::new().expect("Failed to create SP1CudaProver") } } @@ -273,8 +296,8 @@ impl Drop for SP1CudaProver { /// Cleans up the a docker container with the given name. fn cleanup_container(container_name: &str) { - if let Err(e) = Command::new("sudo").args(["docker", "rm", "-f", container_name]).output() { - eprintln!("failed to remove container: {}", e); + if let Err(e) = Command::new("docker").args(["rm", "-f", container_name]).output() { + eprintln!("Failed to remove container: {}. You may need to manually remove it using 'docker rm -f {}'", e, container_name); } } @@ -313,7 +336,7 @@ mod tests { setup_logger(); let prover = SP1Prover::::new(); - let client = SP1CudaProver::new(); + let client = SP1CudaProver::new().expect("Failed to create SP1CudaProver"); let (pk, vk) = prover.setup(FIBONACCI_ELF); println!("proving core"); diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index cacfbb172c..8558e0cd57 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -55,6 +55,10 @@ path = "scripts/build_plonk_bn254.rs" name = "build_groth16_bn254" path = "scripts/build_groth16_bn254.rs" +[[bin]] +name = "build_compress_vks" +path = "scripts/build_compress_vks.rs" + [[bin]] name = "e2e" path = "scripts/e2e.rs" diff --git a/crates/prover/merkle_tree.bin b/crates/prover/merkle_tree.bin new file mode 100644 index 0000000000..2d833832a1 Binary files /dev/null and b/crates/prover/merkle_tree.bin differ diff --git a/crates/prover/scripts/build_compress_vks.rs b/crates/prover/scripts/build_compress_vks.rs new file mode 100644 index 0000000000..6528642d00 --- /dev/null +++ b/crates/prover/scripts/build_compress_vks.rs @@ -0,0 +1,55 @@ +use std::{fs::File, path::PathBuf}; + +use clap::Parser; +use p3_baby_bear::BabyBear; +use sp1_core_machine::{riscv::CoreShapeConfig, utils::setup_logger}; +use sp1_prover::{utils::get_all_vk_digests, InnerSC, REDUCE_BATCH_SIZE}; +use sp1_recursion_circuit_v2::merkle_tree::MerkleTree; +use sp1_recursion_core_v2::shape::RecursionShapeConfig; + +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[clap(short, long)] + build_dir: PathBuf, + #[clap(short, long)] + dummy: bool, + #[clap(short, long, default_value_t = REDUCE_BATCH_SIZE)] + reduce_batch_size: usize, +} + +fn main() { + setup_logger(); + let args = Args::parse(); + + let reduce_batch_size = args.reduce_batch_size; + let build_dir = args.build_dir; + + let core_shape_config = CoreShapeConfig::default(); + let recursion_shape_config = RecursionShapeConfig::default(); + + std::fs::create_dir_all(&build_dir).expect("failed to create build directory"); + + tracing::info!("building compress vk map"); + let vk_map = get_all_vk_digests(&core_shape_config, &recursion_shape_config, reduce_batch_size); + tracing::info!("compress vks generated, number of keys: {}", vk_map.len()); + + // Save the vk map to a file. + tracing::info!("saving vk map to file"); + let vk_map_path = build_dir.join("vk_map.bin"); + let mut vk_map_file = File::create(vk_map_path).unwrap(); + bincode::serialize_into(&mut vk_map_file, &vk_map).unwrap(); + tracing::info!("File saved successfully."); + + // Build a merkle tree from the vk map. + tracing::info!("building merkle tree"); + let (root, merkle_tree) = + MerkleTree::::commit(vk_map.keys().cloned().collect()); + + // Saving merkle tree data to file. + tracing::info!("saving merkle tree to file"); + let merkle_tree_path = build_dir.join("merkle_tree.bin"); + let mut merkle_tree_file = File::create(merkle_tree_path).unwrap(); + bincode::serialize_into(&mut merkle_tree_file, &(root, merkle_tree)).unwrap(); + tracing::info!("File saved successfully."); +} diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index 02e3dc812c..c9d3edc185 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -84,9 +84,7 @@ use sp1_recursion_circuit_v2::{ }; pub use types::*; -use utils::{ - get_all_vk_digests, sp1_commited_values_digest_bn254, sp1_vkey_digest_bn254, words_to_bytes, -}; +use utils::{sp1_commited_values_digest_bn254, sp1_vkey_digest_bn254, words_to_bytes}; use components::{DefaultProverComponents, SP1ProverComponents}; @@ -108,7 +106,10 @@ const WRAP_DEGREE: usize = 17; const CORE_CACHE_SIZE: usize = 5; const COMPRESS_CACHE_SIZE: usize = 3; -const REDUCE_BATCH_SIZE: usize = 2; +pub const REDUCE_BATCH_SIZE: usize = 2; + +const VK_MAP_BYTES: &[u8] = include_bytes!("../vk_map.bin"); +const MERKLE_TREE_BYTES: &[u8] = include_bytes!("../merkle_tree.bin"); pub type CompressAir = RecursionAir; pub type ShrinkAir = RecursionAir; @@ -190,9 +191,10 @@ impl SP1Prover { let recursion_shape_config = RecursionShapeConfig::default(); let allowed_vk_map = - get_all_vk_digests(&core_shape_config, &recursion_shape_config, REDUCE_BATCH_SIZE); + bincode::deserialize(VK_MAP_BYTES).expect("failed to deserialize vk map"); - let (root, merkle_tree) = MerkleTree::commit(allowed_vk_map.keys().cloned().collect()); + let (root, merkle_tree) = + bincode::deserialize(MERKLE_TREE_BYTES).expect("failed to deserialize merkle tree"); let core_shape_config = env::var("FIX_CORE_SHAPES") .map(|v| v.eq_ignore_ascii_case("true")) diff --git a/crates/prover/vk_map.bin b/crates/prover/vk_map.bin new file mode 100644 index 0000000000..cd1806170f Binary files /dev/null and b/crates/prover/vk_map.bin differ diff --git a/crates/recursion/circuit-v2/src/fri.rs b/crates/recursion/circuit-v2/src/fri.rs index 01890cb3f7..74d50b9327 100644 --- a/crates/recursion/circuit-v2/src/fri.rs +++ b/crates/recursion/circuit-v2/src/fri.rs @@ -21,6 +21,18 @@ use crate::{ FriProofVariable, FriQueryProofVariable, TwoAdicPcsProofVariable, TwoAdicPcsRoundVariable, }; +#[derive(Debug, Clone, Copy)] +pub struct PolynomialShape { + pub width: usize, + pub log_degree: usize, +} + +#[derive(Debug, Clone)] + +pub struct PolynomialBatchShape { + pub shapes: Vec, +} + pub fn verify_shape_and_sample_challenges< C: CircuitConfig, SC: BabyBearFriConfigVariable, @@ -401,12 +413,12 @@ pub fn dummy_query_proof(height: usize) -> QueryProof>, + batch_shapes: &[PolynomialBatchShape], log_blowup: usize, ) -> InnerPcsProof { - let &max_height = batch_shapes + let max_height = batch_shapes .iter() - .map(|shapes| shapes.iter().map(|(_, x)| x).max().unwrap()) + .map(|shape| shape.shapes.iter().map(|shape| shape.log_degree).max().unwrap()) .max() .unwrap(); let fri_proof = FriProof { @@ -423,13 +435,15 @@ pub fn dummy_pcs_proof( batch_shapes .iter() .map(|shapes| { - let batch_max_height = shapes.iter().map(|(_, x)| x).max().unwrap(); + let batch_max_height = + shapes.shapes.iter().map(|shape| shape.log_degree).max().unwrap(); BatchOpening { opened_values: shapes + .shapes .iter() - .map(|(width, _)| vec![BabyBear::zero(); *width]) + .map(|shape| vec![BabyBear::zero(); shape.width]) .collect(), - opening_proof: vec![dummy_hash().into(); *batch_max_height + log_blowup], + opening_proof: vec![dummy_hash().into(); batch_max_height + log_blowup], } }) .collect::>() @@ -698,9 +712,17 @@ mod tests { .collect::>(); pcs.verify(vec![(commit, os.clone())], &proof, &mut challenger).unwrap(); + let batch_shapes = vec![PolynomialBatchShape { + shapes: log_degrees + .iter() + .copied() + .map(|d| PolynomialShape { width: 100, log_degree: d }) + .collect(), + }]; + let dummy_proof = dummy_pcs_proof( inner_fri_config().num_queries, - vec![log_degrees.iter().copied().map(|d| (100, d)).collect()], + &batch_shapes, inner_fri_config().log_blowup, ); diff --git a/crates/recursion/circuit-v2/src/machine/core.rs b/crates/recursion/circuit-v2/src/machine/core.rs index 8bc1f72675..21473f3d08 100644 --- a/crates/recursion/circuit-v2/src/machine/core.rs +++ b/crates/recursion/circuit-v2/src/machine/core.rs @@ -52,6 +52,7 @@ pub struct SP1RecursionWitnessVariable< pub is_first_shard: Felt, } +#[derive(Debug, Clone)] pub struct SP1RecursionWitnessValues { pub vk: StarkVerifyingKey, pub shard_proofs: Vec>, diff --git a/crates/recursion/circuit-v2/src/merkle_tree.rs b/crates/recursion/circuit-v2/src/merkle_tree.rs index 4769ec5949..57493bf524 100644 --- a/crates/recursion/circuit-v2/src/merkle_tree.rs +++ b/crates/recursion/circuit-v2/src/merkle_tree.rs @@ -3,6 +3,7 @@ use std::fmt::Debug; use itertools::Itertools; use p3_field::Field; use p3_util::{reverse_bits_len, reverse_slice_index_bits}; +use serde::{Deserialize, Serialize}; use sp1_core_machine::utils::log2_strict_usize; use sp1_recursion_compiler::ir::Builder; @@ -12,7 +13,9 @@ use crate::{ CircuitConfig, }; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound(serialize = "HV::Digest: Serialize"))] +#[serde(bound(deserialize = "HV::Digest: Deserialize<'de>"))] pub struct MerkleTree> { /// The height of the tree, not counting the root layer. This is the same as the logarithm of the /// number of leaves. diff --git a/crates/recursion/circuit-v2/src/stark.rs b/crates/recursion/circuit-v2/src/stark.rs index a7ffc18568..9306005dc5 100644 --- a/crates/recursion/circuit-v2/src/stark.rs +++ b/crates/recursion/circuit-v2/src/stark.rs @@ -1,10 +1,12 @@ use hashbrown::HashMap; use itertools::{izip, Itertools}; + use num_traits::cast::ToPrimitive; -use p3_air::Air; + +use p3_air::{Air, BaseAir}; use p3_baby_bear::BabyBear; use p3_commit::{Mmcs, Pcs, PolynomialSpace, TwoAdicMultiplicativeCoset}; -use p3_field::{Field, TwoAdicField}; +use p3_field::{AbstractField, ExtensionField, Field, TwoAdicField}; use p3_matrix::dense::RowMajorMatrix; use sp1_recursion_compiler::{ @@ -12,12 +14,18 @@ use sp1_recursion_compiler::{ ir::{Builder, Config, Ext}, prelude::Felt, }; -use sp1_stark::{air::InteractionScope, ShardCommitment, ShardOpenedValues, Val}; +use sp1_stark::{ + air::InteractionScope, baby_bear_poseidon2::BabyBearPoseidon2, AirOpenedValues, Chip, + ChipOpenedValues, InnerChallenge, ProofShape, ShardCommitment, ShardOpenedValues, ShardProof, + Val, PROOF_MAX_NUM_PVS, +}; use sp1_stark::{air::MachineAir, StarkGenericConfig, StarkMachine, StarkVerifyingKey}; use crate::{ - challenger::CanObserveVariable, hash::FieldHasherVariable, CircuitConfig, - TwoAdicPcsMatsVariable, TwoAdicPcsProofVariable, + challenger::CanObserveVariable, + fri::{dummy_hash, dummy_pcs_proof, PolynomialBatchShape, PolynomialShape}, + hash::FieldHasherVariable, + BabyBearFriConfig, CircuitConfig, TwoAdicPcsMatsVariable, TwoAdicPcsProofVariable, }; use crate::{ @@ -37,6 +45,145 @@ pub struct ShardProofVariable, SC: BabyBearFriConf pub public_values: Vec>, } +/// Make a dummy shard proof for a given proof shape. +pub fn dummy_shard_proof>( + machine: &StarkMachine, + vk: &StarkVerifyingKey, + shape: &ProofShape, +) -> ShardProof { + // Make a dummy commitment. + let commitment = ShardCommitment { + global_main_commit: dummy_hash(), + local_main_commit: dummy_hash(), + permutation_commit: dummy_hash(), + quotient_commit: dummy_hash(), + }; + + // Get dummy opened values by reading the chip ordering from the shape. + let chip_ordering = shape + .chip_information + .iter() + .enumerate() + .map(|(i, (name, _))| (name.clone(), i)) + .collect::>(); + let shard_chips = machine.shard_chips_ordered(&chip_ordering).collect::>(); + let chip_scopes = shard_chips.iter().map(|chip| chip.commit_scope()).collect::>(); + let has_global_main_commit = chip_scopes.contains(&InteractionScope::Global); + let opened_values = ShardOpenedValues { + chips: shard_chips + .iter() + .zip_eq(shape.chip_information.iter()) + .map(|(chip, (_, log_degree))| { + dummy_opened_values::<_, InnerChallenge, _>(chip, *log_degree) + }) + .collect(), + }; + + let mut preprocessed_batch_shape = vec![]; + let mut global_main_batch_shape = vec![]; + let mut local_main_batch_shape = vec![]; + let mut permutation_batch_shape = vec![]; + let mut quotient_batch_shape = vec![]; + + for info in vk.chip_information.iter() { + let name = &info.0; + let i = chip_ordering[name]; + let opened_values = &opened_values.chips[i]; + let prep_shape = PolynomialShape { + width: opened_values.preprocessed.local.len(), + log_degree: opened_values.log_degree, + }; + preprocessed_batch_shape.push(prep_shape); + } + + for (chip_opening, scope) in opened_values.chips.iter().zip_eq(chip_scopes.iter()) { + let main_shape = PolynomialShape { + width: chip_opening.main.local.len(), + log_degree: chip_opening.log_degree, + }; + match scope { + InteractionScope::Global => global_main_batch_shape.push(main_shape), + InteractionScope::Local => local_main_batch_shape.push(main_shape), + } + let permutation_shape = PolynomialShape { + width: chip_opening.permutation.local.len(), + log_degree: chip_opening.log_degree, + }; + permutation_batch_shape.push(permutation_shape); + for quot_chunk in chip_opening.quotient.iter() { + assert_eq!(quot_chunk.len(), 4); + quotient_batch_shape.push(PolynomialShape { + width: quot_chunk.len(), + log_degree: chip_opening.log_degree, + }); + } + } + + let batch_shapes = if has_global_main_commit { + vec![ + PolynomialBatchShape { shapes: preprocessed_batch_shape }, + PolynomialBatchShape { shapes: global_main_batch_shape }, + PolynomialBatchShape { shapes: local_main_batch_shape }, + PolynomialBatchShape { shapes: permutation_batch_shape }, + PolynomialBatchShape { shapes: quotient_batch_shape }, + ] + } else { + vec![ + PolynomialBatchShape { shapes: preprocessed_batch_shape }, + PolynomialBatchShape { shapes: local_main_batch_shape }, + PolynomialBatchShape { shapes: permutation_batch_shape }, + PolynomialBatchShape { shapes: quotient_batch_shape }, + ] + }; + + let fri_queries = machine.config().fri_config().num_queries; + let log_blowup = machine.config().fri_config().log_blowup; + let opening_proof = dummy_pcs_proof(fri_queries, &batch_shapes, log_blowup); + + let public_values = (0..PROOF_MAX_NUM_PVS).map(|_| BabyBear::zero()).collect::>(); + + ShardProof { + commitment, + opened_values, + opening_proof, + chip_ordering, + chip_scopes, + public_values, + } +} + +fn dummy_opened_values, A: MachineAir>( + chip: &Chip, + log_degree: usize, +) -> ChipOpenedValues { + let preprocessed_width = chip.preprocessed_width(); + let preprocessed = AirOpenedValues { + local: vec![EF::zero(); preprocessed_width], + next: vec![EF::zero(); preprocessed_width], + }; + let main_width = chip.width(); + let main = + AirOpenedValues { local: vec![EF::zero(); main_width], next: vec![EF::zero(); main_width] }; + + let permutation_width = chip.permutation_width(); + let permutation = AirOpenedValues { + local: vec![EF::zero(); permutation_width * EF::D], + next: vec![EF::zero(); permutation_width * EF::D], + }; + let quotient_width = chip.quotient_width(); + let quotient = (0..quotient_width).map(|_| vec![EF::zero(); EF::D]).collect::>(); + + ChipOpenedValues { + preprocessed, + main, + permutation, + quotient, + global_cumulative_sum: EF::zero(), + local_cumulative_sum: EF::zero(), + log_degree, + } +} + #[derive(Clone)] pub struct MerkleProofVariable> { pub index: Vec, @@ -348,10 +495,10 @@ pub mod tests { type F = InnerVal; type A = RiscvAir; + type SC = BabyBearPoseidon2; pub fn build_verify_shard_with_provers< - C: CircuitConfig>, - SC: BabyBearFriConfigVariable + Default + Sync + Send, + C: CircuitConfig>, CoreP: MachineProver, RecP: MachineProver>, >( @@ -359,18 +506,7 @@ pub mod tests { elf: &[u8], opts: SP1CoreOpts, num_shards_in_batch: Option, - ) -> (TracedVec>, Vec>) - where - SC::Challenger: Send, - <::ValMmcs as Mmcs>::ProverData< - RowMajorMatrix, - >: Send + Sync, - <::ValMmcs as Mmcs>::Commitment: Send + Sync, - <::ValMmcs as Mmcs>::Proof: Send, - StarkVerifyingKey: Witnessable>, - ShardProof: Witnessable>, - { - // Generate a dummy proof. + ) -> (TracedVec>, Vec>) { setup_logger(); let machine = RiscvAir::::machine(SC::default()); @@ -380,7 +516,6 @@ pub mod tests { .unwrap(); let mut challenger = machine.config().challenger(); machine.verify(&vk, &proof, &mut challenger).unwrap(); - println!("Proof generated successfully"); // Observe all the commitments. let mut builder = Builder::::default(); @@ -390,7 +525,8 @@ pub mod tests { // Add a hash invocation, since the poseidon2 table expects that it's in the first row. let mut challenger = config.challenger_variable(&mut builder); // let vk = VerifyingKeyVariable::from_constant_key_babybear(&mut builder, &vk); - vk.write(&mut witness_stream); + Witnessable::::write(&vk, &mut witness_stream); + let vk_value = vk.clone(); let vk: VerifyingKeyVariable<_, _> = vk.read(&mut builder); vk.observe_into(&mut builder, &mut challenger); @@ -398,8 +534,10 @@ pub mod tests { .shard_proofs .into_iter() .map(|proof| { - proof.write(&mut witness_stream); - proof.read(&mut builder) + let shape = proof.shape(); + let dummy_proof = dummy_shard_proof(&machine, &vk_value, &shape); + Witnessable::::write(&proof, &mut witness_stream); + dummy_proof.read(&mut builder) }) .collect::>(); // Observe all the commitments, and put the proofs into the witness stream. @@ -432,12 +570,12 @@ pub mod tests { #[test] fn test_verify_shard_inner() { let (operations, stream) = - build_verify_shard_with_provers::< - InnerConfig, - BabyBearPoseidon2, - CpuProver<_, _>, - CpuProver<_, _>, - >(BabyBearPoseidon2::new(), FIBONACCI_ELF, SP1CoreOpts::default(), Some(2)); + build_verify_shard_with_provers::, CpuProver<_, _>>( + BabyBearPoseidon2::new(), + FIBONACCI_ELF, + SP1CoreOpts::default(), + Some(2), + ); run_test_recursion_with_prover::>(operations, stream); } } diff --git a/crates/sdk/src/lib.rs b/crates/sdk/src/lib.rs index e5fa4289e3..b9b7c461aa 100644 --- a/crates/sdk/src/lib.rs +++ b/crates/sdk/src/lib.rs @@ -81,7 +81,7 @@ impl ProverClient { #[cfg(not(feature = "cuda"))] prover: Box::new(CpuProver::new()), #[cfg(feature = "cuda")] - prover: Box::new(CudaProver::new()), + prover: Box::new(CudaProver::new(SP1Prover::new())), }, "network" => { cfg_if! { diff --git a/crates/sdk/src/network/auth.rs b/crates/sdk/src/network/auth.rs index 45354a66ba..7cb2c2ef7d 100644 --- a/crates/sdk/src/network/auth.rs +++ b/crates/sdk/src/network/auth.rs @@ -34,6 +34,12 @@ sol! { string description; } + struct ModifyCpuCycles { + uint64 nonce; + string proof_id; + uint64 cycles; + } + struct FulfillProof { uint64 nonce; string proof_id; @@ -116,6 +122,18 @@ impl NetworkAuth { self.sign_message(type_struct).await } + /// Signs a message to modify the CPU cycles for a proof. The proof must have been previously + /// claimed by the signer first. + pub async fn sign_modify_cpu_cycles_message( + &self, + nonce: u64, + proof_id: &str, + cycles: u64, + ) -> Result> { + let type_struct = ModifyCpuCycles { nonce, proof_id: proof_id.to_string(), cycles }; + self.sign_message(type_struct).await + } + /// Signs a message to fulfill a proof. The proof must have been previously claimed by the /// signer first. pub async fn sign_fulfill_proof_message(&self, nonce: u64, proof_id: &str) -> Result> { diff --git a/crates/sdk/src/network/client.rs b/crates/sdk/src/network/client.rs index 936d55fdb8..f9bdd94ccd 100644 --- a/crates/sdk/src/network/client.rs +++ b/crates/sdk/src/network/client.rs @@ -2,7 +2,9 @@ use std::{env, time::Duration}; use crate::{ network::auth::NetworkAuth, - proto::network::{UnclaimProofRequest, UnclaimReason}, + proto::network::{ + ModifyCpuCyclesRequest, ModifyCpuCyclesResponse, UnclaimProofRequest, UnclaimReason, + }, }; use anyhow::{Context, Ok, Result}; use futures::{future::join_all, Future}; @@ -219,6 +221,28 @@ impl NetworkClient { Ok(()) } + /// Modifies the CPU cycles for a proof. May be called by the claimer after the proof has been + /// claimed. Returns an error if the proof is not in a PROOF_CLAIMED state or if the caller is + /// not the claimer. + pub async fn modify_cpu_cycles( + &self, + proof_id: &str, + cycles: u64, + ) -> Result { + let nonce = self.get_nonce().await?; + let signature = self.auth.sign_modify_cpu_cycles_message(nonce, proof_id, cycles).await?; + let res = self + .with_error_handling(self.rpc.modify_cpu_cycles(ModifyCpuCyclesRequest { + signature, + nonce, + proof_id: proof_id.to_string(), + cycles, + })) + .await?; + + Ok(res) + } + /// Fulfill a proof. Should only be called after the proof has been uploaded. Returns an error /// if the proof is not in a PROOF_CLAIMED state or if the caller is not the claimer. pub async fn fulfill_proof(&self, proof_id: &str) -> Result { diff --git a/crates/sdk/src/proto/network.rs b/crates/sdk/src/proto/network.rs index a0e0d883e5..d2f2e118e2 100644 --- a/crates/sdk/src/proto/network.rs +++ b/crates/sdk/src/proto/network.rs @@ -116,6 +116,29 @@ pub struct UnclaimProofRequest { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct UnclaimProofResponse {} +/// The request to update a proof's CPU cycle count. +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ModifyCpuCyclesRequest { + /// The signature of the message. + #[prost(bytes = "vec", tag = "1")] + pub signature: ::prost::alloc::vec::Vec, + /// The nonce for the account. + #[prost(uint64, tag = "2")] + pub nonce: u64, + /// The proof identifier. + #[prost(string, tag = "3")] + pub proof_id: ::prost::alloc::string::String, + /// The number of CPU cycles for this proof. + #[prost(uint64, tag = "4")] + pub cycles: u64, +} +/// The response for updating a proof's CPU cycle count, empty on success. +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ModifyCpuCyclesResponse {} /// The request to fulfill a proof. MUST be called after the proof has been uploaded and MUST be called /// when the proof is in a PROOF_CLAIMED state. #[derive(serde::Serialize, serde::Deserialize)] @@ -523,6 +546,11 @@ pub trait NetworkService { ctx: twirp::Context, req: UnclaimProofRequest, ) -> Result; + async fn modify_cpu_cycles( + &self, + ctx: twirp::Context, + req: ModifyCpuCyclesRequest, + ) -> Result; async fn fulfill_proof( &self, ctx: twirp::Context, @@ -583,6 +611,12 @@ where api.unclaim_proof(ctx, req).await }, ) + .route( + "/ModifyCpuCycles", + |api: std::sync::Arc, ctx: twirp::Context, req: ModifyCpuCyclesRequest| async move { + api.modify_cpu_cycles(ctx, req).await + }, + ) .route( "/FulfillProof", |api: std::sync::Arc, ctx: twirp::Context, req: FulfillProofRequest| async move { @@ -639,6 +673,10 @@ pub trait NetworkServiceClient: Send + Sync + std::fmt::Debug { &self, req: UnclaimProofRequest, ) -> Result; + async fn modify_cpu_cycles( + &self, + req: ModifyCpuCyclesRequest, + ) -> Result; async fn fulfill_proof( &self, req: FulfillProofRequest, @@ -692,6 +730,13 @@ impl NetworkServiceClient for twirp::client::Client { let url = self.base_url.join("network.NetworkService/UnclaimProof")?; self.request(url, req).await } + async fn modify_cpu_cycles( + &self, + req: ModifyCpuCyclesRequest, + ) -> Result { + let url = self.base_url.join("network.NetworkService/ModifyCpuCycles")?; + self.request(url, req).await + } async fn fulfill_proof( &self, req: FulfillProofRequest, diff --git a/crates/sdk/src/provers/cuda.rs b/crates/sdk/src/provers/cuda.rs index 78ec8c6fa6..a10bdb33f3 100644 --- a/crates/sdk/src/provers/cuda.rs +++ b/crates/sdk/src/provers/cuda.rs @@ -18,10 +18,9 @@ pub struct CudaProver { impl CudaProver { /// Creates a new [CudaProver]. - pub fn new() -> Self { - let prover = SP1Prover::new(); + pub fn new(prover: SP1Prover) -> Self { let cuda_prover = SP1CudaProver::new(); - Self { prover, cuda_prover } + Self { prover, cuda_prover: cuda_prover.expect("Failed to initialize CUDA prover") } } } @@ -121,6 +120,6 @@ impl Prover for CudaProver { impl Default for CudaProver { fn default() -> Self { - Self::new() + Self::new(SP1Prover::new()) } } diff --git a/crates/stark/src/prover.rs b/crates/stark/src/prover.rs index 35c98eaa30..89851370f8 100644 --- a/crates/stark/src/prover.rs +++ b/crates/stark/src/prover.rs @@ -266,7 +266,7 @@ pub trait MachineProver>: /// A proving key for any [`MachineAir`] that is agnostic to hardware. pub trait MachineProvingKey: Send + Sync { /// The main commitment. - fn commit(&self) -> Com; + fn preprocessed_commit(&self) -> Com; /// The start pc. fn pc_start(&self) -> Val; @@ -853,7 +853,7 @@ where PcsProverData: Send + Sync + Serialize + DeserializeOwned, Com: Send + Sync, { - fn commit(&self) -> Com { + fn preprocessed_commit(&self) -> Com { self.commit.clone() }