Compare commits

...

5 commits

Author SHA1 Message Date
36a1867ed5
fix bot replying to its own message 2023-10-13 21:14:39 +02:00
f85d0a0f20
reformat; add expiration time to tokens 2023-10-13 21:14:36 +02:00
a2ad9ea428
update docker configuration 2023-10-13 20:43:38 +02:00
ac07f3c3dc
add DB migrations 2023-10-13 20:43:35 +02:00
80a19c1e73
refactoring 2023-10-13 18:24:09 +02:00
11 changed files with 242 additions and 158 deletions

View file

@ -1,13 +1,30 @@
FROM python:3.10-alpine FROM python:3.10-alpine as base
FROM base as builder
RUN apk update && apk add cmake olm make alpine-sdk RUN apk update && apk add cmake olm make alpine-sdk
RUN mkdir /install
COPY requirements.txt /requirements.txt
RUN pip install --prefix=/install -r /requirements.txt
FROM base
RUN apk update && apk add olm
COPY --from=builder /install /usr/local
WORKDIR /app WORKDIR /app
COPY requirements.txt /app/requirements.txt
RUN pip install -r requirements.txt STOPSIGNAL SIGINT
RUN mkdir /data RUN mkdir /data
COPY matrix-invitation-dealer /app/matrix-invitation-dealer COPY matrix-invitation-dealer /app/matrix-invitation-dealer
COPY sql /app/sql
COPY docker.env /app/.env COPY docker.env /app/.env
CMD ["python3", "-m", "matrix-invitation-dealer"] CMD ["python3", "-m", "matrix-invitation-dealer"]

View file

@ -6,3 +6,19 @@ services:
network_mode: "host" # FIXME network_mode: "host" # FIXME
volumes: volumes:
- ./data:/data - ./data:/data
environment:
# matrix credentials
MATRIX_HOMESERVER: "matrix.nolog.chat"
MATRIX_USER_ID: "@invite:nolog.chat"
MATRIX_USER_PASSWORD: "REDACTED"
# admin credentials
SYNAPSE_ADMIN_HOMESERVER: "http://127.0.0.1:8008"
SYNAPSE_ADMIN_ACCESS_TOKEN: "REDACTED"
# accept invites from users with the following suffix
USER_ID_SUFFIX: "nolog.chat"
# restrictions and quotas
USER_REQUIRED_AGE: "14d"
INVITE_CODE_QUOTA: "10/7d"

View file

@ -1,9 +1,13 @@
import asyncio import asyncio
import logging import logging
import os
import json
from .env import CREDENTIALS_FILE
from .client import create_bot from .client import create_client
from .main import Bot from .main import Bot
from .setup import matrix_account_setup
from .migrate import db_apply_migrations
logging.basicConfig(level=logging.WARNING) logging.basicConfig(level=logging.WARNING)
@ -11,9 +15,20 @@ logging.getLogger(__package__).setLevel(logging.DEBUG)
async def go(): async def go():
client = await create_bot() if not os.path.exists(CREDENTIALS_FILE):
bot = Bot(client) await matrix_account_setup()
await bot.run() else:
await db_apply_migrations() # create and update db
with open(CREDENTIALS_FILE) as f:
config = json.load(f)
client = await create_client(config)
bot = Bot(client)
loop = asyncio.get_running_loop()
await bot.run()
asyncio.run(go()) asyncio.run(go())

View file

@ -1,6 +1,9 @@
from typing import Optional from typing import Optional
from aiohttp import ClientSession from aiohttp import ClientSession
import logging import logging
import time
from .env import INVITE_CODE_EXPIRATION
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,6 +49,7 @@ class SynapseAdmin:
}, },
json={ json={
"uses_allowed": 1, "uses_allowed": 1,
"expiry_time": int((time.time() + INVITE_CODE_EXPIRATION.total_seconds())*1000),
}, },
) )
if not resp.ok: if not resp.ok:
@ -54,9 +58,3 @@ class SynapseAdmin:
json = await resp.json() json = await resp.json()
return json["token"] return json["token"]
async def delete_token(self):
pass
async def get_token(self):
pass

