Skip to content

Commit

Permalink
Merge branch 'master' into feature/mininny/audit-5
Browse files Browse the repository at this point in the history
  • Loading branch information
mininny authored Jan 8, 2025
2 parents f35812f + cdddf9e commit 66d0362
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 14 deletions.
2 changes: 1 addition & 1 deletion docs/riscv.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## Helpful learning resources

- rv32 instruction set cheat sheet: http://blog.translusion.com/images/posts/RISC-V-cheatsheet-RV32I-4-3.pdf
- rv32: reference card: https://github.com/jameslzhu/riscv-card/blob/master/riscv-card.pdf
- rv32: reference card: https://github.com/jameslzhu/riscv-card/releases/download/latest/riscv-card.pdf
- online riscv32 interpreter: https://www.cs.cornell.edu/courses/cs3410/2019sp/riscv/interpreter/#
- specs: https://riscv.org/technical/specifications/
- Berkely riscv card: https://inst.eecs.berkeley.edu/~cs61c/fa18/img/riscvcard.pdf
Expand Down
2 changes: 1 addition & 1 deletion rvgo/fast/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (m *Memory) SetUnaligned(addr uint64, dat []byte) {
m.Invalidate(addr) // invalidate this branch of memory, now that the value changed
}

copy(p.Data[pageAddr:], dat)
copy(p.Data[pageAddr:], dat[d:])
}

