From 698dafade4cde7c3d691babb37e27430a29be0dd Mon Sep 17 00:00:00 2001 From: bobloy Date: Thu, 8 Jul 2021 08:29:18 -0400 Subject: [PATCH 1/2] Add chatter initialization --- chatter/__init__.py | 6 ++++-- chatter/chat.py | 44 +++++++++++++++++++++++++++++--------------- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/chatter/__init__.py b/chatter/__init__.py index 9447c6a..663dadf 100644 --- a/chatter/__init__.py +++ b/chatter/__init__.py @@ -1,8 +1,10 @@ from .chat import Chatter -def setup(bot): - bot.add_cog(Chatter(bot)) +async def setup(bot): + cog = Chatter(bot) + await cog.initialize() + bot.add_cog(cog) # __all__ = ( diff --git a/chatter/chat.py b/chatter/chat.py index 66ff116..4655da8 100644 --- a/chatter/chat.py +++ b/chatter/chat.py @@ -19,6 +19,7 @@ from redbot.core.utils.predicates import MessagePredicate from chatter.trainers import MovieTrainer, TwitterCorpusTrainer, UbuntuCorpusTrainer2 +chatterbot_log = logging.getLogger("red.fox_v3.chatterbot") log = logging.getLogger("red.fox_v3.chatter") @@ -58,11 +59,14 @@ class Chatter(Cog): This cog trains a chatbot that will talk like members of your Guild """ + models = [ENG_SM, ENG_MD, ENG_LG, ENG_TRF] + algos = [SpacySimilarity, JaccardSimilarity, LevenshteinDistance] + def __init__(self, bot): super().__init__() self.bot = bot self.config = Config.get_conf(self, identifier=6710497116116101114) - default_global = {"learning": True} + default_global = {"learning": True, "model_number": 0, "algo_number": 0, "threshold": 0.90} self.default_guild = { "whitelist": None, "days": 1, @@ -79,7 +83,7 @@ class Chatter(Cog): self.tagger_language = ENG_SM self.similarity_algo = SpacySimilarity self.similarity_threshold = 0.90 - self.chatbot = self._create_chatbot() + self.chatbot = None # self.chatbot.set_trainer(ListTrainer) # self.trainer = ListTrainer(self.chatbot) @@ -98,6 +102,18 @@ class Chatter(Cog): """Nothing to delete""" return + async def initialize(self): + all_config = dict(self.config.defaults["GLOBAL"]) + all_config.update(await self.config.all()) + model_number = all_config["model_number"] + algo_number = all_config["algo_number"] + threshold = all_config["threshold"] + + self.tagger_language = self.models[model_number] + self.similarity_algo = self.algos[algo_number] + self.similarity_threshold = threshold + self.chatbot = self._create_chatbot() + def _create_chatbot(self): return ChatBot( @@ -110,7 +126,7 @@ class Chatter(Cog): logic_adapters=["chatterbot.logic.BestMatch"], maximum_similarity_threshold=self.similarity_threshold, tagger_language=self.tagger_language, - logger=log, + logger=chatterbot_log, ) async def _get_conversation(self, ctx, in_channels: List[discord.TextChannel]): @@ -334,15 +350,12 @@ class Chatter(Cog): self, ctx: commands.Context, algo_number: int, threshold: float = None ): """ - Switch the active logic algorithm to one of the three. Default after reload is Spacy + Switch the active logic algorithm to one of the three. Default is Spacy 0: Spacy 1: Jaccard 2: Levenshtein """ - - algos = [SpacySimilarity, JaccardSimilarity, LevenshteinDistance] - if algo_number < 0 or algo_number > 2: await ctx.send_help() return @@ -355,8 +368,11 @@ class Chatter(Cog): return else: self.similarity_threshold = threshold + await self.config.threshold.set(self.similarity_threshold) + + self.similarity_algo = self.algos[algo_number] + await self.config.algo_number.set(algo_number) - self.similarity_algo = algos[algo_number] async with ctx.typing(): self.chatbot = self._create_chatbot() @@ -366,21 +382,18 @@ class Chatter(Cog): @chatter.command(name="model") async def chatter_model(self, ctx: commands.Context, model_number: int): """ - Switch the active model to one of the three. Default after reload is Medium + Switch the active model to one of the three. Default is Small 0: Small - 1: Medium + 1: Medium (Requires additional setup) 2: Large (Requires additional setup) 3. Accurate (Requires additional setup) """ - - models = [ENG_SM, ENG_MD, ENG_LG, ENG_TRF] - if model_number < 0 or model_number > 3: await ctx.send_help() return - if model_number == 2: + if model_number >= 0: await ctx.maybe_send_embed( "Additional requirements needed. See guide before continuing.\n" "Continue?" ) @@ -393,7 +406,8 @@ class Chatter(Cog): if not pred.result: return - self.tagger_language = models[model_number] + self.tagger_language = self.models[model_number] + await self.config.model_number.set(model_number) async with ctx.typing(): self.chatbot = self._create_chatbot() From a5ff888f4c19c1297a47f15c04f9891cb4e29c7d Mon Sep 17 00:00:00 2001 From: bobloy Date: Thu, 8 Jul 2021 08:33:29 -0400 Subject: [PATCH 2/2] black reformat --- chatter/storage_adapters.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/chatter/storage_adapters.py b/chatter/storage_adapters.py index b2dc02a..706f96f 100644 --- a/chatter/storage_adapters.py +++ b/chatter/storage_adapters.py @@ -18,9 +18,7 @@ class MyDumbSQLStorageAdapter(SQLStorageAdapter): if not self.database_uri: self.database_uri = "sqlite:///db.sqlite3" - self.engine = create_engine( - self.database_uri, connect_args={"check_same_thread": False} - ) + self.engine = create_engine(self.database_uri, connect_args={"check_same_thread": False}) if self.database_uri.startswith("sqlite://"): from sqlalchemy.engine import Engine @@ -31,7 +29,7 @@ class MyDumbSQLStorageAdapter(SQLStorageAdapter): dbapi_connection.execute("PRAGMA journal_mode=WAL") dbapi_connection.execute("PRAGMA synchronous=NORMAL") - if not inspect(self.engine).has_table('Statement'): + if not inspect(self.engine).has_table("Statement"): self.create_database() self.Session = sessionmaker(bind=self.engine, expire_on_commit=True)