Skip to content

Commit

Permalink
refactor: remove subscription and use internal interval (#31)
Browse files Browse the repository at this point in the history
Subscription is really a bad way to interact with the agent. This PR changes the messaging structure to send all prices as a batch update in an configured interval to minimize agent <> publisher interactions. The pythd has undergone a rewrite because there's no async jsonrpc client library in python which supports batch requests.
  • Loading branch information
ali-bahjati authored Sep 13, 2024
1 parent 0f95b2d commit 5733f74
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 211 deletions.
7 changes: 4 additions & 3 deletions config/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# the configuration for the chosen engine as described below.
provider_engine = 'pyth_replicator'

product_update_interval_secs = 10
price_update_interval_secs = 1.0
product_update_interval_secs = 60
health_check_port = 8000

# The health check will return a failure status if no price data has been published within the specified time frame.
Expand All @@ -22,7 +23,7 @@ endpoint = 'ws://127.0.0.1:8910'
# coin_gecko_id = 'bitcoin'

[publisher.pyth_replicator]
http_endpoint = 'https://pythnet.rpcpool.com'
ws_endpoint = 'wss://pythnet.rpcpool.com'
http_endpoint = 'https://api2.pythnet.pyth.network'
ws_endpoint = 'wss://api2.pythnet.pyth.network'
first_mapping = 'AHtgzX45WTKfkPG53L6WYhGEXwQkN1BVknET3sVsLL8J'
program_key = 'FsJ3A3u2vn5cTVofAjvy6y5kwABJAqYWpe4975bi2epH'
2 changes: 1 addition & 1 deletion example_publisher/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
_DEFAULT_CONFIG_PATH = os.path.join("config", "config.toml")


log_level = logging._nameToLevel[os.environ.get("LOG_LEVEL", "DEBUG").upper()]
log_level = logging._nameToLevel[os.environ.get("LOG_LEVEL", "INFO").upper()]
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(log_level))

log = structlog.get_logger()
Expand Down
1 change: 1 addition & 0 deletions example_publisher/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Config:
pythd: Pythd
health_check_port: int
health_check_threshold_secs: int
price_update_interval_secs: float = ts.option(default=1.0)
product_update_interval_secs: int = ts.option(default=60)
coin_gecko: Optional[CoinGeckoConfig] = ts.option(default=None)
pyth_replicator: Optional[PythReplicatorConfig] = ts.option(default=None)
4 changes: 2 additions & 2 deletions example_publisher/providers/pyth_replicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def _update_loop(self) -> None:
update.timestamp,
)

