diff --git a/docs/riscv.md b/docs/riscv.md index 54740ac..40bff5f 100644 --- a/docs/riscv.md +++ b/docs/riscv.md @@ -3,7 +3,7 @@ ## Helpful learning resources - rv32 instruction set cheat sheet: http://blog.translusion.com/images/posts/RISC-V-cheatsheet-RV32I-4-3.pdf -- rv32: reference card: https://github.com/jameslzhu/riscv-card/blob/master/riscv-card.pdf +- rv32: reference card: https://github.com/jameslzhu/riscv-card/releases/download/latest/riscv-card.pdf - online riscv32 interpreter: https://www.cs.cornell.edu/courses/cs3410/2019sp/riscv/interpreter/# - specs: https://riscv.org/technical/specifications/ - Berkely riscv card: https://inst.eecs.berkeley.edu/~cs61c/fa18/img/riscvcard.pdf diff --git a/rvgo/fast/memory.go b/rvgo/fast/memory.go index e239fcd..0b65aa7 100644 --- a/rvgo/fast/memory.go +++ b/rvgo/fast/memory.go @@ -128,7 +128,7 @@ func (m *Memory) SetUnaligned(addr uint64, dat []byte) { m.Invalidate(addr) // invalidate this branch of memory, now that the value changed } - copy(p.Data[pageAddr:], dat) + copy(p.Data[pageAddr:], dat[d:]) } func (m *Memory) GetUnaligned(addr uint64, dest []byte) { diff --git a/rvgo/fast/memory_test.go b/rvgo/fast/memory_test.go index 8653a6f..9829ab9 100644 --- a/rvgo/fast/memory_test.go +++ b/rvgo/fast/memory_test.go @@ -412,3 +412,12 @@ func TestMemoryBinary(t *testing.T) { m.GetUnaligned(8, dest[:]) require.Equal(t, uint8(123), dest[0]) } + +func TestMemoryInvalidSetUnaligned(t *testing.T) { + t.Run("SetUnaligned incorrectly writes to next page", func(t *testing.T) { + m := NewMemory() + m.SetUnaligned(0x0FFE, []byte{0xaa, 0xbb, 0xcc, 0xdd}) + require.Equal(t, m.pages[0].Data[4094:], []byte{0xaa, 0xbb}) + require.Equal(t, m.pages[1].Data[0:2], []byte{0xcc, 0xdd}) + }) +} diff --git a/rvgo/fast/vm.go b/rvgo/fast/vm.go index bf9e60b..98cdc76 100644 --- a/rvgo/fast/vm.go +++ b/rvgo/fast/vm.go @@ -839,13 +839,23 @@ func (inst *InstrumentedState) riscvStep() (outErr error) { imm := parseImmTypeJ(instr) rdValue := add64(pc, toU64(4)) setRegister(rd, rdValue) - setPC(add64(pc, signExtend64(shl64(toU64(1), imm), toU64(20)))) // signed offset in multiples of 2 bytes (last bit is there, but ignored) + + newPC := add64(pc, signExtend64(shl64(toU64(1), imm), toU64(20))) + if newPC&3 != 0 { // quick target alignment check + revertWithCode(riscv.ErrNotAlignedAddr, fmt.Errorf("pc %d not aligned with 4 bytes", newPC)) + } + setPC(newPC) // signed offset in multiples of 2 bytes (last bit is there, but ignored) case 0x67: // 110_0111: JALR = Jump and link register rs1Value := getRegister(rs1) imm := parseImmTypeI(instr) rdValue := add64(pc, toU64(4)) setRegister(rd, rdValue) - setPC(and64(add64(rs1Value, signExtend64(imm, toU64(11))), xor64(u64Mask(), toU64(1)))) // least significant bit is set to 0 + + newPC := and64(add64(rs1Value, signExtend64(imm, toU64(11))), xor64(u64Mask(), toU64(1))) + if newPC&3 != 0 { // quick addr alignment check + revertWithCode(riscv.ErrNotAlignedAddr, fmt.Errorf("pc %d not aligned with 4 bytes", newPC)) + } + setPC(newPC) // least significant bit is set to 0 case 0x73: // 111_0011: environment things switch funct3 { case 0: // 000 = ECALL/EBREAK @@ -873,7 +883,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) { // 0b010 == RV32A W variants // 0b011 == RV64A D variants size := shl64(funct3, toU64(1)) - if lt64(size, toU64(4)) != 0 { + if lt64(size, toU64(4)) != 0 || gt64(size, toU64(8)) != 0 { revertWithCode(riscv.ErrBadAMOSize, fmt.Errorf("bad AMO size: %d", size)) } addr := getRegister(rs1) diff --git a/rvgo/slow/vm.go b/rvgo/slow/vm.go index 935e135..c8294b7 100644 --- a/rvgo/slow/vm.go +++ b/rvgo/slow/vm.go @@ -1,6 +1,7 @@ package slow import ( + "bytes" "encoding/binary" "fmt" @@ -121,6 +122,12 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err return } + // First 4 bytes of keccak256("step(bytes,bytes,bytes32)") + expectedSelector := []byte{0xe1, 0x4c, 0xed, 0x32} + if len(calldata) < 4 || !bytes.Equal(calldata[:4], expectedSelector) { + panic("invalid function selector") + } + stateContentOffset := uint8(4 + 32 + 32 + 32 + 32) if iszero(eq(b32asBEWord(calldataload(toU64(4+32*3))), shortToU256(stateSize))) { // user-provided state size must match expected state size @@ -453,6 +460,10 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err setMemoryB32(rightAddr, beWordAsB32(right), proofIndexR) } storeMem := func(addr U64, size U64, value U64, proofIndexL uint8, proofIndexR uint8) { + if size.val() > 8 { + revertWithCode(riscv.ErrStoreExceeds8Bytes, fmt.Errorf("cannot store more than 8 bytes: %d", size)) + } + storeMemUnaligned(addr, size, u64ToU256(value), proofIndexL, proofIndexR) } @@ -1012,13 +1023,23 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err imm := parseImmTypeJ(instr) rdValue := add64(pc, toU64(4)) setRegister(rd, rdValue) - setPC(add64(pc, signExtend64(shl64(toU64(1), imm), toU64(20)))) // signed offset in multiples of 2 bytes (last bit is there, but ignored) + + newPC := add64(pc, signExtend64(shl64(toU64(1), imm), toU64(20))) + if and64(newPC, toU64(3)) != (U64{}) { // quick target alignment check + revertWithCode(riscv.ErrNotAlignedAddr, fmt.Errorf("pc %d not aligned with 4 bytes", newPC)) + } + setPC(newPC) // signed offset in multiples of 2 bytes (last bit is there, but ignored) case 0x67: // 110_0111: JALR = Jump and link register rs1Value := getRegister(rs1) imm := parseImmTypeI(instr) rdValue := add64(pc, toU64(4)) setRegister(rd, rdValue) - setPC(and64(add64(rs1Value, signExtend64(imm, toU64(11))), xor64(u64Mask(), toU64(1)))) // least significant bit is set to 0 + + newPC := and64(add64(rs1Value, signExtend64(imm, toU64(11))), xor64(u64Mask(), toU64(1))) + if and64(newPC, toU64(3)) != (U64{}) { // quick target alignment check + revertWithCode(riscv.ErrNotAlignedAddr, fmt.Errorf("pc %d not aligned with 4 bytes", newPC)) + } + setPC(newPC) // least significant bit is set to 0 case 0x73: // 111_0011: environment things switch funct3.val() { case 0: // 000 = ECALL/EBREAK @@ -1046,7 +1067,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err // 0b010 == RV32A W variants // 0b011 == RV64A D variants size := shl64(funct3, toU64(1)) - if lt64(size, toU64(4)) != (U64{}) { + if or64(lt64(size, toU64(4)), gt64(size, toU64(8))) != (U64{}) { revertWithCode(riscv.ErrBadAMOSize, fmt.Errorf("bad AMO size: %d", size)) } addr := getRegister(rs1) diff --git a/rvsol/README.md b/rvsol/README.md index 24c8abc..f6cf70a 100644 --- a/rvsol/README.md +++ b/rvsol/README.md @@ -40,7 +40,7 @@ forge test -vvv --ffi - There are few issues with Foundry. - Run script directly without manual build does not work with the current version of Foundry (2024-03-15 `3fa0270`). You **must run** `make build` **before** running the deploy script. ([issue](https://github.com/foundry-rs/foundry/issues/6572)) - - Some older version(2024-02-01 `2f4b5db`) of Foundry makes a dependency error reproted above issue. + - Some older version(2024-02-01 `2f4b5db`) of Foundry makes a dependency error reported above issue. Use the **latest version** of Foundry! - The deploy script can be run only once on the devnet because of the `create2` salt. - To rerun the script for dev purpose, you must restart the devnet with `make devnet-clean && make devnet-up` command on the monorepo. \ No newline at end of file + To rerun the script for dev purpose, you must restart the devnet with `make devnet-clean && make devnet-up` command on the monorepo. diff --git a/rvsol/src/RISCV.sol b/rvsol/src/RISCV.sol index a5c7609..16a409e 100644 --- a/rvsol/src/RISCV.sol +++ b/rvsol/src/RISCV.sol @@ -738,6 +738,8 @@ contract RISCV is IBigStepper { } function storeMem(addr, size, value, proofIndexL, proofIndexR) { + if gt(size, 8) { revertWithCode(0xbad512e8) } // cannot store more than 8 bytes + storeMemUnaligned(addr, size, u64ToU256(value), proofIndexL, proofIndexR) } @@ -1500,7 +1502,13 @@ contract RISCV is IBigStepper { let imm := parseImmTypeJ(instr) let rdValue := add64(_pc, toU64(4)) setRegister(rd, rdValue) - setPC(add64(_pc, signExtend64(shl64(toU64(1), imm), toU64(20)))) // signed offset in multiples of 2 + + let newPC := add64(_pc, signExtend64(shl64(toU64(1), imm), toU64(20))) + if and64(newPC, toU64(3)) { + // quick target alignment check + revertWithCode(0xbad10ad0) // target not aligned with 4 bytes + } + setPC(newPC) // signed offset in multiples of 2 // bytes (last bit is there, but ignored) } case 0x67 { @@ -1509,8 +1517,13 @@ contract RISCV is IBigStepper { let imm := parseImmTypeI(instr) let rdValue := add64(_pc, toU64(4)) setRegister(rd, rdValue) - setPC(and64(add64(rs1Value, signExtend64(imm, toU64(11))), xor64(u64Mask(), toU64(1)))) // least - // significant bit is set to 0 + + let newPC := and64(add64(rs1Value, signExtend64(imm, toU64(11))), xor64(u64Mask(), toU64(1))) + if and64(newPC, toU64(3)) { + // quick target alignment check + revertWithCode(0xbad10ad0) // target not aligned with 4 bytes + } + setPC(newPC) // least significant bit is set to 0 } case 0x73 { // 111_0011: environment things diff --git a/rvsol/test/RISCV.t.sol b/rvsol/test/RISCV.t.sol index 24c3f6e..5fc0396 100644 --- a/rvsol/test/RISCV.t.sol +++ b/rvsol/test/RISCV.t.sol @@ -2246,7 +2246,7 @@ contract RISCV_Test is CommonTest { /* J Type instructions */ function test_jal_succeeds() public { - uint32 imm = 0xbef054ae; + uint32 imm = 0xbef054ac; uint32 insn = encodeJType(0x6f, 5, imm); // jal x5, imm (State memory state, bytes memory proof) = constructRISCVState(0, insn); bytes memory encodedState = encodeState(state); @@ -2472,6 +2472,18 @@ contract RISCV_Test is CommonTest { vm.expectRevert(hex"00000000000000000000000000000000000000000000000000000000f001ca11"); riscv.step(encodedState, proof, 0); } + + function test_revert_unaligned_jal_instruction() public { + // 0xbef054ae % 4 != 0 + uint32 imm = 0xbef054ae; + uint32 insn = encodeJType(0x6f, 5, imm); // jal x5, imm + (State memory state, bytes memory proof) = constructRISCVState(0, insn); + bytes memory encodedState = encodeState(state); + + vm.expectRevert(hex"00000000000000000000000000000000000000000000000000000000bad10ad0"); + riscv.step(encodedState, proof, 0); + } + /* Helper methods */ function encodeState(State memory state) internal pure returns (bytes memory) {