View file

@ -1,34 +1,28 @@
#!/usr/bin/env python3 from .env import STORE_PATH
import json
import os
from .env import CREDENTIALS_FILE, STORE_PATH
from nio import AsyncClient, AsyncClientConfig from nio import AsyncClient, AsyncClientConfig
import logging
logger = logging.getLogger(__name__)
async def create_bot() -> AsyncClient: async def create_client(credentials) -> AsyncClient:
if not os.path.exists(CREDENTIALS_FILE): cfg = AsyncClientConfig(
print("Please first run setup to create initial connection parameters and database") encryption_enabled=True,
exit(1) store_sync_tokens=True,
else: store_name="test_store",
with open(CREDENTIALS_FILE, "r") as f: )
config = json.load(f) client = AsyncClient(
credentials["homeserver"], config=cfg, store_path=str(STORE_PATH)
)
client.access_token = credentials["access_token"]
cfg = AsyncClientConfig( client.user_id = credentials["user_id"]
encryption_enabled=True,
store_sync_tokens=True,
store_name="test_store",
)
client = AsyncClient(config["homeserver"], config=cfg, store_path=STORE_PATH)
client.access_token = config["access_token"] client.device_id = credentials["device_id"]
client.user_id = config["user_id"] client.load_store()
if client.should_upload_keys:
client.device_id = config["device_id"] await client.keys_upload()
client.download
client.load_store()
if client.should_upload_keys:
await client.keys_upload()
return client return client

View file

@ -1,12 +0,0 @@
from .env import DATABASE_FILE
import aiosqlite
class CRUD:
def __init__(self, connection: aiosqlite.Connection):
self.connection = connection
@classmethod
async def at(cls, file: str):
conn = await aiosqlite.connect(file)
return cls(conn)

View file

@ -1,6 +1,7 @@
import os import os
import re import re
import datetime import datetime
from pathlib import Path
from typing import Any, Callable, Optional, TypeVar, Dict from typing import Any, Callable, Optional, TypeVar, Dict
R = TypeVar("R") R = TypeVar("R")
@ -16,10 +17,12 @@ def td_parse(v: str) -> datetime.timedelta:
match = re.match(r"^(?:(\d+)d)?\s*(?:(\d+)h)?\s*(?:(\d+)m)?\s*(?:(\d+)s)?$", v) match = re.match(r"^(?:(\d+)d)?\s*(?:(\d+)h)?\s*(?:(\d+)m)?\s*(?:(\d+)s)?$", v)
if match is None: if match is None:
raise ValueError(f'Cannot parse "{v}" into timedelta') raise ValueError(f'Cannot parse "{v}" into timedelta')
days = int(match.group(1) or 0) days = int(match.group(1) or 0)
hours = int(match.group(2) or 0) hours = int(match.group(2) or 0)
minutes = int(match.group(3) or 0) minutes = int(match.group(3) or 0)
seconds = int(match.group(4) or 0) seconds = int(match.group(4) or 0)
return datetime.timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds) return datetime.timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
@ -44,24 +47,23 @@ if os.path.isfile(_ENV_FILE):
continue continue
_DOTENV[split[0]] = "=".join(split[1:]) _DOTENV[split[0]] = "=".join(split[1:])
USER_REQUIRED_AGE: datetime.timedelta = env( USER_REQUIRED_AGE = env(td_parse, "USER_REQUIRED_AGE", datetime.timedelta(days=14))
td_parse, "USER_REQUIRED_AGE", datetime.timedelta(days=14)
)
SYNAPSE_ADMIN_ACCESS_TOKEN: str = env(str, "SYNAPSE_ADMIN_ACCESS_TOKEN", "") SYNAPSE_ADMIN_ACCESS_TOKEN = env(str, "SYNAPSE_ADMIN_ACCESS_TOKEN", "")
SYNAPSE_ADMIN_HOMESERVER: str = env( SYNAPSE_ADMIN_HOMESERVER = env(str, "SYNAPSE_ADMIN_HOMESERVER", "http://127.0.0.1:8008")
str, "SYNAPSE_ADMIN_ACCESS_TOKEN", "http://127.0.0.1:8008"
)
DATABASE_FILE: str = env(str, "DATABASE_FILE", "data.sqlite") DATABASE_FILE = env(Path, "DATABASE_FILE", "data.sqlite")
CREDENTIALS_FILE: str = env(str, "CREDENTIALS_FILE", "credentials.json") CREDENTIALS_FILE = env(Path, "CREDENTIALS_FILE", "credentials.json")
STORE_PATH: str = env(str, "STORE_PATH", "store") STORE_PATH = env(Path, "STORE_PATH", "store")
USER_ID_SUFFIX: str = env(str, "USER_ID_SUFFIX", "nolog.chat") USER_ID_SUFFIX = env(str, "USER_ID_SUFFIX", "nolog.chat")
INVITE_CODE_QUOTA: str = env(str, "INVITE_CODE_QUOTA", "10/7d") INVITE_CODE_QUOTA = env(str, "INVITE_CODE_QUOTA", "10/7d")
icq_amount, icq_timespan = INVITE_CODE_QUOTA.split("/") icq_amount, icq_timespan = INVITE_CODE_QUOTA.split("/")
INVITE_CODE_QUOTA_AMOUNT: int = int(icq_amount) INVITE_CODE_QUOTA_AMOUNT = int(icq_amount)
INVITE_CODE_QUOTA_TIMESPAN: datetime.timedelta = td_parse(icq_timespan) INVITE_CODE_QUOTA_TIMESPAN = td_parse(icq_timespan)
INVITE_CODE_EXPIRATION = env(
td_parse, "INVITE_CODE_EXPIRATION", datetime.timedelta(days=7)
)

