add option to specify num of tokens in cmd; quota bypass list

This commit is contained in:
bain 2024-10-31 22:36:11 +01:00
parent 93aebbf77f
commit eaa83d5865
Signed by: bain
GPG key ID: 31F0F25E3BED0B9B
2 changed files with 46 additions and 24 deletions

View file

@ -67,3 +67,5 @@ INVITE_CODE_QUOTA_TIMESPAN = td_parse(icq_timespan)
INVITE_CODE_EXPIRATION = env( INVITE_CODE_EXPIRATION = env(
td_parse, "INVITE_CODE_EXPIRATION", datetime.timedelta(days=7) td_parse, "INVITE_CODE_EXPIRATION", datetime.timedelta(days=7)
) )
USERS_WITHOUT_QUOTA = env(lambda x: x.split(), "USERS_WITHOUT_QUOTA", [])

View file

@ -1,10 +1,12 @@
import asyncio
import datetime import datetime
import logging
import re import re
import time import time
import asyncio from collections import deque
from typing import Deque, Optional from typing import Deque, Optional
import aiosqlite
import aiosqlite
from nio import ( from nio import (
AsyncClient, AsyncClient,
InviteMemberEvent, InviteMemberEvent,
@ -14,8 +16,6 @@ from nio import (
RoomMessage, RoomMessage,
RoomMessageText, RoomMessageText,
) )
from collections import deque
import logging
from . import env from . import env
from .admin import SynapseAdmin from .admin import SynapseAdmin
@ -82,7 +82,10 @@ class Bot:
user = event.sender user = event.sender
if type(event) is not RoomMessageText or event.body != "!new": if (
type(event) is not RoomMessageText
or (command := event.body.split())[0] != "!new"
):
await self.send_hi_message(user, room) await self.send_hi_message(user, room)
return return
@ -101,23 +104,38 @@ class Bot:
) )
return return
if await self.quota_exceeded(user): num = 1
try:
if len(command) > 1:
num = int(command[1])
except ValueError:
pass
async with self.db_lock:
if (
user not in env.USERS_WITHOUT_QUOTA
and await self.remaining_quota(user) - num < 0
):
await self.send_message( await self.send_message(
room.room_id, room.room_id,
plain="Sorry, you can't create any more invites right now. Come back later.", plain="Sorry, you can't create any more invites right now. Come back later.",
) )
return return
token = await self.admin_api.create_token() tokens: list[str] = list(filter(lambda t: t, [await self.admin_api.create_token() for _ in range(num)])) # type: ignore
async with self.db_lock:
await self.db.execute( await self.db.executemany(
"INSERT INTO tokens (user, token) VALUES (?, ?);", (user, token) "INSERT INTO tokens (user, token) VALUES (?, ?);",
((user, token) for token in tokens),
) )
await self.db.commit() await self.db.commit()
await self.send_message( await self.send_message(
room.room_id, room.room_id,
formatted=f"<code>{token}</code>", formatted="<code>{}</code>".format('\n'.join(tokens)),
) )
async def send_hi_message(self, user_id: str, room: MatrixRoom): async def send_hi_message(self, user_id: str, room: MatrixRoom):
@ -158,15 +176,18 @@ class Bot:
>= env.USER_REQUIRED_AGE >= env.USER_REQUIRED_AGE
) )
async def quota_exceeded(self, user: str) -> bool: async def remaining_quota(self, user: str) -> int:
timespan = env.INVITE_CODE_QUOTA_TIMESPAN.total_seconds() timespan = env.INVITE_CODE_QUOTA_TIMESPAN.total_seconds()
async with self.db_lock:
async with self.db.execute( async with self.db.execute(
"SELECT count(token) AS amount FROM tokens WHERE unixepoch(CURRENT_TIMESTAMP)-unixepoch(created) < ? AND user = ?;", "SELECT count(token) AS amount FROM tokens WHERE unixepoch(CURRENT_TIMESTAMP)-unixepoch(created) < ? AND user = ?;",
(timespan, user), (timespan, user),
) as cursor: ) as cursor:
res = await cursor.fetchone() res = await cursor.fetchone()
return res is not None and res[0] >= env.INVITE_CODE_QUOTA_AMOUNT if res is None:
return 0
else:
return env.INVITE_CODE_QUOTA_AMOUNT - res[0]
async def require_dm_partner(self, room: MatrixRoom) -> Optional[str]: async def require_dm_partner(self, room: MatrixRoom) -> Optional[str]:
""" """
@ -246,7 +267,6 @@ class Bot:
# and then gracefully shutdown all connections # and then gracefully shutdown all connections
logger.info("Gracefully shutting down") logger.info("Gracefully shutting down")
await self.db_lock.acquire() await self.db_lock.acquire()
try: try: