From 220b1e3b64825a0cc1ac5ebca48a07dd5bccc01f Mon Sep 17 00:00:00 2001 From: Thor Harald Johansen Date: Wed, 14 Jul 2021 15:00:38 +0200 Subject: [PATCH] Refactored cringebot.py to use bot.py --- cringebot.py | 460 ++++++++++++++++++++------------------------------- 1 file changed, 176 insertions(+), 284 deletions(-) diff --git a/cringebot.py b/cringebot.py index 0b41c2a..078de80 100644 --- a/cringebot.py +++ b/cringebot.py @@ -3,21 +3,14 @@ import sys import time from datetime import datetime, timezone, timedelta import json -import pprint import threading -import traceback import bogofilter import html2text +from collections import deque from mastodon import Mastodon, MastodonNotFoundError -def log_print(source, text = ""): - prefix = "{}: ".format(source) - text = (prefix + text.strip()).replace("\n", "\n" + prefix) - print(text) - -def log_pprint(source, obj): - log_print(source, pprint.pformat(obj)) +from bot import Bot, BotClient def encode_time(dt): return int(dt.strftime("%Y%m%d%H%M")) @@ -28,74 +21,31 @@ def decode_time(value): else: return dt.strptime(str(value), "%Y%m%d%H") -class Instance: - def __init__(self, name, config): - self.name = name - self.config = config - - self.base_url = "https://{}".format(name) - self.client_file = "secret/{}.client".format(name) - self.user_file = "secret/{}.user".format(name) +class CringeBotClient(BotClient): + def __init__(self, bot, config): + super().__init__(bot, config) - self.state_file = "state/{}.state".format(name) - self.state_lock = threading.Lock() - + self.h2t = html2text.HTML2Text() + self.h2t.ignore_links = True + self.spawner_thread = threading.Thread( target = self.spawner, - name = self.name + " spawner", + name = self.config["name"] + " spawner", args = (), kwargs = {}, daemon = True) - self.tracker_thread = threading.Thread( - target = self.tracker, - name = self.name + " tracker", - args = (), - kwargs = {}, - daemon = True) - self.purger_thread = threading.Thread( target = self.purger, - name = self.name + " purger", + name = self.config["name"] + " purger", args = (), kwargs = {}, daemon = True) - def setup(self): - if not os.path.exists(self.client_file): - Mastodon.create_app( - 'Cringefilter', - api_base_url = self.base_url, - to_file = self.client_file) - - if not os.path.exists(self.user_file): - api = Mastodon( - api_base_url = self.base_url, - client_id = self.client_file) - - auth_url = api.auth_request_url() - - print("Go to:") - print(auth_url) - print() - - auth_code = input("Enter code: ") - - print() - - api.log_in(code = auth_code, to_file = self.user_file) - - def start(self): + def on_start(self): self.spawner_thread.start() def spawner(self): - self.load_state() - - self.api = Mastodon( - access_token = self.user_file, - api_base_url = self.base_url) - - self.tracker_thread.start() if not learning: self.purger_thread.start() @@ -104,141 +54,125 @@ class Instance: time.sleep(60) def respond(self, status, message): - log_print(self.name, "Responded with: {}".format(message)) - self.api.status_reply(status, "{}\n#{}".format(message, config["tag"]), visibility = "direct", untag = True) - - def tracker(self): - my_id = self.api.me()["id"] - - while True: - try: - - statuses = self.api.account_statuses(my_id, min_id = self.state["min_id"]) - - h2t = html2text.HTML2Text() - h2t.ignore_links = True - - while not statuses is None and len(statuses) > 0: - log_print(self.name, "Found {} new status(es)".format(len(statuses))) - - for status in sorted(statuses, - key = lambda status: status["created_at"]): - - self.state_lock.acquire() - self.state["min_id"] = status["id"] - self.state_lock.release() - - if status["reblog"]: - continue - - md_text = h2t.handle(status["content"]) - - if config["tag"] in md_text: - continue - - mail_text = toot_dict_to_mail(status).format() - - if learning: - preview = toot_dict_to_mail(status) - preview.body = md_text - preview_text = preview.format() - - log_print(self.name, preview_text) - log_print(self.name) - - category = None - while not category in bogofilter.categories: - category = input("H(am), S(pam) or U(nknown)? ").upper() - - if category != bogofilter.UNSURE: - bogofilter.run(mail_text, [category]) - - if category == bogofilter.SPAM: - self.track_status(status) + self.log("Responded with:") + self.log(message) + self.api.status_reply(status, "{}\n{}".format(message, self.config["tag"]), visibility = "direct", untag = True) + time.sleep(1) + + def on_status(self, status): + if status["account"]["id"] != self.api.me()["id"]: + return + + if status["reblog"]: + return + + md_text = self.h2t.handle(status["content"]) + if self.config["tag"] in md_text.split(): + return + mail_text = toot_dict_to_mail(status).format() + + preview = toot_dict_to_mail(status) + preview.body = md_text + preview_text = preview.format() + + if learning: + self.log(preview_text) + self.log() + + category = None + while not category in bogofilter.categories: + category = input("H(am), S(pam) or U(nknown)? ").upper() + + if category != bogofilter.UNSURE: + bogofilter.run(mail_text, [category]) - log_print(self.name) + if category == bogofilter.SPAM: + self.track_status(status) - self.save_state() - else: - replied_id = status.get("in_reply_to_id", None) - if replied_id: + self.log() + else: + replied_id = status.get("in_reply_to_id", None) + if replied_id: + try: + replied_status = self.api.status(replied_id) + replied_tokens = self.h2t.handle(replied_status["content"]).split() + + if self.config["tag"] in replied_tokens: + target_status_id = replied_status.get("in_reply_to_id", None) + if target_status_id: + try: + target_status = self.api.status(target_status_id) + target_timeslot_key = encode_time(target_status["created_at"]) + target_mail_text = toot_dict_to_mail(target_status).format() + + command = self.h2t.handle(status["content"]).strip() + tokens = deque(command.split()) + self.log("Received command: {}".format(command)) try: - replied_status = self.api.status(replied_id) - replied_md = h2t.handle(replied_status["content"]) - - if config["tag"] in replied_md: - target_status_id = replied_status.get("in_reply_to_id", None) - if target_status_id: - try: - target_status = self.api.status(target_status_id) - target_timeslot_key = encode_time(target_status["created_at"]) - target_mail_text = toot_dict_to_mail(target_status).format() - - command = h2t.handle(status["content"]).strip() - log_print(self.name, "Received command: {}".format(command)) - if command == "learn spam": - bogofilter.run(target_mail_text, [bogofilter.LEARN_SPAM]) - self.track_status(target_status) - self.respond(status, "Learned as spam") - elif command == "unlearn spam": - bogofilter.run(target_mail_text, [bogofilter.UNLEARN_SPAM]) - self.expire_status(target_timeslot_key, target_status_id) - self.respond(status, "Unlearned as spam") - elif command == "relearn spam": - bogofilter.run(target_mail_text, [bogofilter.UNLEARN_HAM, bogofilter.LEARN_SPAM]) - self.track_status(target_status) - self.respond(status, "Relearned as spam") - elif command == "learn ham": - bogofilter.run(target_mail_text, [bogofilter.LEARN_HAM]) - self.expire_status(target_timeslot_key, target_status_id) - self.respond(status, "Learned as ham") - elif command == "unlearn ham": - bogofilter.run(target_mail_text, [bogofilter.UNLEARN_HAM]) - self.respond(status, "Unlearned as ham") - elif command == "relearn ham": - bogofilter.run(target_mail_text, [bogofilter.UNLEARN_SPAM, bogofilter.LEARN_HAM]) - self.expire_status(target_timeslot_key, target_status_id) - self.respond(status, "Relearned as ham") - else: - self.respond(status, "Unknown command") - except MastodonNotFoundError: - self.respond(status, "Original status is missing") - else: - self.respond(status, "Original status not referenced") - continue - except MastodonNotFoundError: - pass - - result = bogofilter.run(mail_text, [bogofilter.CLASSIFY, bogofilter.REGISTER]) - bogo_report = "Bogofilter: Category={}, Score={}".format(result.category, "{:.4f}".format(result.score)) - if result.category == bogofilter.SPAM: - log_print(self.name, "SPAM: Tracking status with ID {} as spam".format(status["id"])) - self.respond(status, "Categorised as spam\n{}".format(bogo_report)) - self.track_status(status) - time.sleep(1) - elif result.category == bogofilter.UNSURE: - log_print(self.name, "UNSURE: Not tracking status with ID {} as spam".format(status["id"])) - self.respond(status, "Categorised as unsure\n{}".format(bogo_report)) - time.sleep(1) - else: - log_print(self.name, "HAM: Not tracking status with ID {} as spam".format(status["id"])) - self.respond(status, "Categorised as ham\n{}".format(bogo_report)) - time.sleep(1) - - log_print(self.name) - log_print(self.name, mail_text) - - self.save_state() - - statuses = self.api.fetch_previous(statuses) - - # Rate limit (max 300 requests per 5 minutes, i.e. 1 per second) - time.sleep(1) - - except: - log_print(self.name, traceback.format_exc()) - - time.sleep(5) + while True: + token = tokens.popleft() + if token == "learn": + token = tokens.popleft() + if token == "spam": + bogofilter.run(target_mail_text, [bogofilter.LEARN_SPAM]) + self.track_status(target_status) + self.respond(status, "Learned as spam") + break + elif token == "ham": + bogofilter.run(target_mail_text, [bogofilter.LEARN_HAM]) + self.expire_status(target_timeslot_key, target_status_id) + self.respond(status, "Learned as ham") + break + elif token == "unlearn": + token = tokens.popleft() + if token == "spam": + bogofilter.run(target_mail_text, [bogofilter.UNLEARN_SPAM]) + self.expire_status(target_timeslot_key, target_status_id) + self.respond(status, "Unlearned as spam") + break + elif token == "ham": + bogofilter.run(target_mail_text, [bogofilter.UNLEARN_SPAM]) + self.expire_status(target_timeslot_key, target_status_id) + self.respond(status, "Unlearned as spam") + break + elif token == "relearn": + token = tokens.popleft() + if token == "spam": + bogofilter.run(target_mail_text, [bogofilter.UNLEARN_HAM, bogofilter.LEARN_SPAM]) + self.track_status(target_status) + self.respond(status, "Relearned as spam") + break + elif token == "ham": + bogofilter.run(target_mail_text, [bogofilter.UNLEARN_SPAM, bogofilter.LEARN_HAM]) + self.expire_status(target_timeslot_key, target_status_id) + self.respond(status, "Relearned as as ham") + break + except IndexError: + self.respond(status, "Invalid command") + except MastodonNotFoundError: + self.respond(status, "Original status is missing") + else: + self.respond(status, "Original status is missing") + return + except MastodonNotFoundError: + pass + + result = bogofilter.run(mail_text, [bogofilter.CLASSIFY, bogofilter.REGISTER]) + bogo_report = "Bogofilter: Category={}, Score={}".format(result.category, "{:.4f}".format(result.score)) + if result.category == bogofilter.SPAM: + self.log("SPAM: Tracking status with ID {} as spam".format(status["id"])) + self.respond(status, "Categorised as spam\n{}".format(bogo_report)) + self.track_status(status) + elif result.category == bogofilter.UNSURE: + self.log("UNSURE: Not tracking status with ID {} as spam".format(status["id"])) + self.respond(status, "Categorised as unsure\n{}".format(bogo_report)) + else: + self.log("HAM: Not tracking status with ID {} as spam".format(status["id"])) + self.respond(status, "Categorised as ham\n{}".format(bogo_report)) + + self.log() + self.log(preview_text) + self.log() def purger(self): while True: @@ -248,13 +182,12 @@ class Instance: if not timeslot_key is None: try: - log_print(self.name, "Deleting status {} in timeslot {}".format(status_id, timeslot_key)) + self.log("Deleting status {} in timeslot {}".format(status_id, timeslot_key)) self.api.status_delete(status_id) deleted = True except MastodonNotFoundError: - log_print(self.name, - "Cannot find status {} on server".format(status_id)) + self.log("Cannot find status {} on server".format(status_id)) self.expire_status(timeslot_key, status_id) @@ -262,106 +195,74 @@ class Instance: time.sleep(60) else: time.sleep(1) - except: - log_print(self.name, traceback.format_exc()) + self.log(traceback.format_exc()) time.sleep(60) - def load_state(self): - self.state_lock.acquire() + def on_load_state(self): + state = super().on_load_state() + state["timeslots"] = state.get("timeslots", {}) + state["timeslots"] = dict(map(lambda kv: (int(kv[0]), set(kv[1])), state["timeslots"])) + return state - if not os.path.exists(self.state_file): - self.state = dict( - min_id = "0", - timeslots = {}) - else: - with open(self.state_file) as json_file: - self.state = json.load(json_file) - - self.state["timeslots"] = dict(map(lambda kv: (int(kv[0]), set(kv[1])), self.state["timeslots"])) - - self.state_lock.release() - - def save_state(self): - self.state_lock.acquire() - - json_state = self.state.copy() - json_state["timeslots"] = list(map(lambda kv: [kv[0], list(kv[1])], json_state["timeslots"].items())) - - self.state_lock.release() - - with open(self.state_file, "w") as json_file: - json.dump(json_state, json_file, indent = 4) + def on_save_state(self, state): + state["timeslots"] = list(map(lambda kv: [kv[0], list(kv[1])], state["timeslots"].items())) + super().on_save_state(state) def tracker_report(self): - self.state_lock.acquire() - - total_timeslots = len(self.state["timeslots"]) - total_statuses = 0 - for timeslot_key, status_ids in self.state["timeslots"].items(): - total_statuses += len(status_ids) + with self.state_lock: + total_timeslots = len(self.state["timeslots"]) + total_statuses = 0 + for timeslot_key, status_ids in self.state["timeslots"].items(): + total_statuses += len(status_ids) - self.state_lock.release() - - log_print(self.name, "Tracking {} statuses across {} timeslots".format( - total_statuses, total_timeslots)) + self.log("Tracking {} statuses across {} timeslots".format(total_statuses, total_timeslots)) def track_status(self, status): status_id = str(status["id"]) timeslot_key = encode_time(status["created_at"]) - self.state_lock.acquire() - if status["reblog"] is None: - timeslots = self.state["timeslots"] - if not timeslot_key in timeslots: - timeslots[timeslot_key] = set() - timeslots[timeslot_key].add(status_id) - - self.state_lock.release() + with self.state_lock: + if status["reblog"] is None: + timeslots = self.state["timeslots"] + if not timeslot_key in timeslots: + timeslots[timeslot_key] = set() + timeslots[timeslot_key].add(status_id) def next_expired(self): now = datetime.now(timezone.utc) - min_timeslot_key = encode_time(now - timedelta(minutes = config["max_age"])) + min_timeslot_key = encode_time(now - timedelta(minutes = self.config["max_age"])) - self.state_lock.acquire() - - timeslot_key, status_ids = next(iter(self.state["timeslots"].items()), (None, None)) + with self.state_lock: + timeslot_key, status_ids = next(iter(self.state["timeslots"].items()), (None, None)) - if not timeslot_key is None and timeslot_key < min_timeslot_key: - status_id = next(iter(status_ids), None) - else: - timeslot_key = None - status_id = None - - self.state_lock.release() - + if not timeslot_key is None and timeslot_key < min_timeslot_key: + status_id = next(iter(status_ids), None) + else: + timeslot_key = None + status_id = None + return (timeslot_key, status_id) def expire_status(self, timeslot_key, status_id): - self.state_lock.acquire() + with self.state_lock: + timeslots = self.state["timeslots"] + if timeslot_key in timeslots: + if status_id in timeslots[timeslot_key]: + self.log("Expiring status {} from timeslot {}".format(status_id, timeslot_key)) + timeslots[timeslot_key].remove(status_id) + else: + self.log("Cannot expire missing status {} from timeslot {}".format( + status_id, timeslot_key)) - timeslots = self.state["timeslots"] - if timeslot_key in timeslots: - if status_id in timeslots[timeslot_key]: - log_print(self.name, "Expiring status {} from timeslot {}".format(status_id, timeslot_key)) - timeslots[timeslot_key].remove(status_id) + if len(timeslots[timeslot_key]) == 0: + self.log("Removing empty timeslot {}".format(timeslot_key)) + del timeslots[timeslot_key] else: - log_print(self.name, "Cannot expire missing status {} from timeslot {}".format( + self.log("Cannot expire status {} from missing timeslot {}".format( status_id, timeslot_key)) - if len(timeslots[timeslot_key]) == 0: - log_print(self.name, "Removing empty timeslot {}".format(timeslot_key)) - del timeslots[timeslot_key] - else: - log_print(self.name, "Cannot expire status {} from missing timeslot {}".format( - status_id, timeslot_key)) - - self.state_lock.release() - self.save_state() - def toot_dict_to_mail(toot_dict): - #log_pprint("toot_dict_to_mail", toot_dict) - flags = [] if toot_dict.get("sensitive", False): @@ -376,8 +277,8 @@ def toot_dict_to_mail(toot_dict): if toot_dict.get("reblogged", False): flags.append("reblogged") - #if toot_dict.get("favourited", False): - # flags.append("favourited") + if toot_dict.get("favourited", False): + flags.append("favourited") if toot_dict.get("bookmarked", False): flags.append("bookmarked") @@ -426,17 +327,8 @@ def toot_dict_to_mail(toot_dict): learning = "-l" in sys.argv[1:] with open("config.json") as json_file: - config = json.load(json_file) - -instances = {} -for name in config["instances"]: - instances[name] = Instance(name = name, config = config) - instances[name].setup() - -start_interval = 60.0 / len(config["instances"]) -for instance in instances.values(): - instance.start() - time.sleep(start_interval) + bot = Bot(CringeBotClient, json.load(json_file)) +bot.start() while True: time.sleep(1)