Boil bot lib down to essentials

master
Thor 3 years ago
parent 38426d176f
commit 906d044b8c
  1. 414
      bot.py

414
bot.py

@ -1,322 +1,196 @@
import os
import sys
import time
from datetime import datetime, timezone, timedelta
import copy
import json
import pprint
import threading
import traceback
import bogofilter
import html2text
from mastodon import Mastodon, MastodonNotFoundError
from mastodon import Mastodon
def log_print(source, text = ""):
prefix = "{}: ".format(source)
text = (prefix + text.strip()).replace("\n", "\n" + prefix)
print(text)
def log_pprint(source, obj):
log_print(source, pprint.pformat(obj))
def log_obj_str(obj):
if isinstance(obj, str):
return obj.strip()
else:
return pprint.pformat(obj).strip()
class Instance:
def __init__(self, name, config):
self.name = name
self.config = config
self.base_url = "https://{}".format(name)
self.client_file = "secret/{}.client".format(name)
self.user_file = "secret/{}.user".format(name)
self.spawner_thread = threading.Thread(
target = self.spawner,
name = self.name + " spawner",
args = (),
kwargs = {},
daemon = True)
class BotClient:
DEFAULT_STATE = {"min_status_id": "0"}
def __init__(self, bot, config):
self.bot = bot
self.config = {**{
"base_url": "https://{}".format(config["name"]),
"client_file": "secret/{}.client".format(config["name"]),
"user_file": "secret/{}.user".format(config["name"]),
"state_file": "state/{}.state".format(config["name"])}, **config}
self.state_lock = threading.Lock()
self.load_state()
self.tracker_thread = threading.Thread(
target = self.tracker,
name = self.name + " tracker",
self.poll_thread = threading.Thread(
target = self.poll_loop,
name = "{} Poll Loop".format(self.config["name"]),
args = (),
kwargs = {},
daemon = True)
def setup(self, app_name):
if not os.path.exists(self.client_file):
def log_str(self, obj, infix = str()):
return self.bot.log_str(obj, infix = "{}: {}".format(self.config["name"], infix))
def log(self, obj, infix = str()):
return self.bot.log(obj, infix = "{}: {}".format(self.config["name"], infix))
def setup(self):
if not os.path.exists(self.config["client_file"]):
Mastodon.create_app(
app_name,
api_base_url = self.base_url,
to_file = self.client_file)
self.app_name,
api_base_url = self.config["base_url"],
to_file = self.config["client_file"])
if not os.path.exists(self.user_file):
if not os.path.exists(self.config["user_file"]):
api = Mastodon(
api_base_url = self.base_url,
client_id = self.client_file)
api_base_url = self.config["base_url"],
client_id = self.config["client_file"])
auth_url = api.auth_request_url()
print("Go to:")
print(auth_url)
print()
self.log("Go to:")
self.log(auth_url)
self.log()
auth_code = input("Enter code: ")
auth_code = input(log_string("Enter code: "))
print()
api.log_in(code = auth_code, to_file = self.user_file)
self.log()
def start(self):
self.spawner_thread.start()
def spawner(self):
self.load_state()
api.log_in(code = auth_code, to_file = self.config["user_file"])
self.api = Mastodon(
access_token = self.user_file,
api_base_url = self.base_url)
self.tracker_thread.start()
self.purger_thread.start()
while True:
self.tracker_report()
time.sleep(60)
access_token = self.config["user_file"],
api_base_url = self.config["base_url"])
def tracker(self):
my_id = self.api.me()["id"]
def start(self):
self.poll_thread.start()
def poll_loop(self):
while True:
try:
self.state_lock.acquire()
self.state_lock.release()
statuses = self.api.account_statuses(my_id, min_id = self.state["min_id"])
statuses = self.api.timeline(min_id = self.state["min_status_id"])
h2t = html2text.HTML2Text()
h2t.ignore_links = True
if len(statuses) == 0:
self.on_poll()
time.sleep(self.config["poll_interval"])
else:
self.on_wake()
while len(statuses) > 0:
self.on_poll_page(statuses)
for status in sorted(statuses,
key = lambda status: status["created_at"]):
while not statuses is None and len(statuses) > 0:
log_print(self.name, "Found {} new status(es)".format(len(statuses)))
self.on_status(status)
for status in sorted(statuses,
key = lambda status: status["created_at"]):
with self.state_lock:
self.state["min_status_id"] = status["id"]
self.state["min_id"] = status["id"]
self.save_state()
md_text = h2t.handle(status["content"])
self.track_status(status)
self.save_state()
statuses = self.api.fetch_previous(statuses)
# Rate limit (max 300 requests per 5 minutes, i.e. 1 per second)
time.sleep(1)
except:
log_print(self.name, traceback.format_exc())
time.sleep(5)
def purger(self):
while True:
try:
deleted = False
timeslot_key, status_id = self.next_expired()
if not timeslot_key is None:
try:
log_print(self.name, "Deleting status {} in timeslot {}".format(status_id, timeslot_key))
self.api.status_delete(status_id)
deleted = True
except MastodonNotFoundError:
log_print(self.name,
"Cannot find status {} on server".format(status_id))
self.expire_status(timeslot_key, status_id)
if deleted:
time.sleep(60)
else:
time.sleep(1)
except:
log_print(self.name, traceback.format_exc())
time.sleep(60)
time.sleep(self.config["rate_limit"])
statuses = self.api.fetch_previous(statuses)
except Exception as exc:
self.save_state()
self.on_poll_exception(exc)
def load_state(self):
self.state_lock.acquire()
if not os.path.exists(self.state_file):
self.state = dict(
min_id = "0",
timeslots = {})
else:
with open(self.state_file) as json_file:
self.state = json.load(json_file)
self.state["timeslots"] = dict(map(lambda kv: (int(kv[0]), set(kv[1])), self.state["timeslots"]))
self.state_lock.release()
with self.state_lock:
self.state = self.on_load_state()
self.on_state_loaded(self.state)
def save_state(self):
self.state_lock.acquire()
json_state = self.state.copy()
json_state["timeslots"] = list(map(lambda kv: [kv[0], list(kv[1])], json_state["timeslots"].items()))
self.state_lock.release()
with self.state_lock:
self.on_save_state(copy.deepcopy(self.state))
self.on_state_saved(self.state)
with open(self.state_file, "w") as json_file:
json.dump(json_state, json_file, indent = 4)
def on_poll(self):
pass
def tracker_report(self):
self.state_lock.acquire()
def on_poll_exception(self, exc):
pass
total_timeslots = len(self.state["timeslots"])
total_statuses = 0
for timeslot_key, status_ids in self.state["timeslots"].items():
total_statuses += len(status_ids)
def on_wake(self):
pass
self.state_lock.release()
def on_status_page(self, statuses):
pass
log_print(self.name, "Tracking {} statuses across {} timeslots".format(
total_statuses, total_timeslots))
def on_status(self, status):
pass
def track_status(self, status):
status_id = str(status["id"])
timeslot_key = encode_time(status["created_at"])
def on_load_state(self):
if os.path.exists(self.config["state_file"]):
with open(self.config["state_file"]) as json_file:
return json.load(json_file)
self.state_lock.acquire()
if status["reblog"] is None:
timeslots = self.state["timeslots"]
if not timeslot_key in timeslots:
timeslots[timeslot_key] = set()
timeslots[timeslot_key].add(status_id)
self.state_lock.release()
def next_expired(self):
now = datetime.now(timezone.utc)
min_timeslot_key = encode_time(now - timedelta(minutes = config["max_age"]))
self.state_lock.acquire()
timeslot_key, status_ids = next(iter(self.state["timeslots"].items()), (None, None))
if not timeslot_key is None and timeslot_key < min_timeslot_key:
status_id = next(iter(status_ids), None)
else:
timeslot_key = None
status_id = None
self.state_lock.release()
return (timeslot_key, status_id)
def expire_status(self, timeslot_key, status_id):
self.state_lock.acquire()
timeslots = self.state["timeslots"]
if timeslot_key in timeslots:
if status_id in timeslots[timeslot_key]:
log_print(self.name, "Expiring status {} from timeslot {}".format(status_id, timeslot_key))
timeslots[timeslot_key].remove(status_id)
else:
log_print(self.name, "Cannot expire missing status {} from timeslot {}".format(
status_id, timeslot_key))
if len(timeslots[timeslot_key]) == 0:
log_print(self.name, "Removing empty timeslot {}".format(timeslot_key))
del timeslots[timeslot_key]
else:
log_print(self.name, "Cannot expire status {} from missing timeslot {}".format(
status_id, timeslot_key))
self.state_lock.release()
self.save_state()
def toot_dict_to_mail(toot_dict):
#log_pprint("toot_dict_to_mail", toot_dict)
flags = []
if toot_dict.get("sensitive", False):
flags.append("sensitive")
return copy.deepcopy(self.DEFAULT_STATE)
if toot_dict.get("poll", False):
flags.append("poll")
if toot_dict.get("reblog", False):
flags.append("reblog")
if toot_dict.get("reblogged", False):
flags.append("reblogged")
#if toot_dict.get("favourited", False):
# flags.append("favourited")
if toot_dict.get("bookmarked", False):
flags.append("bookmarked")
if toot_dict.get("pinned", False):
flags.append("pinned")
flags = ", ".join(flags)
headers = {}
def on_state_loaded(self, state):
pass
if toot_dict.get("account") and toot_dict["account"].get("acct"):
headers["From"] = toot_dict["account"]["acct"]
def on_save_state(self, state):
with open(self.config["state_file"], "w") as json_file:
json.dump(state, json_file, indent = 4)
if toot_dict.get("created_at"):
headers["Date"] = toot_dict["created_at"]
def on_state_saved(self, state):
pass
if toot_dict.get("visibility"):
headers["X-Visibility"] = toot_dict["visibility"]
class Bot:
DEFAULT_CONFIG = {
"app_name": "Generic Bot",
"rate_limit": 1,
"poll_interval": 10,
"clients": {
"mastodon.social": {}}}
if len(flags) > 0:
headers["X-Flags"] = flags
def __init__(self, client_type = BotClient, config = None):
self.clients = {}
if toot_dict.get("spoiler_text"):
headers["Subject"] = toot_dict["spoiler_text"]
self.client_type = client_type
self.config = config or self.DEFAULT_CONFIG
if toot_dict.get("replies_count", 0) > 0:
headers["X-Replies-Count"] = toot_dict["replies_count"]
def log_str(self, obj, infix = str()):
prefix = "{}: {}".format(self.config["tag"], infix)
return prefix + log_obj_str(obj).replace("\n", "\n" + prefix)
if len(toot_dict.get("media_attachments", [])) > 0:
headers["X-Attachments-Count"] = len(toot_dict["media_attachments"])
if toot_dict.get("reblogs_count", 0) > 0:
headers["X-Reblogs-Count"] = toot_dict["reblogs_count"]
if toot_dict.get("favourites_count", 0) > 0:
headers["X-Favourites-Count"] = toot_dict["favourites_count"]
if toot_dict.get("content") and len(toot_dict["content"]) > 0:
body = toot_dict["content"]
else:
body = None
return bogofilter.Mail(headers = headers, body = body)
learning = "-l" in sys.argv[1:]
with open("config.json") as json_file:
config = json.load(json_file)
def log(self, *args, **kwargs):
print(self.log_str(*args, **kwargs))
instances = {}
for name in config["instances"]:
instances[name] = Instance(name = name, config = config)
instances[name].setup()
start_interval = 60.0 / len(config["instances"])
for instance in instances.values():
instance.start()
time.sleep(start_interval)
while True:
time.sleep(1)
def start(self):
self.clients = self.on_start_clients(self.config["clients"])
self.on_clients_started(self.clients)
def on_start_clients(self, client_configs):
clients = {}
for client_name, client_config in client_configs.items():
client_config = {**{
"name": client_name,
"app_name": self.config["app_name"],
"rate_limit": self.config["rate_limit"],
"poll_interval": self.config["poll_interval"],
}, **client_config}
client = self.on_init_client(client_name, client_config)
client.setup()
clients[client_config["name"]] = client
start_interval = self.config["poll_interval"] / len(self.config["clients"])
for client in clients.values():
client.start()
time.sleep(start_interval)
return clients
def on_init_client(self, client_name, client_config):
return self.client_type(self, client_config)
def on_clients_started(self, clients):
pass

Loading…
Cancel
Save