-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
259 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
.github/ | ||
.ruff_cache/ | ||
.venv/ | ||
postgres_data/ | ||
__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
FROM python:3.10.13-alpine3.18 | ||
FROM python:3.12.2-alpine3.19 | ||
|
||
WORKDIR /app | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
version: "2" | ||
|
||
services: | ||
bot: | ||
build: . | ||
depends_on: | ||
- postgres | ||
environment: | ||
- POSTGRES_HOST=postgres | ||
- POSTGRES_PORT=5432 | ||
- POSTGRES_DB_NAME=blockbot | ||
restart: unless-stopped | ||
|
||
postgres: | ||
image: postgres:16.2-alpine3.19 | ||
environment: | ||
POSTGRES_DB: blockbot | ||
PGDATA: /var/lib/postgresql/data | ||
restart: unless-stopped | ||
volumes: | ||
- ./postgres_data:/var/lib/postgresql/data | ||
|
||
volumes: | ||
postgres_data: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
asyncpg==0.29.0 | ||
hikari==2.0.0.dev122 | ||
hikari-arc==1.1.0 | ||
ruff==0.2.0 | ||
pre-commit==3.6.0 | ||
python-dotenv==1.0.1 | ||
hikari-arc==1.2.1 | ||
pre-commit==3.6.2 | ||
python-dotenv==1.0.1 | ||
ruff==0.2.2 | ||
SQLAlchemy==2.0.27 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,23 @@ | ||
import os | ||
import sys | ||
|
||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
TOKEN = os.environ.get("TOKEN") # required | ||
def get_required_var(var: str) -> str: | ||
env = os.environ.get(var) | ||
if env is None: | ||
print(f"{var} environment variable not set. Exiting.") | ||
sys.exit(1) | ||
return env | ||
|
||
TOKEN = get_required_var("TOKEN") | ||
DEBUG = os.environ.get("DEBUG", False) | ||
POSTGRES_USER = get_required_var("POSTGRES_USER") | ||
POSTGRES_PASSWORD = get_required_var("POSTGRES_PASSWORD") | ||
POSTGRES_HOST = get_required_var("POSTGRES_HOST") | ||
POSTGRES_PORT = get_required_var("POSTGRES_PORT") | ||
POSTGRES_DB_NAME = get_required_var("POSTGRES_DB_NAME") | ||
|
||
CHANNEL_IDS: dict[str, int] = {"lobby": 627542044390457350} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from sqlalchemy import BigInteger, Column, Integer, SmallInteger | ||
from sqlalchemy.ext.asyncio import create_async_engine | ||
from sqlalchemy.orm import declarative_base | ||
|
||
from src.config import POSTGRES_HOST, POSTGRES_PASSWORD, POSTGRES_PORT, POSTGRES_USER, POSTGRES_DB_NAME | ||
|
||
|
||
Base = declarative_base() | ||
|
||
# TODO: add reprs? | ||
|
||
class StarboardSettings(Base): | ||
__tablename__ = "starboard_settings" | ||
|
||
guild = Column(BigInteger, nullable=False, primary_key=True) | ||
channel = Column(BigInteger, nullable=True) | ||
threshold = Column(SmallInteger, nullable=False, default=3) | ||
|
||
|
||
class Starboard(Base): | ||
__tablename__ = "starboard" | ||
|
||
id = Column(Integer, nullable=False, primary_key=True, autoincrement=True) | ||
channel = Column(BigInteger, nullable=False) | ||
message = Column(BigInteger, nullable=False) | ||
stars = Column(SmallInteger, nullable=False) | ||
starboard_channel = Column(BigInteger, nullable=False) | ||
starboard_message = Column(BigInteger, nullable=False) | ||
starboard_stars = Column(SmallInteger, nullable=False) | ||
|
||
|
||
engine = create_async_engine( | ||
f"postgresql+asyncpg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB_NAME}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
|
||
import arc | ||
import hikari | ||
from sqlalchemy import delete, insert, select, update | ||
from sqlalchemy.ext.asyncio import AsyncEngine | ||
|
||
from src.database import Starboard, StarboardSettings | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
plugin = arc.GatewayPlugin("Starboard") | ||
|
||
@plugin.listen() | ||
@plugin.inject_dependencies | ||
async def on_reaction( | ||
event: hikari.GuildReactionAddEvent, | ||
session: AsyncEngine = arc.inject(), | ||
) -> None: | ||
logger.info("Received guild reaction add event") | ||
|
||
if event.emoji_name != "⭐": | ||
return | ||
|
||
message = await plugin.client.rest.fetch_message(event.channel_id, event.message_id) | ||
star_count = sum(r.emoji == "⭐" for r in message.reactions) | ||
|
||
stmt = select(StarboardSettings).where(StarboardSettings.guild == event.guild_id) | ||
async with session.connect() as conn: | ||
result = await conn.execute(stmt) | ||
|
||
settings = result.first() | ||
|
||
# TODO: remove temporary logging and merge into one if statement | ||
if not settings: | ||
logger.info("Received star but no guild starboard set") | ||
return | ||
if star_count < settings.threshold: | ||
logger.info("Not enough stars to post to starboard") | ||
return | ||
if not settings.channel: | ||
logger.info("No starboard channel set") | ||
return | ||
|
||
async with session.connect() as conn: | ||
stmt = select(Starboard).where(Starboard.message == event.message_id) | ||
result = await conn.execute(stmt) | ||
starboard = result.first() | ||
|
||
logger.info(starboard) | ||
|
||
if not starboard: | ||
stmt = select(Starboard).where(Starboard.starboard_message == event.message_id) | ||
result = await conn.execute(stmt) | ||
starboard = result.first() | ||
|
||
logger.info(starboard) | ||
|
||
embed = hikari.Embed(description=f"⭐ {star_count}\n[link]({message.make_link(event.guild_id)})") | ||
|
||
# TODO: handle starring the starboard message | ||
# i.e. don't create a starboard message for the starboard message | ||
|
||
if not starboard: | ||
try: | ||
logger.info("Creating message") | ||
message = await plugin.client.rest.create_message( | ||
settings.channel, | ||
embed, | ||
) | ||
stmt = insert(Starboard).values( | ||
channel=event.channel_id, | ||
message=event.message_id, | ||
stars=star_count, | ||
starboard_channel=settings.channel, | ||
starboard_message=message.id, | ||
starboard_stars=0, | ||
) | ||
|
||
async with session.begin() as conn: | ||
await conn.execute(stmt) | ||
await conn.commit() | ||
except hikari.ForbiddenError: | ||
logger.info("Can't access starboard channel") | ||
stmt = update(StarboardSettings).where(StarboardSettings.guild == event.guild_id).values( | ||
channel=None) | ||
|
||
async with session.begin() as conn: | ||
await conn.execute(stmt) | ||
await conn.commit() | ||
|
||
else: | ||
try: | ||
logger.info("Editing message") | ||
await plugin.client.rest.edit_message( | ||
starboard.starboard_channel, | ||
starboard.starboard_message, | ||
embed | ||
) | ||
except hikari.ForbiddenError: | ||
logger.info("Can't edit starboard message") | ||
stmt = delete(StarboardSettings).where(StarboardSettings.guild == event.guild_id) | ||
|
||
async with session.begin() as conn: | ||
await conn.execute(stmt) | ||
await conn.commit() | ||
|
||
@plugin.include | ||
@arc.slash_command("starboard", "Edit or view starboard settings.", default_permissions=hikari.Permissions.MANAGE_GUILD) | ||
async def starboard_settings( | ||
ctx: arc.GatewayContext, | ||
channel: arc.Option[hikari.TextableGuildChannel | None, arc.ChannelParams("The channel to post starboard messages to.")] = None, | ||
threshold: arc.Option[int | None, arc.IntParams("The minimum number of stars before this message is posted to the starboard", min=1)] = None, | ||
session: AsyncEngine = arc.inject(), | ||
) -> None: | ||
assert ctx.guild_id | ||
|
||
stmt = select(StarboardSettings).where(StarboardSettings.guild == ctx.guild_id) | ||
async with session.connect() as conn: | ||
result = await conn.execute(stmt) | ||
|
||
settings = result.first() | ||
|
||
if not channel and not threshold: | ||
if not settings: | ||
await ctx.respond("This server has no starboard settings.", flags=hikari.MessageFlag.EPHEMERAL) | ||
else: | ||
# TODO: `channel` and `threshold` can be None | ||
embed = hikari.Embed( | ||
title="Starboard Settings", | ||
description=( | ||
f"**Channel:** <#{settings.channel}>\n" | ||
f"**Threshold:** {settings.threshold}" | ||
), | ||
) | ||
await ctx.respond(embed) | ||
|
||
return | ||
|
||
if not settings: | ||
stmt = insert(StarboardSettings).values(guild=ctx.guild_id) | ||
else: | ||
stmt = update(StarboardSettings).where(StarboardSettings.guild == ctx.guild_id) | ||
|
||
# TODO: simplify logic | ||
if channel and threshold: | ||
stmt = stmt.values(channel=channel.id, threshold=threshold) | ||
elif channel: | ||
stmt = stmt.values(channel=channel.id) | ||
elif threshold: | ||
stmt = stmt.values(threshold=threshold) | ||
|
||
async with session.begin() as conn: | ||
await conn.execute(stmt) | ||
await conn.commit() | ||
|
||
# TODO: respond with embed of new settings? | ||
await ctx.respond("Settings updated.", flags=hikari.MessageFlag.EPHEMERAL) | ||
|
||
@arc.loader | ||
def loader(client: arc.GatewayClient) -> None: | ||
client.add_plugin(plugin) |