Skip to content

Commit

Permalink
refactor(target_chains/starknet): remove Result from merkle_tree and …
Browse files Browse the repository at this point in the history
…pyth setters (#1548)

* refactor(target_chains/starknet): remove Result from merkle_tree

* refactor(target_chains/starknet): remove Result from pyth contract setters
  • Loading branch information
Riateche authored May 6, 2024
1 parent 55cbe62 commit 42b64ac
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 75 deletions.
26 changes: 15 additions & 11 deletions target_chains/starknet/contracts/src/merkle_tree.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use super::reader::{Reader, ReaderImpl};
use super::byte_array::ByteArray;
use super::util::ONE_SHIFT_96;
use core::cmp::{min, max};
use core::panic_with_felt252;

const MERKLE_LEAF_PREFIX: u8 = 0;
const MERKLE_NODE_PREFIX: u8 = 1;
Expand All @@ -14,6 +15,15 @@ pub enum MerkleVerificationError {
DigestMismatch,
}

impl MerkleVerificationErrorIntoFelt252 of Into<MerkleVerificationError, felt252> {
fn into(self: MerkleVerificationError) -> felt252 {
match self {
MerkleVerificationError::Reader(err) => err.into(),
MerkleVerificationError::DigestMismatch => 'digest mismatch',
}
}
}

#[generate_trait]
impl ResultReaderToMerkleVerification<T> of ResultReaderToMerkleVerificationTrait<T> {
fn map_err(self: Result<T, pyth::reader::Error>) -> Result<T, MerkleVerificationError> {
Expand All @@ -24,12 +34,11 @@ impl ResultReaderToMerkleVerification<T> of ResultReaderToMerkleVerificationTrai
}
}

fn leaf_hash(mut reader: Reader) -> Result<u256, super::reader::Error> {
fn leaf_hash(mut reader: Reader) -> u256 {
let mut hasher = HasherImpl::new();
hasher.push_u8(MERKLE_LEAF_PREFIX);
hasher.push_reader(ref reader);
let hash = hasher.finalize() / ONE_SHIFT_96;
Result::Ok(hash)
hasher.finalize() / ONE_SHIFT_96
}

fn node_hash(a: u256, b: u256) -> u256 {
Expand All @@ -40,25 +49,20 @@ fn node_hash(a: u256, b: u256) -> u256 {
hasher.finalize() / ONE_SHIFT_96
}

pub fn read_and_verify_proof(
root_digest: u256, message: @ByteArray, ref reader: Reader
) -> Result<(), MerkleVerificationError> {
pub fn read_and_verify_proof(root_digest: u256, message: @ByteArray, ref reader: Reader) {
let mut message_reader = ReaderImpl::new(message.clone());
let mut current_hash = leaf_hash(message_reader.clone()).map_err()?;
let mut current_hash = leaf_hash(message_reader.clone());

let proof_size = reader.read_u8();
let mut i = 0;

let mut result = Result::Ok(());
while i < proof_size {
let sibling_digest = reader.read_u160();
current_hash = node_hash(current_hash, sibling_digest);
i += 1;
};
result?;

if root_digest != current_hash {
return Result::Err(MerkleVerificationError::DigestMismatch);
panic_with_felt252(MerkleVerificationError::DigestMismatch.into());
}
Result::Ok(())
}
106 changes: 43 additions & 63 deletions target_chains/starknet/contracts/src/pyth.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ pub use pyth::{Event, PriceFeedUpdateEvent};
pub trait IPyth<T> {
fn get_price_unsafe(self: @T, price_id: u256) -> Result<Price, GetPriceUnsafeError>;
fn get_ema_price_unsafe(self: @T, price_id: u256) -> Result<Price, GetPriceUnsafeError>;
fn set_data_sources(
ref self: T, sources: Array<DataSource>
) -> Result<(), GovernanceActionError>;
fn set_fee(ref self: T, single_update_fee: u256) -> Result<(), GovernanceActionError>;
fn update_price_feeds(ref self: T, data: ByteArray) -> Result<(), UpdatePriceFeedsError>;
fn set_data_sources(ref self: T, sources: Array<DataSource>);
fn set_fee(ref self: T, single_update_fee: u256);
fn update_price_feeds(ref self: T, data: ByteArray);
}

#[derive(Copy, Drop, Debug, Serde, PartialEq)]
Expand Down Expand Up @@ -333,51 +331,44 @@ mod pyth {
Result::Ok(price)
}

fn set_data_sources(
ref self: ContractState, sources: Array<DataSource>
) -> Result<(), GovernanceActionError> {
fn set_data_sources(ref self: ContractState, sources: Array<DataSource>) {
if self.owner.read() != get_caller_address() {
return Result::Err(GovernanceActionError::AccessDenied);
panic_with_felt252(GovernanceActionError::AccessDenied.into());
}
write_data_sources(ref self, sources);
Result::Ok(())
}

fn set_fee(
ref self: ContractState, single_update_fee: u256
) -> Result<(), GovernanceActionError> {
fn set_fee(ref self: ContractState, single_update_fee: u256) {
if self.owner.read() != get_caller_address() {
return Result::Err(GovernanceActionError::AccessDenied);
panic_with_felt252(GovernanceActionError::AccessDenied.into());
}
self.single_update_fee.write(single_update_fee);
Result::Ok(())
}

fn update_price_feeds(
ref self: ContractState, data: ByteArray
) -> Result<(), UpdatePriceFeedsError> {
fn update_price_feeds(ref self: ContractState, data: ByteArray) {
let mut reader = ReaderImpl::new(data);
let x = reader.read_u32();
if x != ACCUMULATOR_MAGIC {
return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into());
}
if reader.read_u8() != MAJOR_VERSION {
return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into());
}
if reader.read_u8() < MINIMUM_ALLOWED_MINOR_VERSION {
return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into());
}

let trailing_header_size = reader.read_u8();
reader.skip(trailing_header_size);

let update_type: Option<UpdateType> = reader.read_u8().try_into();
let update_type: UpdateType = reader
.read_u8()
.try_into()
.expect(UpdatePriceFeedsError::InvalidUpdateData.into());

match update_type {
Option::Some(v) => match v {
UpdateType::WormholeMerkle => {}
},
Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); }
};
UpdateType::WormholeMerkle => {}
}

let wh_proof_size = reader.read_u16();
let wh_proof = reader.read_byte_array(wh_proof_size.into());
Expand All @@ -388,22 +379,23 @@ mod pyth {
emitter_chain_id: vm.emitter_chain_id, emitter_address: vm.emitter_address
};
if !self.is_valid_data_source.read(source) {
return Result::Err(UpdatePriceFeedsError::InvalidUpdateDataSource);
panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateDataSource.into());
}

let mut payload_reader = ReaderImpl::new(vm.payload);
let x = payload_reader.read_u32();
if x != ACCUMULATOR_WORMHOLE_MAGIC {
return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into());
}