log.info(
log.debug(
"Received a price update", symbol=symbol, price=self._prices[symbol]
)

Expand All @@ -118,7 +118,7 @@ async def _update_accounts_loop(self) -> None:

await asyncio.sleep(self._config.account_update_interval_secs)

def upd_products(self, *args) -> None:
def upd_products(self, product_symbols: List[Symbol]) -> None:
# This provider stores all the possible feeds and
# does not care about the desired products as knowing
# them does not improve the performance of the replicator
Expand Down
103 changes: 47 additions & 56 deletions example_publisher/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from example_publisher.providers.coin_gecko import CoinGecko
from example_publisher.config import Config
from example_publisher.providers.pyth_replicator import PythReplicator
from example_publisher.pythd import Pythd, SubscriptionId
from example_publisher.pythd import PriceUpdate, Pythd, SubscriptionId


log = get_logger()
Expand Down Expand Up @@ -50,7 +50,6 @@ def __init__(self, config: Config) -> None:

self.pythd: Pythd = Pythd(
address=config.pythd.endpoint,
on_notify_price_sched=self.on_notify_price_sched,
)
self.subscriptions: Dict[SubscriptionId, Product] = {}
self.products: List[Product] = []
Expand All @@ -66,18 +65,17 @@ def is_healthy(self) -> bool:
async def start(self):
await self.pythd.connect()

self._product_update_task = asyncio.create_task(
self._start_product_update_loop()
)

async def _start_product_update_loop(self):
await self._upd_products()

self._product_update_task = asyncio.create_task(self._product_update_loop())
self._price_update_task = asyncio.create_task(self._price_update_loop())

self.provider.start()

async def _product_update_loop(self):
while True:
await self._upd_products()
await self._subscribe_notify_price_sched()
await asyncio.sleep(self.config.product_update_interval_secs)
await self._upd_products()

async def _upd_products(self):
log.debug("fetching product accounts from Pythd")
Expand Down Expand Up @@ -114,58 +112,51 @@ async def _upd_products(self):

self.provider.upd_products([product.symbol for product in self.products])

async def _subscribe_notify_price_sched(self):
# Subscribe to Pythd's notify_price_sched for each product that
# is not subscribed yet. Unfortunately there is no way to unsubscribe
# to the prices that are no longer available.
log.debug("subscribing to notify_price_sched")

subscriptions = {}
for product in self.products:
if not product.subscription_id:
subscription_id = await self.pythd.subscribe_price_sched(
product.price_account
async def _price_update_loop(self):
while True:
price_updates = []
for product in self.products:
price = self.provider.latest_price(product.symbol)
if not price:
log.info("latest price not available", symbol=product.symbol)
continue

scaled_price = self.apply_exponent(price.price, product.exponent)
scaled_conf = self.apply_exponent(price.conf, product.exponent)

price_updates.append(
PriceUpdate(
account=product.price_account,
price=scaled_price,
conf=scaled_conf,
status=TRADING,
)
)
log.debug(
"sending price update",
symbol=product.symbol,
price_account=product.price_account,
price=price.price,
conf=price.conf,
scaled_price=scaled_price,
scaled_conf=scaled_conf,
)
product.subscription_id = subscription_id

subscriptions[product.subscription_id] = product

self.subscriptions = subscriptions

async def on_notify_price_sched(self, subscription: int) -> None:

log.debug("received notify_price_sched", subscription=subscription)
if subscription not in self.subscriptions:
return
self.last_successful_update = (
price.timestamp
if self.last_successful_update is None
else max(self.last_successful_update, price.timestamp)
)

# Look up the current price and confidence interval of the product
product = self.subscriptions[subscription]
price = self.provider.latest_price(product.symbol)
if not price:
log.info("latest price not available", symbol=product.symbol)
return
log.info(
"sending batch update_price",
num_price_updates=len(price_updates),
total_products=len(self.products),
)

# Scale the price and confidence interval using the Pyth exponent
scaled_price = self.apply_exponent(price.price, product.exponent)
scaled_conf = self.apply_exponent(price.conf, product.exponent)
await self.pythd.update_price_batch(price_updates)

# Send the price update
log.info(
"sending update_price",
product_account=product.product_account,
price_account=product.price_account,
price=scaled_price,
conf=scaled_conf,
symbol=product.symbol,
)
await self.pythd.update_price(
product.price_account, scaled_price, scaled_conf, TRADING
)
self.last_successful_update = (
price.timestamp
if self.last_successful_update is None
else max(self.last_successful_update, price.timestamp)
)
await asyncio.sleep(self.config.price_update_interval_secs)

def apply_exponent(self, x: float, exp: int) -> int:
return int(x * (10 ** (-exp)))
125 changes: 78 additions & 47 deletions example_publisher/pythd.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
from dataclasses import dataclass, field
import sys
import traceback
from dataclasses_json import config, DataClassJsonMixin
from typing import Callable, Coroutine, List
import json
from dataclasses_json import config, DataClassJsonMixin, dataclass_json
from dataclasses_json.undefined import Undefined
from typing import List, Any, Optional
from structlog import get_logger
from jsonrpc_websocket import Server
from websockets.client import connect, WebSocketClientProtocol
from asyncio import Lock

log = get_logger()

Expand All @@ -15,12 +15,22 @@
TRADING = "trading"


@dataclass_json(undefined=Undefined.EXCLUDE)
@dataclass
class Price(DataClassJsonMixin):
account: str
exponent: int = field(metadata=config(field_name="price_exponent"))


@dataclass
class PriceUpdate(DataClassJsonMixin):
account: str
price: int
conf: int
status: str


@dataclass_json(undefined=Undefined.EXCLUDE)
@dataclass
class Metadata(DataClassJsonMixin):
symbol: str
Expand All @@ -34,56 +44,77 @@ class Product(DataClassJsonMixin):
prices: List[Price] = field(metadata=config(field_name="price"))


@dataclass
class JSONRPCRequest(DataClassJsonMixin):
id: int
method: str
params: List[Any] | Any
jsonrpc: str = "2.0"


@dataclass
class JSONRPCResponse(DataClassJsonMixin):
id: int
result: Optional[Any] = None
error: Optional[Any] = None
jsonrpc: str = "2.0"


class Pythd:
def __init__(
self,
address: str,
on_notify_price_sched: Callable[[SubscriptionId], Coroutine[None, None, None]],
) -> None:
self.address = address
self.server: Server
self.on_notify_price_sched = on_notify_price_sched
self._tasks = set()
self.client: WebSocketClientProtocol
self.id_counter = 0
self.lock = Lock()

async def connect(self):
self.server = Server(self.address)
self.server.notify_price_sched = self._notify_price_sched
task = await self.server.ws_connect()
task.add_done_callback(Pythd._on_connection_done)
self._tasks.add(task)

@staticmethod
def _on_connection_done(task):
log.error("pythd connection closed")
if not task.cancelled() and task.exception() is not None:
e = task.exception()
traceback.print_exception(None, e, e.__traceback__)
sys.exit(1)

async def subscribe_price_sched(self, account: str) -> int:
subscription = (await self.server.subscribe_price_sched(account=account))[
"subscription"
]
log.debug(
"subscribed to price_sched", account=account, subscription=subscription
self.client = await connect(self.address)

def _create_request(self, method: str, params: List[Any] | Any) -> JSONRPCRequest:
self.id_counter += 1
return JSONRPCRequest(
id=self.id_counter,
method=method,
params=params,
)
return subscription

def _notify_price_sched(self, subscription: int) -> None:
log.debug("notify_price_sched RPC call received", subscription=subscription)
task = asyncio.get_event_loop().create_task(
self.on_notify_price_sched(subscription)
)
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
async def send_request(self, request: JSONRPCRequest) -> JSONRPCResponse:
# Using a lock will result in a synchronous execution of the send_request method
# and response retrieval which makes the code easier but is not good for performance.
# It is not recommended to use this behaviour where there are concurrent requests
# being made to the server.
async with self.lock:
await self.client.send(request.to_json())
response = await self.client.recv()
return JSONRPCResponse.from_json(response)

async def send_batch_request(
self, requests: List[JSONRPCRequest]
) -> List[JSONRPCResponse]:
async with self.lock:
await self.client.send(
json.dumps([request.to_dict() for request in requests])
)
response = await self.client.recv()
return JSONRPCResponse.schema().loads(response, many=True)

async def all_products(self) -> List[Product]:
result = await self.server.get_product_list()
return [Product.from_dict(d) for d in result]

async def update_price(
self, account: str, price: int, conf: int, status: str
) -> None:
await self.server.update_price(
account=account, price=price, conf=conf, status=status
)
request = self._create_request("get_product_list", [])
result = await self.send_request(request)
if result.result:
return Product.schema().load(result.result, many=True)
else:
raise ValueError(f"Error fetching products: {result.to_json()}")

async def update_price_batch(self, price_updates: List[PriceUpdate]) -> None:
requests = [
self._create_request("update_price", price_update.to_dict())
for price_update in price_updates
]
results = await self.send_batch_request(requests)
if any(result.error for result in results):
results_json_str = JSONRPCResponse.schema().dumps(results, many=True)
raise ValueError(f"Error updating prices: {results_json_str}")
Loading

0 comments on commit 5733f74

Please sign in to comment.