func (m *Memory) GetUnaligned(addr uint64, dest []byte) {
Expand Down
9 changes: 9 additions & 0 deletions rvgo/fast/memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,12 @@ func TestMemoryBinary(t *testing.T) {
m.GetUnaligned(8, dest[:])
require.Equal(t, uint8(123), dest[0])
}

func TestMemoryInvalidSetUnaligned(t *testing.T) {
t.Run("SetUnaligned incorrectly writes to next page", func(t *testing.T) {
m := NewMemory()
m.SetUnaligned(0x0FFE, []byte{0xaa, 0xbb, 0xcc, 0xdd})
require.Equal(t, m.pages[0].Data[4094:], []byte{0xaa, 0xbb})
require.Equal(t, m.pages[1].Data[0:2], []byte{0xcc, 0xdd})
})
}
16 changes: 13 additions & 3 deletions rvgo/fast/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -839,13 +839,23 @@ func (inst *InstrumentedState) riscvStep() (outErr error) {
imm := parseImmTypeJ(instr)
rdValue := add64(pc, toU64(4))
setRegister(rd, rdValue)
setPC(add64(pc, signExtend64(shl64(toU64(1), imm), toU64(20)))) // signed offset in multiples of 2 bytes (last bit is there, but ignored)

newPC := add64(pc, signExtend64(shl64(toU64(1), imm), toU64(20)))
if newPC&3 != 0 { // quick target alignment check
revertWithCode(riscv.ErrNotAlignedAddr, fmt.Errorf("pc %d not aligned with 4 bytes", newPC))
}
setPC(newPC) // signed offset in multiples of 2 bytes (last bit is there, but ignored)
case 0x67: // 110_0111: JALR = Jump and link register
rs1Value := getRegister(rs1)
imm := parseImmTypeI(instr)
rdValue := add64(pc, toU64(4))
setRegister(rd, rdValue)
setPC(and64(add64(rs1Value, signExtend64(imm, toU64(11))), xor64(u64Mask(), toU64(1)))) // least significant bit is set to 0

newPC := and64(add64(rs1Value, signExtend64(imm, toU64(11))), xor64(u64Mask(), toU64(1)))
if newPC&3 != 0 { // quick addr alignment check
revertWithCode(riscv.ErrNotAlignedAddr, fmt.Errorf("pc %d not aligned with 4 bytes", newPC))
}
setPC(newPC) // least significant bit is set to 0
case 0x73: // 111_0011: environment things
switch funct3 {
case 0: // 000 = ECALL/EBREAK
Expand Down Expand Up @@ -873,7 +883,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) {
// 0b010 == RV32A W variants
// 0b011 == RV64A D variants
size := shl64(funct3, toU64(1))
if lt64(size, toU64(4)) != 0 {
if lt64(size, toU64(4)) != 0 || gt64(size, toU64(8)) != 0 {
revertWithCode(riscv.ErrBadAMOSize, fmt.Errorf("bad AMO size: %d", size))
}
addr := getRegister(rs1)
Expand Down
27 changes: 24 additions & 3 deletions rvgo/slow/vm.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package slow

import (
"bytes"
"encoding/binary"
"fmt"

Expand Down Expand Up @@ -121,6 +122,12 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err
return
}

// First 4 bytes of keccak256("step(bytes,bytes,bytes32)")
expectedSelector := []byte{0xe1, 0x4c, 0xed, 0x32}
if len(calldata) < 4 || !bytes.Equal(calldata[:4], expectedSelector) {
panic("invalid function selector")
}

stateContentOffset := uint8(4 + 32 + 32 + 32 + 32)
if iszero(eq(b32asBEWord(calldataload(toU64(4+32*3))), shortToU256(stateSize))) {
// user-provided state size must match expected state size
Expand Down Expand Up @@ -453,6 +460,10 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err
setMemoryB32(rightAddr, beWordAsB32(right), proofIndexR)
}
storeMem := func(addr U64, size U64, value U64, proofIndexL uint8, proofIndexR uint8) {
if size.val() > 8 {
revertWithCode(riscv.ErrStoreExceeds8Bytes, fmt.Errorf("cannot store more than 8 bytes: %d", size))
}

storeMemUnaligned(addr, size, u64ToU256(value), proofIndexL, proofIndexR)
}

Expand Down Expand Up @@ -1012,13 +1023,23 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err
imm := parseImmTypeJ(instr)
rdValue := add64(pc, toU64(4))
setRegister(rd, rdValue)
setPC(add64(pc, signExtend64(shl64(toU64(1), imm), toU64(20)))) // signed offset in multiples of 2 bytes (last bit is there, but ignored)

newPC := add64(pc, signExtend64(shl64(toU64(1), imm), toU64(20)))
if and64(newPC, toU64(3)) != (U64{}) { // quick target alignment check
revertWithCode(riscv.ErrNotAlignedAddr, fmt.Errorf("pc %d not aligned with 4 bytes", newPC))
}
setPC(newPC) // signed offset in multiples of 2 bytes (last bit is there, but ignored)
case 0x67: // 110_0111: JALR = Jump and link register
rs1Value := getRegister(rs1)
imm := parseImmTypeI(instr)
rdValue := add64(pc, toU64(4))
setRegister(rd, rdValue)
setPC(and64(add64(rs1Value, signExtend64(imm, toU64(11))), xor64(u64Mask(), toU64(1)))) // least significant bit is set to 0

newPC := and64(add64(rs1Value, signExtend64(imm, toU64(11))), xor64(u64Mask(), toU64(1)))
if and64(newPC, toU64(3)) != (U64{}) { // quick target alignment check
revertWithCode(riscv.ErrNotAlignedAddr, fmt.Errorf("pc %d not aligned with 4 bytes", newPC))
}
setPC(newPC) // least significant bit is set to 0
case 0x73: // 111_0011: environment things
switch funct3.val() {
case 0: // 000 = ECALL/EBREAK
Expand Down Expand Up @@ -1046,7 +1067,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err
// 0b010 == RV32A W variants
// 0b011 == RV64A D variants
size := shl64(funct3, toU64(1))
if lt64(size, toU64(4)) != (U64{}) {
if or64(lt64(size, toU64(4)), gt64(size, toU64(8))) != (U64{}) {
revertWithCode(riscv.ErrBadAMOSize, fmt.Errorf("bad AMO size: %d", size))
}
addr := getRegister(rs1)
Expand Down
4 changes: 2 additions & 2 deletions rvsol/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ forge test -vvv --ffi
- There are few issues with Foundry.
- Run script directly without manual build does not work with the current version of Foundry (2024-03-15 `3fa0270`).
You **must run** `make build` **before** running the deploy script. ([issue](https://github.com/foundry-rs/foundry/issues/6572))
- Some older version(2024-02-01 `2f4b5db`) of Foundry makes a dependency error reproted above issue.
- Some older version(2024-02-01 `2f4b5db`) of Foundry makes a dependency error reported above issue.
Use the **latest version** of Foundry!
- The deploy script can be run only once on the devnet because of the `create2` salt.
To rerun the script for dev purpose, you must restart the devnet with `make devnet-clean && make devnet-up` command on the monorepo.
To rerun the script for dev purpose, you must restart the devnet with `make devnet-clean && make devnet-up` command on the monorepo.
19 changes: 16 additions & 3 deletions rvsol/src/RISCV.sol
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,8 @@ contract RISCV is IBigStepper {
}

function storeMem(addr, size, value, proofIndexL, proofIndexR) {
if gt(size, 8) { revertWithCode(0xbad512e8) } // cannot store more than 8 bytes

storeMemUnaligned(addr, size, u64ToU256(value), proofIndexL, proofIndexR)
}

Expand Down Expand Up @@ -1500,7 +1502,13 @@ contract RISCV is IBigStepper {
let imm := parseImmTypeJ(instr)
let rdValue := add64(_pc, toU64(4))
setRegister(rd, rdValue)
setPC(add64(_pc, signExtend64(shl64(toU64(1), imm), toU64(20)))) // signed offset in multiples of 2

let newPC := add64(_pc, signExtend64(shl64(toU64(1), imm), toU64(20)))
if and64(newPC, toU64(3)) {
// quick target alignment check
revertWithCode(0xbad10ad0) // target not aligned with 4 bytes
}
setPC(newPC) // signed offset in multiples of 2
// bytes (last bit is there, but ignored)
}
case 0x67 {
Expand All @@ -1509,8 +1517,13 @@ contract RISCV is IBigStepper {
let imm := parseImmTypeI(instr)
let rdValue := add64(_pc, toU64(4))
setRegister(rd, rdValue)
setPC(and64(add64(rs1Value, signExtend64(imm, toU64(11))), xor64(u64Mask(), toU64(1)))) // least
// significant bit is set to 0

let newPC := and64(add64(rs1Value, signExtend64(imm, toU64(11))), xor64(u64Mask(), toU64(1)))
if and64(newPC, toU64(3)) {
// quick target alignment check
revertWithCode(0xbad10ad0) // target not aligned with 4 bytes
}
setPC(newPC) // least significant bit is set to 0
}
case 0x73 {
// 111_0011: environment things
Expand Down
14 changes: 13 additions & 1 deletion rvsol/test/RISCV.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2246,7 +2246,7 @@ contract RISCV_Test is CommonTest {
/* J Type instructions */

function test_jal_succeeds() public {
uint32 imm = 0xbef054ae;
uint32 imm = 0xbef054ac;
uint32 insn = encodeJType(0x6f, 5, imm); // jal x5, imm
(State memory state, bytes memory proof) = constructRISCVState(0, insn);
bytes memory encodedState = encodeState(state);
Expand Down Expand Up @@ -2472,6 +2472,18 @@ contract RISCV_Test is CommonTest {
vm.expectRevert(hex"00000000000000000000000000000000000000000000000000000000f001ca11");
riscv.step(encodedState, proof, 0);
}

function test_revert_unaligned_jal_instruction() public {
// 0xbef054ae % 4 != 0
uint32 imm = 0xbef054ae;
uint32 insn = encodeJType(0x6f, 5, imm); // jal x5, imm
(State memory state, bytes memory proof) = constructRISCVState(0, insn);
bytes memory encodedState = encodeState(state);

vm.expectRevert(hex"00000000000000000000000000000000000000000000000000000000bad10ad0");
riscv.step(encodedState, proof, 0);
}

/* Helper methods */

function encodeState(State memory state) internal pure returns (bytes memory) {
Expand Down

0 comments on commit 66d0362

Please sign in to comment.