Compare commits
5 commits
a2247a4b2a
...
36a1867ed5
Author | SHA1 | Date | |
---|---|---|---|
36a1867ed5 | |||
f85d0a0f20 | |||
a2ad9ea428 | |||
ac07f3c3dc | |||
80a19c1e73 |
11 changed files with 242 additions and 158 deletions
23
Dockerfile
23
Dockerfile
|
@ -1,13 +1,30 @@
|
|||
FROM python:3.10-alpine
|
||||
FROM python:3.10-alpine as base
|
||||
|
||||
FROM base as builder
|
||||
|
||||
RUN apk update && apk add cmake olm make alpine-sdk
|
||||
|
||||
RUN mkdir /install
|
||||
|
||||
COPY requirements.txt /requirements.txt
|
||||
|
||||
RUN pip install --prefix=/install -r /requirements.txt
|
||||
|
||||
|
||||
FROM base
|
||||
|
||||
RUN apk update && apk add olm
|
||||
|
||||
COPY --from=builder /install /usr/local
|
||||
|
||||
WORKDIR /app
|
||||
COPY requirements.txt /app/requirements.txt
|
||||
RUN pip install -r requirements.txt
|
||||
|
||||
STOPSIGNAL SIGINT
|
||||
|
||||
RUN mkdir /data
|
||||
|
||||
COPY matrix-invitation-dealer /app/matrix-invitation-dealer
|
||||
COPY sql /app/sql
|
||||
COPY docker.env /app/.env
|
||||
|
||||
CMD ["python3", "-m", "matrix-invitation-dealer"]
|
||||
|
|
|
@ -6,3 +6,19 @@ services:
|
|||
network_mode: "host" # FIXME
|
||||
volumes:
|
||||
- ./data:/data
|
||||
environment:
|
||||
# matrix credentials
|
||||
MATRIX_HOMESERVER: "matrix.nolog.chat"
|
||||
MATRIX_USER_ID: "@invite:nolog.chat"
|
||||
MATRIX_USER_PASSWORD: "REDACTED"
|
||||
|
||||
# admin credentials
|
||||
SYNAPSE_ADMIN_HOMESERVER: "http://127.0.0.1:8008"
|
||||
SYNAPSE_ADMIN_ACCESS_TOKEN: "REDACTED"
|
||||
|
||||
# accept invites from users with the following suffix
|
||||
USER_ID_SUFFIX: "nolog.chat"
|
||||
|
||||
# restrictions and quotas
|
||||
USER_REQUIRED_AGE: "14d"
|
||||
INVITE_CODE_QUOTA: "10/7d"
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
from .client import create_bot
|
||||
from .env import CREDENTIALS_FILE
|
||||
from .client import create_client
|
||||
from .main import Bot
|
||||
from .setup import matrix_account_setup
|
||||
from .migrate import db_apply_migrations
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
@ -11,9 +15,20 @@ logging.getLogger(__package__).setLevel(logging.DEBUG)
|
|||
|
||||
|
||||
async def go():
|
||||
client = await create_bot()
|
||||
bot = Bot(client)
|
||||
await bot.run()
|
||||
if not os.path.exists(CREDENTIALS_FILE):
|
||||
await matrix_account_setup()
|
||||
else:
|
||||
await db_apply_migrations() # create and update db
|
||||
|
||||
with open(CREDENTIALS_FILE) as f:
|
||||
config = json.load(f)
|
||||
client = await create_client(config)
|
||||
|
||||
bot = Bot(client)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
await bot.run()
|
||||
|
||||
|
||||
asyncio.run(go())
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
from typing import Optional
|
||||
from aiohttp import ClientSession
|
||||
import logging
|
||||
import time
|
||||
|
||||
from .env import INVITE_CODE_EXPIRATION
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -46,6 +49,7 @@ class SynapseAdmin:
|
|||
},
|
||||
json={
|
||||
"uses_allowed": 1,
|
||||
"expiry_time": int((time.time() + INVITE_CODE_EXPIRATION.total_seconds())*1000),
|
||||
},
|
||||
)
|
||||
if not resp.ok:
|
||||
|
@ -54,9 +58,3 @@ class SynapseAdmin:
|
|||
json = await resp.json()
|
||||
|
||||
return json["token"]
|
||||
|
||||
async def delete_token(self):
|
||||
pass
|
||||
|
||||
async def get_token(self):
|
||||
pass
|
||||
|
|
|
@ -1,34 +1,28 @@
|
|||
#!/usr/bin/env python3
|
||||
import json
|
||||
import os
|
||||
from .env import CREDENTIALS_FILE, STORE_PATH
|
||||
from .env import STORE_PATH
|
||||
|
||||
from nio import AsyncClient, AsyncClientConfig
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def create_bot() -> AsyncClient:
|
||||
if not os.path.exists(CREDENTIALS_FILE):
|
||||
print("Please first run setup to create initial connection parameters and database")
|
||||
exit(1)
|
||||
else:
|
||||
with open(CREDENTIALS_FILE, "r") as f:
|
||||
config = json.load(f)
|
||||
async def create_client(credentials) -> AsyncClient:
|
||||
cfg = AsyncClientConfig(
|
||||
encryption_enabled=True,
|
||||
store_sync_tokens=True,
|
||||
store_name="test_store",
|
||||
)
|
||||
client = AsyncClient(
|
||||
credentials["homeserver"], config=cfg, store_path=str(STORE_PATH)
|
||||
)
|
||||
client.access_token = credentials["access_token"]
|
||||
|
||||
cfg = AsyncClientConfig(
|
||||
encryption_enabled=True,
|
||||
store_sync_tokens=True,
|
||||
store_name="test_store",
|
||||
)
|
||||
client = AsyncClient(config["homeserver"], config=cfg, store_path=STORE_PATH)
|
||||
client.user_id = credentials["user_id"]
|
||||
|
||||
client.access_token = config["access_token"]
|
||||
client.device_id = credentials["device_id"]
|
||||
|
||||
client.user_id = config["user_id"]
|
||||
|
||||
client.device_id = config["device_id"]
|
||||
|
||||
client.download
|
||||
client.load_store()
|
||||
if client.should_upload_keys:
|
||||
await client.keys_upload()
|
||||
client.load_store()
|
||||
if client.should_upload_keys:
|
||||
await client.keys_upload()
|
||||
|
||||
return client
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
from .env import DATABASE_FILE
|
||||
import aiosqlite
|
||||
|
||||
|
||||
class CRUD:
|
||||
def __init__(self, connection: aiosqlite.Connection):
|
||||
self.connection = connection
|
||||
|
||||
@classmethod
|
||||
async def at(cls, file: str):
|
||||
conn = await aiosqlite.connect(file)
|
||||
return cls(conn)
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import re
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional, TypeVar, Dict
|
||||
|
||||
R = TypeVar("R")
|
||||
|
@ -16,10 +17,12 @@ def td_parse(v: str) -> datetime.timedelta:
|
|||
match = re.match(r"^(?:(\d+)d)?\s*(?:(\d+)h)?\s*(?:(\d+)m)?\s*(?:(\d+)s)?$", v)
|
||||
if match is None:
|
||||
raise ValueError(f'Cannot parse "{v}" into timedelta')
|
||||
|
||||
days = int(match.group(1) or 0)
|
||||
hours = int(match.group(2) or 0)
|
||||
minutes = int(match.group(3) or 0)
|
||||
seconds = int(match.group(4) or 0)
|
||||
|
||||
return datetime.timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
|
||||
|
||||
|
||||
|
@ -44,24 +47,23 @@ if os.path.isfile(_ENV_FILE):
|
|||
continue
|
||||
_DOTENV[split[0]] = "=".join(split[1:])
|
||||
|
||||
USER_REQUIRED_AGE: datetime.timedelta = env(
|
||||
td_parse, "USER_REQUIRED_AGE", datetime.timedelta(days=14)
|
||||
)
|
||||
USER_REQUIRED_AGE = env(td_parse, "USER_REQUIRED_AGE", datetime.timedelta(days=14))
|
||||
|
||||
SYNAPSE_ADMIN_ACCESS_TOKEN: str = env(str, "SYNAPSE_ADMIN_ACCESS_TOKEN", "")
|
||||
SYNAPSE_ADMIN_HOMESERVER: str = env(
|
||||
str, "SYNAPSE_ADMIN_ACCESS_TOKEN", "http://127.0.0.1:8008"
|
||||
)
|
||||
SYNAPSE_ADMIN_ACCESS_TOKEN = env(str, "SYNAPSE_ADMIN_ACCESS_TOKEN", "")
|
||||
SYNAPSE_ADMIN_HOMESERVER = env(str, "SYNAPSE_ADMIN_HOMESERVER", "http://127.0.0.1:8008")
|
||||
|
||||
DATABASE_FILE: str = env(str, "DATABASE_FILE", "data.sqlite")
|
||||
CREDENTIALS_FILE: str = env(str, "CREDENTIALS_FILE", "credentials.json")
|
||||
STORE_PATH: str = env(str, "STORE_PATH", "store")
|
||||
DATABASE_FILE = env(Path, "DATABASE_FILE", "data.sqlite")
|
||||
CREDENTIALS_FILE = env(Path, "CREDENTIALS_FILE", "credentials.json")
|
||||
STORE_PATH = env(Path, "STORE_PATH", "store")
|
||||
|
||||
USER_ID_SUFFIX: str = env(str, "USER_ID_SUFFIX", "nolog.chat")
|
||||
USER_ID_SUFFIX = env(str, "USER_ID_SUFFIX", "nolog.chat")
|
||||
|
||||
INVITE_CODE_QUOTA: str = env(str, "INVITE_CODE_QUOTA", "10/7d")
|
||||
INVITE_CODE_QUOTA = env(str, "INVITE_CODE_QUOTA", "10/7d")
|
||||
|
||||
icq_amount, icq_timespan = INVITE_CODE_QUOTA.split("/")
|
||||
INVITE_CODE_QUOTA_AMOUNT: int = int(icq_amount)
|
||||
INVITE_CODE_QUOTA_TIMESPAN: datetime.timedelta = td_parse(icq_timespan)
|
||||
INVITE_CODE_QUOTA_AMOUNT = int(icq_amount)
|
||||
INVITE_CODE_QUOTA_TIMESPAN = td_parse(icq_timespan)
|
||||
|
||||
INVITE_CODE_EXPIRATION = env(
|
||||
td_parse, "INVITE_CODE_EXPIRATION", datetime.timedelta(days=7)
|
||||
)
|
||||
|
|
|
@ -37,14 +37,13 @@ class Bot:
|
|||
if (
|
||||
event.membership == "invite"
|
||||
and event.state_key == self.client.user_id # event about me
|
||||
and event.sender.endswith(':' + env.USER_ID_SUFFIX)
|
||||
and event.content.get('is_direct', False)
|
||||
and event.sender.endswith(":" + env.USER_ID_SUFFIX)
|
||||
and event.content.get("is_direct", False)
|
||||
):
|
||||
# we've got a valid invite!
|
||||
logger.debug("joining DM of %s", event.sender)
|
||||
await self.client.join(room.room_id)
|
||||
elif event.membership == "invite" and event.state_key == self.client.user_id:
|
||||
print(event.content)
|
||||
await self.client.room_leave(room.room_id)
|
||||
|
||||
async def room_member_update_callback(
|
||||
|
@ -60,42 +59,22 @@ class Bot:
|
|||
if not user:
|
||||
return
|
||||
|
||||
allowed = await self.user_allowed(user)
|
||||
|
||||
if allowed is None:
|
||||
await self.send_message(
|
||||
room.room_id,
|
||||
plain="Hello! I couldn't fetch your account information. Sorry. You can try again later.",
|
||||
)
|
||||
await self.leave(room)
|
||||
return
|
||||
|
||||
if not allowed:
|
||||
await self.send_message(
|
||||
room.room_id,
|
||||
formatted="Hello! You can't create invites <i>just yet</i>. Feel free to message me in a few days to check again.",
|
||||
)
|
||||
await self.leave(room)
|
||||
return
|
||||
|
||||
await self.send_message(
|
||||
room.room_id,
|
||||
formatted="Hello! <b>You are allowed to create invites</b>, hurray! You can generate a new invite by sending the <code>!new</code> command. I will respond with a single-use code that you can share.",
|
||||
)
|
||||
await self.send_hi_message(user, room)
|
||||
|
||||
if event.membership == "leave" and room.joined_count == 1:
|
||||
# leave rooms where we're alone
|
||||
await self.leave(room)
|
||||
|
||||
async def room_message_callback(self, room: MatrixRoom, event: RoomMessage):
|
||||
if type(room) is not MatrixRoom:
|
||||
return
|
||||
|
||||
if type(event) is not RoomMessageText or event.body != "!new":
|
||||
if type(room) is not MatrixRoom or event.sender == self.client.user_id:
|
||||
return
|
||||
|
||||
user = event.sender
|
||||
|
||||
if type(event) is not RoomMessageText or event.body != "!new":
|
||||
await self.send_hi_message(user, room)
|
||||
return
|
||||
|
||||
allowed = await self.user_allowed(user)
|
||||
if allowed is None:
|
||||
await self.send_message(
|
||||
|
@ -130,6 +109,29 @@ class Bot:
|
|||
formatted=f"<code>{token}</code>",
|
||||
)
|
||||
|
||||
async def send_hi_message(self, user_id: str, room: MatrixRoom):
|
||||
allowed = await self.user_allowed(user_id)
|
||||
|
||||
if allowed is None:
|
||||
await self.send_message(
|
||||
room.room_id,
|
||||
plain="Hello! I couldn't fetch your account information. Sorry. You can try again later.",
|
||||
)
|
||||
return
|
||||
|
||||
if not allowed:
|
||||
await self.send_message(
|
||||
room.room_id,
|
||||
formatted="Hello! You can't create invites <i>just yet</i>. Feel free to message me in a few days to check again.",
|
||||
)
|
||||
return
|
||||
|
||||
await self.send_message(
|
||||
room.room_id,
|
||||
formatted="Hello! <b>You are allowed to create invites</b>, hurray! You can generate a new invite by sending the <code>!new</code> command. I will respond with a single-use code that you can share.",
|
||||
)
|
||||
return
|
||||
|
||||
async def user_allowed(self, user: str) -> Optional[bool]:
|
||||
"""
|
||||
Checks both that the user is from the homeserver we are managing, and
|
||||
|
@ -164,7 +166,6 @@ class Bot:
|
|||
return None
|
||||
|
||||
not_me = list(filter(lambda u: u.user_id != self.client.user_id, users.members))
|
||||
print(room.users)
|
||||
if len(not_me) > 1:
|
||||
# This shouldn't really happen, since we're trying our best to stay out of
|
||||
# rooms with multiple people.
|
||||
|
@ -213,20 +214,34 @@ class Bot:
|
|||
|
||||
async def run(self):
|
||||
self.admin_api = await SynapseAdmin.at(
|
||||
env.SYNAPSE_ADMIN_HOMESERVER, env.SYNAPSE_ADMIN_ACCESS_TOKEN or self.client.access_token
|
||||
env.SYNAPSE_ADMIN_HOMESERVER or self.client.homeserver,
|
||||
env.SYNAPSE_ADMIN_ACCESS_TOKEN or self.client.access_token,
|
||||
)
|
||||
self.db = await aiosqlite.connect(env.DATABASE_FILE)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
await asyncio.gather(self.client.sync_forever(30_000), self.main())
|
||||
except Exception:
|
||||
# TODO: better restart system
|
||||
logger.exception("Restarting")
|
||||
await asyncio.sleep(15)
|
||||
finally:
|
||||
await self.client.close()
|
||||
await self.db.close()
|
||||
await self.admin_api.session.close()
|
||||
while True:
|
||||
future = asyncio.gather(self.client.sync_forever(30_000), self.main())
|
||||
try:
|
||||
await asyncio.shield(future)
|
||||
except asyncio.CancelledError:
|
||||
# When we are getting cancelled, we want to first finish writing to the DB,
|
||||
# and then gracefully shutdown all connections
|
||||
logger.info("Gracefully shutting down")
|
||||
|
||||
await self.db_lock.acquire()
|
||||
|
||||
try:
|
||||
future.cancel()
|
||||
await future
|
||||
except:
|
||||
pass
|
||||
|
||||
await self.client.close()
|
||||
await self.db.close()
|
||||
await self.admin_api.session.close()
|
||||
|
||||
return
|
||||
except Exception:
|
||||
# TODO: better restart system
|
||||
logger.exception("Restarting")
|
||||
await asyncio.sleep(15)
|
||||
|
|
36
matrix-invitation-dealer/migrate.py
Normal file
36
matrix-invitation-dealer/migrate.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
from pathlib import Path
|
||||
import glob
|
||||
import aiosqlite
|
||||
import os
|
||||
import logging
|
||||
|
||||
from .env import DATABASE_FILE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def db_apply_migrations():
|
||||
if not os.path.exists(DATABASE_FILE):
|
||||
async with aiosqlite.connect(DATABASE_FILE) as db:
|
||||
script = Path("sql/init.sql").read_text()
|
||||
await db.executescript(script)
|
||||
await db.commit()
|
||||
|
||||
async with aiosqlite.connect(DATABASE_FILE) as db:
|
||||
async with db.execute("SELECT filename FROM sch_updates") as cursor:
|
||||
already_applied = [r[0] for r in (await cursor.fetchall())]
|
||||
|
||||
for filename in glob.glob("sql/update*"):
|
||||
file = Path(filename)
|
||||
if file.name not in already_applied:
|
||||
try:
|
||||
await db.executescript(file.read_text())
|
||||
await db.execute(
|
||||
"INSERT INTO sch_updates (filename) VALUES (?)", file.name
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to migrate!")
|
||||
await db.rollback()
|
||||
exit(1)
|
||||
else:
|
||||
await db.commit()
|
|
@ -1,11 +1,12 @@
|
|||
import asyncio
|
||||
import json
|
||||
import getpass
|
||||
import aiosqlite
|
||||
import os
|
||||
import logging
|
||||
from nio import AsyncClient, AsyncClientConfig, LoginResponse
|
||||
|
||||
from .env import CREDENTIALS_FILE, DATABASE_FILE, STORE_PATH
|
||||
from .env import CREDENTIALS_FILE, STORE_PATH, env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def write_details_to_disk(resp: LoginResponse, homeserver) -> None:
|
||||
"""
|
||||
|
@ -24,57 +25,49 @@ def write_details_to_disk(resp: LoginResponse, homeserver) -> None:
|
|||
f,
|
||||
)
|
||||
|
||||
async def main():
|
||||
print(
|
||||
"First time use. Did not find credential file. Asking for "
|
||||
"homeserver, user, and password to create credential file."
|
||||
)
|
||||
|
||||
homeserver = input(f"Enter your homeserver URL: ")
|
||||
async def matrix_account_setup():
|
||||
MATRIX_HOMESERVER = env(str, "MATRIX_HOMESERVER")
|
||||
MATRIX_USER_ID = env(str, "MATRIX_USER_ID")
|
||||
MATRIX_USER_PASSWORD = env(str, "MATRIX_USER_PASSWORD")
|
||||
MATRIX_DEVICE_NAME = env(str, "MATRIX_DEVICE_NAME", "matrix-invitation-dealer")
|
||||
|
||||
homeserver = MATRIX_HOMESERVER
|
||||
|
||||
if not (homeserver.startswith("https://") or homeserver.startswith("http://")):
|
||||
homeserver = "https://" + homeserver
|
||||
homeserver = "https://" + MATRIX_HOMESERVER
|
||||
|
||||
user_id = input(f"Enter your full user ID: ")
|
||||
cfg = AsyncClientConfig(
|
||||
encryption_enabled=True,
|
||||
store_sync_tokens=True,
|
||||
store_name="account_store",
|
||||
)
|
||||
os.makedirs(STORE_PATH, exist_ok=True)
|
||||
|
||||
device_name = input(f"Choose a name for this device: [matrix-nio] ") or "matrix-nio"
|
||||
client = AsyncClient(
|
||||
homeserver, MATRIX_USER_ID, config=cfg, store_path=str(STORE_PATH)
|
||||
)
|
||||
|
||||
cfg = AsyncClientConfig(
|
||||
encryption_enabled=True,
|
||||
store_sync_tokens=True,
|
||||
store_name="test_store",
|
||||
resp = await client.login(MATRIX_USER_PASSWORD, device_name=MATRIX_DEVICE_NAME)
|
||||
|
||||
client.load_store()
|
||||
if client.should_upload_keys:
|
||||
await client.keys_upload()
|
||||
|
||||
if isinstance(resp, LoginResponse):
|
||||
write_details_to_disk(resp, homeserver)
|
||||
|
||||
assert client.olm is not None
|
||||
key = client.olm.account.identity_keys["ed25519"]
|
||||
|
||||
logger.info("Logged in as %s. Please manually verify this session.")
|
||||
logger.info(
|
||||
"Session fingerprint %s",
|
||||
" ".join([key[i : i + 4] for i in range(0, len(key), 4)]),
|
||||
)
|
||||
os.makedirs(STORE_PATH, exist_ok=True)
|
||||
client = AsyncClient(homeserver, user_id, config=cfg, store_path=STORE_PATH)
|
||||
|
||||
pw = getpass.getpass()
|
||||
else:
|
||||
print(f'homeserver = "{homeserver}"; user = "{MATRIX_USER_ID}"')
|
||||
|
||||
resp = await client.login(pw, device_name=device_name)
|
||||
|
||||
# check that we logged in successfully
|
||||
|
||||
if isinstance(resp, LoginResponse):
|
||||
write_details_to_disk(resp, homeserver)
|
||||
print(
|
||||
"Logged in using a password. Credentials were stored.",
|
||||
"Try running the script again to login with credentials.",
|
||||
)
|
||||
|
||||
else:
|
||||
print(f'homeserver = "{homeserver}"; user = "{user_id}"')
|
||||
|
||||
print(f"Failed to log in: {resp}")
|
||||
exit(1)
|
||||
|
||||
async with aiosqlite.connect(DATABASE_FILE) as db:
|
||||
await db.executescript('''
|
||||
CREATE TABLE IF NOT EXISTS tokens (
|
||||
user TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
created TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
);
|
||||
''')
|
||||
|
||||
await db.commit()
|
||||
|
||||
asyncio.run(main())
|
||||
print(f"Failed to log in: {resp}")
|
||||
exit(1)
|
||||
|
|
10
sql/init.sql
Normal file
10
sql/init.sql
Normal file
|
@ -0,0 +1,10 @@
|
|||
CREATE TABLE tokens (
|
||||
user TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
created TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE sch_updates (
|
||||
filename TEXT NOT NULL,
|
||||
applied TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
);
|
Loading…
Reference in a new issue