Compare commits
5 commits
a2247a4b2a
...
36a1867ed5
Author | SHA1 | Date | |
---|---|---|---|
36a1867ed5 | |||
f85d0a0f20 | |||
a2ad9ea428 | |||
ac07f3c3dc | |||
80a19c1e73 |
11 changed files with 242 additions and 158 deletions
23
Dockerfile
23
Dockerfile
|
@ -1,13 +1,30 @@
|
||||||
FROM python:3.10-alpine
|
FROM python:3.10-alpine as base
|
||||||
|
|
||||||
|
FROM base as builder
|
||||||
|
|
||||||
RUN apk update && apk add cmake olm make alpine-sdk
|
RUN apk update && apk add cmake olm make alpine-sdk
|
||||||
|
|
||||||
|
RUN mkdir /install
|
||||||
|
|
||||||
|
COPY requirements.txt /requirements.txt
|
||||||
|
|
||||||
|
RUN pip install --prefix=/install -r /requirements.txt
|
||||||
|
|
||||||
|
|
||||||
|
FROM base
|
||||||
|
|
||||||
|
RUN apk update && apk add olm
|
||||||
|
|
||||||
|
COPY --from=builder /install /usr/local
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
COPY requirements.txt /app/requirements.txt
|
|
||||||
RUN pip install -r requirements.txt
|
STOPSIGNAL SIGINT
|
||||||
|
|
||||||
RUN mkdir /data
|
RUN mkdir /data
|
||||||
|
|
||||||
COPY matrix-invitation-dealer /app/matrix-invitation-dealer
|
COPY matrix-invitation-dealer /app/matrix-invitation-dealer
|
||||||
|
COPY sql /app/sql
|
||||||
COPY docker.env /app/.env
|
COPY docker.env /app/.env
|
||||||
|
|
||||||
CMD ["python3", "-m", "matrix-invitation-dealer"]
|
CMD ["python3", "-m", "matrix-invitation-dealer"]
|
||||||
|
|
|
@ -6,3 +6,19 @@ services:
|
||||||
network_mode: "host" # FIXME
|
network_mode: "host" # FIXME
|
||||||
volumes:
|
volumes:
|
||||||
- ./data:/data
|
- ./data:/data
|
||||||
|
environment:
|
||||||
|
# matrix credentials
|
||||||
|
MATRIX_HOMESERVER: "matrix.nolog.chat"
|
||||||
|
MATRIX_USER_ID: "@invite:nolog.chat"
|
||||||
|
MATRIX_USER_PASSWORD: "REDACTED"
|
||||||
|
|
||||||
|
# admin credentials
|
||||||
|
SYNAPSE_ADMIN_HOMESERVER: "http://127.0.0.1:8008"
|
||||||
|
SYNAPSE_ADMIN_ACCESS_TOKEN: "REDACTED"
|
||||||
|
|
||||||
|
# accept invites from users with the following suffix
|
||||||
|
USER_ID_SUFFIX: "nolog.chat"
|
||||||
|
|
||||||
|
# restrictions and quotas
|
||||||
|
USER_REQUIRED_AGE: "14d"
|
||||||
|
INVITE_CODE_QUOTA: "10/7d"
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from .env import CREDENTIALS_FILE
|
||||||
from .client import create_bot
|
from .client import create_client
|
||||||
from .main import Bot
|
from .main import Bot
|
||||||
|
from .setup import matrix_account_setup
|
||||||
|
from .migrate import db_apply_migrations
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.WARNING)
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
@ -11,9 +15,20 @@ logging.getLogger(__package__).setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
async def go():
|
async def go():
|
||||||
client = await create_bot()
|
if not os.path.exists(CREDENTIALS_FILE):
|
||||||
bot = Bot(client)
|
await matrix_account_setup()
|
||||||
await bot.run()
|
else:
|
||||||
|
await db_apply_migrations() # create and update db
|
||||||
|
|
||||||
|
with open(CREDENTIALS_FILE) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
client = await create_client(config)
|
||||||
|
|
||||||
|
bot = Bot(client)
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
await bot.run()
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(go())
|
asyncio.run(go())
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
from .env import INVITE_CODE_EXPIRATION
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -46,6 +49,7 @@ class SynapseAdmin:
|
||||||
},
|
},
|
||||||
json={
|
json={
|
||||||
"uses_allowed": 1,
|
"uses_allowed": 1,
|
||||||
|
"expiry_time": int((time.time() + INVITE_CODE_EXPIRATION.total_seconds())*1000),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if not resp.ok:
|
if not resp.ok:
|
||||||
|
@ -54,9 +58,3 @@ class SynapseAdmin:
|
||||||
json = await resp.json()
|
json = await resp.json()
|
||||||
|
|
||||||
return json["token"]
|
return json["token"]
|
||||||
|
|
||||||
async def delete_token(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def get_token(self):
|
|
||||||
pass
|
|
||||||
|
|
|
@ -1,34 +1,28 @@
|
||||||
#!/usr/bin/env python3
|
from .env import STORE_PATH
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from .env import CREDENTIALS_FILE, STORE_PATH
|
|
||||||
from nio import AsyncClient, AsyncClientConfig
|
from nio import AsyncClient, AsyncClientConfig
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def create_bot() -> AsyncClient:
|
async def create_client(credentials) -> AsyncClient:
|
||||||
if not os.path.exists(CREDENTIALS_FILE):
|
cfg = AsyncClientConfig(
|
||||||
print("Please first run setup to create initial connection parameters and database")
|
encryption_enabled=True,
|
||||||
exit(1)
|
store_sync_tokens=True,
|
||||||
else:
|
store_name="test_store",
|
||||||
with open(CREDENTIALS_FILE, "r") as f:
|
)
|
||||||
config = json.load(f)
|
client = AsyncClient(
|
||||||
|
credentials["homeserver"], config=cfg, store_path=str(STORE_PATH)
|
||||||
|
)
|
||||||
|
client.access_token = credentials["access_token"]
|
||||||
|
|
||||||
cfg = AsyncClientConfig(
|
client.user_id = credentials["user_id"]
|
||||||
encryption_enabled=True,
|
|
||||||
store_sync_tokens=True,
|
|
||||||
store_name="test_store",
|
|
||||||
)
|
|
||||||
client = AsyncClient(config["homeserver"], config=cfg, store_path=STORE_PATH)
|
|
||||||
|
|
||||||
client.access_token = config["access_token"]
|
client.device_id = credentials["device_id"]
|
||||||
|
|
||||||
client.user_id = config["user_id"]
|
client.load_store()
|
||||||
|
if client.should_upload_keys:
|
||||||
client.device_id = config["device_id"]
|
await client.keys_upload()
|
||||||
|
|
||||||
client.download
|
|
||||||
client.load_store()
|
|
||||||
if client.should_upload_keys:
|
|
||||||
await client.keys_upload()
|
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
|
|
@ -1,12 +0,0 @@
|
||||||
from .env import DATABASE_FILE
|
|
||||||
import aiosqlite
|
|
||||||
|
|
||||||
|
|
||||||
class CRUD:
|
|
||||||
def __init__(self, connection: aiosqlite.Connection):
|
|
||||||
self.connection = connection
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def at(cls, file: str):
|
|
||||||
conn = await aiosqlite.connect(file)
|
|
||||||
return cls(conn)
|
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import datetime
|
import datetime
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Optional, TypeVar, Dict
|
from typing import Any, Callable, Optional, TypeVar, Dict
|
||||||
|
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
@ -16,10 +17,12 @@ def td_parse(v: str) -> datetime.timedelta:
|
||||||
match = re.match(r"^(?:(\d+)d)?\s*(?:(\d+)h)?\s*(?:(\d+)m)?\s*(?:(\d+)s)?$", v)
|
match = re.match(r"^(?:(\d+)d)?\s*(?:(\d+)h)?\s*(?:(\d+)m)?\s*(?:(\d+)s)?$", v)
|
||||||
if match is None:
|
if match is None:
|
||||||
raise ValueError(f'Cannot parse "{v}" into timedelta')
|
raise ValueError(f'Cannot parse "{v}" into timedelta')
|
||||||
|
|
||||||
days = int(match.group(1) or 0)
|
days = int(match.group(1) or 0)
|
||||||
hours = int(match.group(2) or 0)
|
hours = int(match.group(2) or 0)
|
||||||
minutes = int(match.group(3) or 0)
|
minutes = int(match.group(3) or 0)
|
||||||
seconds = int(match.group(4) or 0)
|
seconds = int(match.group(4) or 0)
|
||||||
|
|
||||||
return datetime.timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
|
return datetime.timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,24 +47,23 @@ if os.path.isfile(_ENV_FILE):
|
||||||
continue
|
continue
|
||||||
_DOTENV[split[0]] = "=".join(split[1:])
|
_DOTENV[split[0]] = "=".join(split[1:])
|
||||||
|
|
||||||
USER_REQUIRED_AGE: datetime.timedelta = env(
|
USER_REQUIRED_AGE = env(td_parse, "USER_REQUIRED_AGE", datetime.timedelta(days=14))
|
||||||
td_parse, "USER_REQUIRED_AGE", datetime.timedelta(days=14)
|
|
||||||
)
|
|
||||||
|
|
||||||
SYNAPSE_ADMIN_ACCESS_TOKEN: str = env(str, "SYNAPSE_ADMIN_ACCESS_TOKEN", "")
|
SYNAPSE_ADMIN_ACCESS_TOKEN = env(str, "SYNAPSE_ADMIN_ACCESS_TOKEN", "")
|
||||||
SYNAPSE_ADMIN_HOMESERVER: str = env(
|
SYNAPSE_ADMIN_HOMESERVER = env(str, "SYNAPSE_ADMIN_HOMESERVER", "http://127.0.0.1:8008")
|
||||||
str, "SYNAPSE_ADMIN_ACCESS_TOKEN", "http://127.0.0.1:8008"
|
|
||||||
)
|
|
||||||
|
|
||||||
DATABASE_FILE: str = env(str, "DATABASE_FILE", "data.sqlite")
|
DATABASE_FILE = env(Path, "DATABASE_FILE", "data.sqlite")
|
||||||
CREDENTIALS_FILE: str = env(str, "CREDENTIALS_FILE", "credentials.json")
|
CREDENTIALS_FILE = env(Path, "CREDENTIALS_FILE", "credentials.json")
|
||||||
STORE_PATH: str = env(str, "STORE_PATH", "store")
|
STORE_PATH = env(Path, "STORE_PATH", "store")
|
||||||
|
|
||||||
USER_ID_SUFFIX: str = env(str, "USER_ID_SUFFIX", "nolog.chat")
|
USER_ID_SUFFIX = env(str, "USER_ID_SUFFIX", "nolog.chat")
|
||||||
|
|
||||||
INVITE_CODE_QUOTA: str = env(str, "INVITE_CODE_QUOTA", "10/7d")
|
INVITE_CODE_QUOTA = env(str, "INVITE_CODE_QUOTA", "10/7d")
|
||||||
|
|
||||||
icq_amount, icq_timespan = INVITE_CODE_QUOTA.split("/")
|
icq_amount, icq_timespan = INVITE_CODE_QUOTA.split("/")
|
||||||
INVITE_CODE_QUOTA_AMOUNT: int = int(icq_amount)
|
INVITE_CODE_QUOTA_AMOUNT = int(icq_amount)
|
||||||
INVITE_CODE_QUOTA_TIMESPAN: datetime.timedelta = td_parse(icq_timespan)
|
INVITE_CODE_QUOTA_TIMESPAN = td_parse(icq_timespan)
|
||||||
|
|
||||||
|
INVITE_CODE_EXPIRATION = env(
|
||||||
|
td_parse, "INVITE_CODE_EXPIRATION", datetime.timedelta(days=7)
|
||||||
|
)
|
||||||
|
|
|
@ -37,14 +37,13 @@ class Bot:
|
||||||
if (
|
if (
|
||||||
event.membership == "invite"
|
event.membership == "invite"
|
||||||
and event.state_key == self.client.user_id # event about me
|
and event.state_key == self.client.user_id # event about me
|
||||||
and event.sender.endswith(':' + env.USER_ID_SUFFIX)
|
and event.sender.endswith(":" + env.USER_ID_SUFFIX)
|
||||||
and event.content.get('is_direct', False)
|
and event.content.get("is_direct", False)
|
||||||
):
|
):
|
||||||
# we've got a valid invite!
|
# we've got a valid invite!
|
||||||
logger.debug("joining DM of %s", event.sender)
|
logger.debug("joining DM of %s", event.sender)
|
||||||
await self.client.join(room.room_id)
|
await self.client.join(room.room_id)
|
||||||
elif event.membership == "invite" and event.state_key == self.client.user_id:
|
elif event.membership == "invite" and event.state_key == self.client.user_id:
|
||||||
print(event.content)
|
|
||||||
await self.client.room_leave(room.room_id)
|
await self.client.room_leave(room.room_id)
|
||||||
|
|
||||||
async def room_member_update_callback(
|
async def room_member_update_callback(
|
||||||
|
@ -60,42 +59,22 @@ class Bot:
|
||||||
if not user:
|
if not user:
|
||||||
return
|
return
|
||||||
|
|
||||||
allowed = await self.user_allowed(user)
|
await self.send_hi_message(user, room)
|
||||||
|
|
||||||
if allowed is None:
|
|
||||||
await self.send_message(
|
|
||||||
room.room_id,
|
|
||||||
plain="Hello! I couldn't fetch your account information. Sorry. You can try again later.",
|
|
||||||
)
|
|
||||||
await self.leave(room)
|
|
||||||
return
|
|
||||||
|
|
||||||
if not allowed:
|
|
||||||
await self.send_message(
|
|
||||||
room.room_id,
|
|
||||||
formatted="Hello! You can't create invites <i>just yet</i>. Feel free to message me in a few days to check again.",
|
|
||||||
)
|
|
||||||
await self.leave(room)
|
|
||||||
return
|
|
||||||
|
|
||||||
await self.send_message(
|
|
||||||
room.room_id,
|
|
||||||
formatted="Hello! <b>You are allowed to create invites</b>, hurray! You can generate a new invite by sending the <code>!new</code> command. I will respond with a single-use code that you can share.",
|
|
||||||
)
|
|
||||||
|
|
||||||
if event.membership == "leave" and room.joined_count == 1:
|
if event.membership == "leave" and room.joined_count == 1:
|
||||||
# leave rooms where we're alone
|
# leave rooms where we're alone
|
||||||
await self.leave(room)
|
await self.leave(room)
|
||||||
|
|
||||||
async def room_message_callback(self, room: MatrixRoom, event: RoomMessage):
|
async def room_message_callback(self, room: MatrixRoom, event: RoomMessage):
|
||||||
if type(room) is not MatrixRoom:
|
if type(room) is not MatrixRoom or event.sender == self.client.user_id:
|
||||||
return
|
|
||||||
|
|
||||||
if type(event) is not RoomMessageText or event.body != "!new":
|
|
||||||
return
|
return
|
||||||
|
|
||||||
user = event.sender
|
user = event.sender
|
||||||
|
|
||||||
|
if type(event) is not RoomMessageText or event.body != "!new":
|
||||||
|
await self.send_hi_message(user, room)
|
||||||
|
return
|
||||||
|
|
||||||
allowed = await self.user_allowed(user)
|
allowed = await self.user_allowed(user)
|
||||||
if allowed is None:
|
if allowed is None:
|
||||||
await self.send_message(
|
await self.send_message(
|
||||||
|
@ -130,6 +109,29 @@ class Bot:
|
||||||
formatted=f"<code>{token}</code>",
|
formatted=f"<code>{token}</code>",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def send_hi_message(self, user_id: str, room: MatrixRoom):
|
||||||
|
allowed = await self.user_allowed(user_id)
|
||||||
|
|
||||||
|
if allowed is None:
|
||||||
|
await self.send_message(
|
||||||
|
room.room_id,
|
||||||
|
plain="Hello! I couldn't fetch your account information. Sorry. You can try again later.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
await self.send_message(
|
||||||
|
room.room_id,
|
||||||
|
formatted="Hello! You can't create invites <i>just yet</i>. Feel free to message me in a few days to check again.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.send_message(
|
||||||
|
room.room_id,
|
||||||
|
formatted="Hello! <b>You are allowed to create invites</b>, hurray! You can generate a new invite by sending the <code>!new</code> command. I will respond with a single-use code that you can share.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
async def user_allowed(self, user: str) -> Optional[bool]:
|
async def user_allowed(self, user: str) -> Optional[bool]:
|
||||||
"""
|
"""
|
||||||
Checks both that the user is from the homeserver we are managing, and
|
Checks both that the user is from the homeserver we are managing, and
|
||||||
|
@ -164,7 +166,6 @@ class Bot:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
not_me = list(filter(lambda u: u.user_id != self.client.user_id, users.members))
|
not_me = list(filter(lambda u: u.user_id != self.client.user_id, users.members))
|
||||||
print(room.users)
|
|
||||||
if len(not_me) > 1:
|
if len(not_me) > 1:
|
||||||
# This shouldn't really happen, since we're trying our best to stay out of
|
# This shouldn't really happen, since we're trying our best to stay out of
|
||||||
# rooms with multiple people.
|
# rooms with multiple people.
|
||||||
|
@ -213,20 +214,34 @@ class Bot:
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
self.admin_api = await SynapseAdmin.at(
|
self.admin_api = await SynapseAdmin.at(
|
||||||
env.SYNAPSE_ADMIN_HOMESERVER, env.SYNAPSE_ADMIN_ACCESS_TOKEN or self.client.access_token
|
env.SYNAPSE_ADMIN_HOMESERVER or self.client.homeserver,
|
||||||
|
env.SYNAPSE_ADMIN_ACCESS_TOKEN or self.client.access_token,
|
||||||
)
|
)
|
||||||
self.db = await aiosqlite.connect(env.DATABASE_FILE)
|
self.db = await aiosqlite.connect(env.DATABASE_FILE)
|
||||||
|
|
||||||
try:
|
while True:
|
||||||
while True:
|
future = asyncio.gather(self.client.sync_forever(30_000), self.main())
|
||||||
try:
|
try:
|
||||||
await asyncio.gather(self.client.sync_forever(30_000), self.main())
|
await asyncio.shield(future)
|
||||||
except Exception:
|
except asyncio.CancelledError:
|
||||||
# TODO: better restart system
|
# When we are getting cancelled, we want to first finish writing to the DB,
|
||||||
logger.exception("Restarting")
|
# and then gracefully shutdown all connections
|
||||||
await asyncio.sleep(15)
|
logger.info("Gracefully shutting down")
|
||||||
finally:
|
|
||||||
await self.client.close()
|
|
||||||
await self.db.close()
|
|
||||||
await self.admin_api.session.close()
|
|
||||||
|
|
||||||
|
await self.db_lock.acquire()
|
||||||
|
|
||||||
|
try:
|
||||||
|
future.cancel()
|
||||||
|
await future
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await self.client.close()
|
||||||
|
await self.db.close()
|
||||||
|
await self.admin_api.session.close()
|
||||||
|
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
# TODO: better restart system
|
||||||
|
logger.exception("Restarting")
|
||||||
|
await asyncio.sleep(15)
|
||||||
|
|
36
matrix-invitation-dealer/migrate.py
Normal file
36
matrix-invitation-dealer/migrate.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
from pathlib import Path
|
||||||
|
import glob
|
||||||
|
import aiosqlite
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .env import DATABASE_FILE
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def db_apply_migrations():
|
||||||
|
if not os.path.exists(DATABASE_FILE):
|
||||||
|
async with aiosqlite.connect(DATABASE_FILE) as db:
|
||||||
|
script = Path("sql/init.sql").read_text()
|
||||||
|
await db.executescript(script)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async with aiosqlite.connect(DATABASE_FILE) as db:
|
||||||
|
async with db.execute("SELECT filename FROM sch_updates") as cursor:
|
||||||
|
already_applied = [r[0] for r in (await cursor.fetchall())]
|
||||||
|
|
||||||
|
for filename in glob.glob("sql/update*"):
|
||||||
|
file = Path(filename)
|
||||||
|
if file.name not in already_applied:
|
||||||
|
try:
|
||||||
|
await db.executescript(file.read_text())
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO sch_updates (filename) VALUES (?)", file.name
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to migrate!")
|
||||||
|
await db.rollback()
|
||||||
|
exit(1)
|
||||||
|
else:
|
||||||
|
await db.commit()
|
|
@ -1,11 +1,12 @@
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
import getpass
|
|
||||||
import aiosqlite
|
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
from nio import AsyncClient, AsyncClientConfig, LoginResponse
|
from nio import AsyncClient, AsyncClientConfig, LoginResponse
|
||||||
|
|
||||||
from .env import CREDENTIALS_FILE, DATABASE_FILE, STORE_PATH
|
from .env import CREDENTIALS_FILE, STORE_PATH, env
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def write_details_to_disk(resp: LoginResponse, homeserver) -> None:
|
def write_details_to_disk(resp: LoginResponse, homeserver) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -24,57 +25,49 @@ def write_details_to_disk(resp: LoginResponse, homeserver) -> None:
|
||||||
f,
|
f,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def main():
|
|
||||||
print(
|
|
||||||
"First time use. Did not find credential file. Asking for "
|
|
||||||
"homeserver, user, and password to create credential file."
|
|
||||||
)
|
|
||||||
|
|
||||||
homeserver = input(f"Enter your homeserver URL: ")
|
async def matrix_account_setup():
|
||||||
|
MATRIX_HOMESERVER = env(str, "MATRIX_HOMESERVER")
|
||||||
|
MATRIX_USER_ID = env(str, "MATRIX_USER_ID")
|
||||||
|
MATRIX_USER_PASSWORD = env(str, "MATRIX_USER_PASSWORD")
|
||||||
|
MATRIX_DEVICE_NAME = env(str, "MATRIX_DEVICE_NAME", "matrix-invitation-dealer")
|
||||||
|
|
||||||
|
homeserver = MATRIX_HOMESERVER
|
||||||
|
|
||||||
if not (homeserver.startswith("https://") or homeserver.startswith("http://")):
|
if not (homeserver.startswith("https://") or homeserver.startswith("http://")):
|
||||||
homeserver = "https://" + homeserver
|
homeserver = "https://" + MATRIX_HOMESERVER
|
||||||
|
|
||||||
user_id = input(f"Enter your full user ID: ")
|
cfg = AsyncClientConfig(
|
||||||
|
encryption_enabled=True,
|
||||||
|
store_sync_tokens=True,
|
||||||
|
store_name="account_store",
|
||||||
|
)
|
||||||
|
os.makedirs(STORE_PATH, exist_ok=True)
|
||||||
|
|
||||||
device_name = input(f"Choose a name for this device: [matrix-nio] ") or "matrix-nio"
|
client = AsyncClient(
|
||||||
|
homeserver, MATRIX_USER_ID, config=cfg, store_path=str(STORE_PATH)
|
||||||
|
)
|
||||||
|
|
||||||
cfg = AsyncClientConfig(
|
resp = await client.login(MATRIX_USER_PASSWORD, device_name=MATRIX_DEVICE_NAME)
|
||||||
encryption_enabled=True,
|
|
||||||
store_sync_tokens=True,
|
client.load_store()
|
||||||
store_name="test_store",
|
if client.should_upload_keys:
|
||||||
|
await client.keys_upload()
|
||||||
|
|
||||||
|
if isinstance(resp, LoginResponse):
|
||||||
|
write_details_to_disk(resp, homeserver)
|
||||||
|
|
||||||
|
assert client.olm is not None
|
||||||
|
key = client.olm.account.identity_keys["ed25519"]
|
||||||
|
|
||||||
|
logger.info("Logged in as %s. Please manually verify this session.")
|
||||||
|
logger.info(
|
||||||
|
"Session fingerprint %s",
|
||||||
|
" ".join([key[i : i + 4] for i in range(0, len(key), 4)]),
|
||||||
)
|
)
|
||||||
os.makedirs(STORE_PATH, exist_ok=True)
|
|
||||||
client = AsyncClient(homeserver, user_id, config=cfg, store_path=STORE_PATH)
|
|
||||||
|
|
||||||
pw = getpass.getpass()
|
else:
|
||||||
|
print(f'homeserver = "{homeserver}"; user = "{MATRIX_USER_ID}"')
|
||||||
|
|
||||||
resp = await client.login(pw, device_name=device_name)
|
print(f"Failed to log in: {resp}")
|
||||||
|
exit(1)
|
||||||
# check that we logged in successfully
|
|
||||||
|
|
||||||
if isinstance(resp, LoginResponse):
|
|
||||||
write_details_to_disk(resp, homeserver)
|
|
||||||
print(
|
|
||||||
"Logged in using a password. Credentials were stored.",
|
|
||||||
"Try running the script again to login with credentials.",
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f'homeserver = "{homeserver}"; user = "{user_id}"')
|
|
||||||
|
|
||||||
print(f"Failed to log in: {resp}")
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
async with aiosqlite.connect(DATABASE_FILE) as db:
|
|
||||||
await db.executescript('''
|
|
||||||
CREATE TABLE IF NOT EXISTS tokens (
|
|
||||||
user TEXT NOT NULL,
|
|
||||||
token TEXT NOT NULL,
|
|
||||||
created TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
|
||||||
);
|
|
||||||
''')
|
|
||||||
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|
10
sql/init.sql
Normal file
10
sql/init.sql
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
CREATE TABLE tokens (
|
||||||
|
user TEXT NOT NULL,
|
||||||
|
token TEXT NOT NULL,
|
||||||
|
created TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE sch_updates (
|
||||||
|
filename TEXT NOT NULL,
|
||||||
|
applied TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||||
|
);
|
Loading…
Reference in a new issue