let update_type: Option<UpdateType> = payload_reader.read_u8().try_into();
let update_type: UpdateType = payload_reader
.read_u8()
.try_into()
.expect(UpdatePriceFeedsError::InvalidUpdateData.into());

match update_type {
Option::Some(v) => match v {
UpdateType::WormholeMerkle => {}
},
Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); }
};
UpdateType::WormholeMerkle => {}
}

let _slot = payload_reader.read_u64();
let _ring_size = payload_reader.read_u32();
Expand All @@ -419,50 +411,39 @@ mod pyth {
let caller = execution_info.caller_address;
let contract = execution_info.contract_address;
if fee_contract.allowance(caller, contract) < total_fee {
return Result::Err(UpdatePriceFeedsError::InsufficientFeeAllowance);
panic_with_felt252(UpdatePriceFeedsError::InsufficientFeeAllowance.into());
}
if !fee_contract.transferFrom(caller, contract, total_fee) {
return Result::Err(UpdatePriceFeedsError::InsufficientFeeAllowance);
panic_with_felt252(UpdatePriceFeedsError::InsufficientFeeAllowance.into());
}

let mut i = 0;
let mut result = Result::Ok(());
while i < num_updates {
let r = read_and_verify_message(ref reader, root_digest);
match r {
Result::Ok(message) => { update_latest_price_if_necessary(ref self, message); },
Result::Err(err) => {
result = Result::Err(err);
break;
}
}
let message = read_and_verify_message(ref reader, root_digest);
update_latest_price_if_necessary(ref self, message);
i += 1;
};
result?;

if reader.len() != 0 {
return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into());
}

Result::Ok(())
}
}

fn read_and_verify_message(
ref reader: Reader, root_digest: u256
) -> Result<PriceFeedMessage, UpdatePriceFeedsError> {
fn read_and_verify_message(ref reader: Reader, root_digest: u256) -> PriceFeedMessage {
let message_size = reader.read_u16();
let message = reader.read_byte_array(message_size.into());
read_and_verify_proof(root_digest, @message, ref reader).map_err()?;
read_and_verify_proof(root_digest, @message, ref reader);

let mut message_reader = ReaderImpl::new(message);
let message_type: Option<MessageType> = message_reader.read_u8().try_into();
let message_type: MessageType = message_reader
.read_u8()
.try_into()
.expect(UpdatePriceFeedsError::InvalidUpdateData.into());

match message_type {
Option::Some(v) => match v {
MessageType::PriceFeed => {}
},
Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); }
};
MessageType::PriceFeed => {}
}

let price_id = message_reader.read_u256();
let price = u64_as_i64(message_reader.read_u64());
Expand All @@ -473,10 +454,9 @@ mod pyth {
let ema_price = u64_as_i64(message_reader.read_u64());
let ema_conf = message_reader.read_u64();

let message = PriceFeedMessage {
PriceFeedMessage {
price_id, price, conf, expo, publish_time, prev_publish_time, ema_price, ema_conf,
};
Result::Ok(message)
}
}

fn update_latest_price_if_necessary(ref self: ContractState, message: PriceFeedMessage) {
Expand Down
2 changes: 1 addition & 1 deletion target_chains/starknet/contracts/tests/pyth.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ fn update_price_feeds_works() {
let mut spy = spy_events(SpyOn::One(pyth.contract_address));

start_prank(CheatTarget::One(pyth.contract_address), user.try_into().unwrap());
pyth.update_price_feeds(good_update1()).unwrap_with_felt252();
pyth.update_price_feeds(good_update1());
stop_prank(CheatTarget::One(pyth.contract_address));

spy.fetch_events();
Expand Down

0 comments on commit 42b64ac

Please sign in to comment.