View file

@ -37,14 +37,13 @@ class Bot:
if ( if (
event.membership == "invite" event.membership == "invite"
and event.state_key == self.client.user_id # event about me and event.state_key == self.client.user_id # event about me
and event.sender.endswith(':' + env.USER_ID_SUFFIX) and event.sender.endswith(":" + env.USER_ID_SUFFIX)
and event.content.get('is_direct', False) and event.content.get("is_direct", False)
): ):
# we've got a valid invite! # we've got a valid invite!
logger.debug("joining DM of %s", event.sender) logger.debug("joining DM of %s", event.sender)
await self.client.join(room.room_id) await self.client.join(room.room_id)
elif event.membership == "invite" and event.state_key == self.client.user_id: elif event.membership == "invite" and event.state_key == self.client.user_id:
print(event.content)
await self.client.room_leave(room.room_id) await self.client.room_leave(room.room_id)
async def room_member_update_callback( async def room_member_update_callback(
@ -60,42 +59,22 @@ class Bot:
if not user: if not user:
return return
allowed = await self.user_allowed(user) await self.send_hi_message(user, room)
if allowed is None:
await self.send_message(
room.room_id,
plain="Hello! I couldn't fetch your account information. Sorry. You can try again later.",
)
await self.leave(room)
return
if not allowed:
await self.send_message(
room.room_id,
formatted="Hello! You can't create invites <i>just yet</i>. Feel free to message me in a few days to check again.",
)
await self.leave(room)
return
await self.send_message(
room.room_id,
formatted="Hello! <b>You are allowed to create invites</b>, hurray! You can generate a new invite by sending the <code>!new</code> command. I will respond with a single-use code that you can share.",
)
if event.membership == "leave" and room.joined_count == 1: if event.membership == "leave" and room.joined_count == 1:
# leave rooms where we're alone # leave rooms where we're alone
await self.leave(room) await self.leave(room)
async def room_message_callback(self, room: MatrixRoom, event: RoomMessage): async def room_message_callback(self, room: MatrixRoom, event: RoomMessage):
if type(room) is not MatrixRoom: if type(room) is not MatrixRoom or event.sender == self.client.user_id:
return
if type(event) is not RoomMessageText or event.body != "!new":
return return
user = event.sender user = event.sender
if type(event) is not RoomMessageText or event.body != "!new":
await self.send_hi_message(user, room)
return
allowed = await self.user_allowed(user) allowed = await self.user_allowed(user)
if allowed is None: if allowed is None:
await self.send_message( await self.send_message(
@ -130,6 +109,29 @@ class Bot:
formatted=f"<code>{token}</code>", formatted=f"<code>{token}</code>",
) )
async def send_hi_message(self, user_id: str, room: MatrixRoom):
allowed = await self.user_allowed(user_id)
if allowed is None:
await self.send_message(
room.room_id,
plain="Hello! I couldn't fetch your account information. Sorry. You can try again later.",
)
return
if not allowed:
await self.send_message(
room.room_id,
formatted="Hello! You can't create invites <i>just yet</i>. Feel free to message me in a few days to check again.",
)
return
await self.send_message(
room.room_id,
formatted="Hello! <b>You are allowed to create invites</b>, hurray! You can generate a new invite by sending the <code>!new</code> command. I will respond with a single-use code that you can share.",
)
return
async def user_allowed(self, user: str) -> Optional[bool]: async def user_allowed(self, user: str) -> Optional[bool]:
""" """
Checks both that the user is from the homeserver we are managing, and Checks both that the user is from the homeserver we are managing, and
@ -164,7 +166,6 @@ class Bot:
return None return None
not_me = list(filter(lambda u: u.user_id != self.client.user_id, users.members)) not_me = list(filter(lambda u: u.user_id != self.client.user_id, users.members))
print(room.users)
if len(not_me) > 1: if len(not_me) > 1:
# This shouldn't really happen, since we're trying our best to stay out of # This shouldn't really happen, since we're trying our best to stay out of
# rooms with multiple people. # rooms with multiple people.
@ -213,20 +214,34 @@ class Bot:
async def run(self): async def run(self):
self.admin_api = await SynapseAdmin.at( self.admin_api = await SynapseAdmin.at(
env.SYNAPSE_ADMIN_HOMESERVER, env.SYNAPSE_ADMIN_ACCESS_TOKEN or self.client.access_token env.SYNAPSE_ADMIN_HOMESERVER or self.client.homeserver,
env.SYNAPSE_ADMIN_ACCESS_TOKEN or self.client.access_token,
) )
self.db = await aiosqlite.connect(env.DATABASE_FILE) self.db = await aiosqlite.connect(env.DATABASE_FILE)
try: while True:
while True: future = asyncio.gather(self.client.sync_forever(30_000), self.main())
try: try:
await asyncio.gather(self.client.sync_forever(30_000), self.main()) await asyncio.shield(future)
except Exception: except asyncio.CancelledError:
# TODO: better restart system # When we are getting cancelled, we want to first finish writing to the DB,
logger.exception("Restarting") # and then gracefully shutdown all connections
await asyncio.sleep(15) logger.info("Gracefully shutting down")
finally:
await self.client.close()
await self.db.close()
await self.admin_api.session.close()
await self.db_lock.acquire()
try:
future.cancel()
await future
except:
pass
await self.client.close()
await self.db.close()
await self.admin_api.session.close()
return
except Exception:
# TODO: better restart system
logger.exception("Restarting")
await asyncio.sleep(15)

View file

@ -0,0 +1,36 @@
from pathlib import Path
import glob
import aiosqlite
import os
import logging
from .env import DATABASE_FILE
logger = logging.getLogger(__name__)
async def db_apply_migrations():
if not os.path.exists(DATABASE_FILE):
async with aiosqlite.connect(DATABASE_FILE) as db:
script = Path("sql/init.sql").read_text()
await db.executescript(script)
await db.commit()
async with aiosqlite.connect(DATABASE_FILE) as db:
async with db.execute("SELECT filename FROM sch_updates") as cursor:
already_applied = [r[0] for r in (await cursor.fetchall())]
for filename in glob.glob("sql/update*"):
file = Path(filename)
if file.name not in already_applied:
try:
await db.executescript(file.read_text())
await db.execute(
"INSERT INTO sch_updates (filename) VALUES (?)", file.name
)
except Exception:
logger.exception("Failed to migrate!")
await db.rollback()
exit(1)
else:
await db.commit()

View file

@ -1,11 +1,12 @@
import asyncio
import json import json
import getpass
import aiosqlite
import os import os
import logging
from nio import AsyncClient, AsyncClientConfig, LoginResponse from nio import AsyncClient, AsyncClientConfig, LoginResponse
from .env import CREDENTIALS_FILE, DATABASE_FILE, STORE_PATH from .env import CREDENTIALS_FILE, STORE_PATH, env
logger = logging.getLogger(__name__)
def write_details_to_disk(resp: LoginResponse, homeserver) -> None: def write_details_to_disk(resp: LoginResponse, homeserver) -> None:
""" """
@ -24,57 +25,49 @@ def write_details_to_disk(resp: LoginResponse, homeserver) -> None:
f, f,
) )
async def main():
print(
"First time use. Did not find credential file. Asking for "
"homeserver, user, and password to create credential file."
)
homeserver = input(f"Enter your homeserver URL: ") async def matrix_account_setup():
MATRIX_HOMESERVER = env(str, "MATRIX_HOMESERVER")
MATRIX_USER_ID = env(str, "MATRIX_USER_ID")
MATRIX_USER_PASSWORD = env(str, "MATRIX_USER_PASSWORD")
MATRIX_DEVICE_NAME = env(str, "MATRIX_DEVICE_NAME", "matrix-invitation-dealer")
homeserver = MATRIX_HOMESERVER
if not (homeserver.startswith("https://") or homeserver.startswith("http://")): if not (homeserver.startswith("https://") or homeserver.startswith("http://")):
homeserver = "https://" + homeserver homeserver = "https://" + MATRIX_HOMESERVER
user_id = input(f"Enter your full user ID: ") cfg = AsyncClientConfig(
encryption_enabled=True,
store_sync_tokens=True,
store_name="account_store",
)
os.makedirs(STORE_PATH, exist_ok=True)
device_name = input(f"Choose a name for this device: [matrix-nio] ") or "matrix-nio" client = AsyncClient(
homeserver, MATRIX_USER_ID, config=cfg, store_path=str(STORE_PATH)
)
cfg = AsyncClientConfig( resp = await client.login(MATRIX_USER_PASSWORD, device_name=MATRIX_DEVICE_NAME)
encryption_enabled=True,
store_sync_tokens=True, client.load_store()
store_name="test_store", if client.should_upload_keys:
await client.keys_upload()
if isinstance(resp, LoginResponse):
write_details_to_disk(resp, homeserver)
assert client.olm is not None
key = client.olm.account.identity_keys["ed25519"]
logger.info("Logged in as %s. Please manually verify this session.")
logger.info(
"Session fingerprint %s",
" ".join([key[i : i + 4] for i in range(0, len(key), 4)]),
) )
os.makedirs(STORE_PATH, exist_ok=True)
client = AsyncClient(homeserver, user_id, config=cfg, store_path=STORE_PATH)
pw = getpass.getpass() else:
print(f'homeserver = "{homeserver}"; user = "{MATRIX_USER_ID}"')
resp = await client.login(pw, device_name=device_name) print(f"Failed to log in: {resp}")
exit(1)
# check that we logged in successfully
if isinstance(resp, LoginResponse):
write_details_to_disk(resp, homeserver)
print(
"Logged in using a password. Credentials were stored.",
"Try running the script again to login with credentials.",
)
else:
print(f'homeserver = "{homeserver}"; user = "{user_id}"')
print(f"Failed to log in: {resp}")
exit(1)
async with aiosqlite.connect(DATABASE_FILE) as db:
await db.executescript('''
CREATE TABLE IF NOT EXISTS tokens (
user TEXT NOT NULL,
token TEXT NOT NULL,
created TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
);
''')
await db.commit()
asyncio.run(main())

10
sql/init.sql Normal file
View file

@ -0,0 +1,10 @@
CREATE TABLE tokens (
user TEXT NOT NULL,
token TEXT NOT NULL,
created TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
);
CREATE TABLE sch_updates (
filename TEXT NOT NULL,
applied TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
);