From 4fcc12a2d8a04f4d9da7c87497417c0fb56fe474 Mon Sep 17 00:00:00 2001 From: bobloy Date: Thu, 13 Aug 2020 12:02:54 -0400 Subject: [PATCH] Allow specifying algorithm and model --- chatter/README.md | 31 ++++++++++++++ chatter/chat.py | 100 ++++++++++++++++++++++++++++++++++++---------- 2 files changed, 111 insertions(+), 20 deletions(-) diff --git a/chatter/README.md b/chatter/README.md index 933162a..5d5446f 100644 --- a/chatter/README.md +++ b/chatter/README.md @@ -167,7 +167,38 @@ settings. This can take a long time to process. ``` [p]chatter algorithm X ``` +or +``` +[p]chatter algo X 0.95 +``` Chatter can be configured to use one of three different Similarity algorithms. Changing this can help if the response speed is too slow, but can reduce the accuracy of results. + +The second argument is the minimum similarity threshold, +raising this will make the bot me more selective with the responses it finds. + +Default minimum similarity threshold is 0.90 + + +## Switching Pretrained Models + +``` +[p]chatter model X +``` + +Chatter can be configured to use one of three pretrained statistical models for English. + +I have not noticed any advantage to changing this, +but supposedly it would help by splitting the search term into more useful parts. + +See [here](https://spacy.io/models) for more info on spaCy models. + +Before you're able to use the *large* model (option 3), you must install it through pip. + +*Warning:* This is ~800MB download. + +``` +[p]pipinstall https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-2.3.1/en_core_web_lg-2.3.1.tar.gz#egg=en_core_web_lg +``` diff --git a/chatter/chat.py b/chatter/chat.py index 886efc3..76ee56c 100644 --- a/chatter/chat.py +++ b/chatter/chat.py @@ -12,11 +12,12 @@ from chatterbot.trainers import ChatterBotCorpusTrainer, ListTrainer, UbuntuCorp from redbot.core import Config, commands from redbot.core.commands import Cog from redbot.core.data_manager import cog_data_path +from redbot.core.utils.predicates import MessagePredicate log = logging.getLogger("red.fox_v3.chat") -class ENG_LG: # TODO: Add option to use this large model +class ENG_LG: ISO_639_1 = "en_core_web_lg" ISO_639 = "eng" ENGLISH_NAME = "English" @@ -28,6 +29,12 @@ class ENG_MD: ENGLISH_NAME = "English" +class ENG_SM: + ISO_639_1 = "en_core_web_sm" + ISO_639 = "eng" + ENGLISH_NAME = "English" + + class Chatter(Cog): """ This cog trains a chatbot that will talk like members of your Guild @@ -42,7 +49,13 @@ class Chatter(Cog): path: pathlib.Path = cog_data_path(self) self.data_path = path / "database.sqlite3" - self.chatbot = self._create_chatbot(self.data_path, SpacySimilarity, 0.45, ENG_MD) + # TODO: Move training_model and similarity_algo to config + # TODO: Add an option to see current settings + + self.tagger_language = ENG_MD + self.similarity_algo = SpacySimilarity + self.similarity_threshold = 0.90 + self.chatbot = self._create_chatbot() # self.chatbot.set_trainer(ListTrainer) # self.trainer = ListTrainer(self.chatbot) @@ -52,18 +65,17 @@ class Chatter(Cog): self.loop = asyncio.get_event_loop() - def _create_chatbot( - self, data_path, similarity_algorithm, similarity_threshold, tagger_language - ): + def _create_chatbot(self): + return ChatBot( "ChatterBot", storage_adapter="chatterbot.storage.SQLStorageAdapter", - database_uri="sqlite:///" + str(data_path), - statement_comparison_function=similarity_algorithm, + database_uri="sqlite:///" + str(self.data_path), + statement_comparison_function=self.similarity_algo, response_selection_method=get_random_response, logic_adapters=["chatterbot.logic.BestMatch"], - # maximum_similarity_threshold=similarity_threshold, - tagger_language=tagger_language, + maximum_similarity_threshold=self.similarity_threshold, + tagger_language=self.tagger_language, logger=log, ) @@ -103,7 +115,7 @@ class Chatter(Cog): try: async for message in channel.history( - limit=None, after=after, oldest_first=True + limit=None, after=after, oldest_first=True ).filter( predicate=predicate ): # type: discord.Message @@ -195,12 +207,14 @@ class Chatter(Cog): "Failed to clear training database. Please wait a bit and try again" ) - self._create_chatbot(self.data_path, SpacySimilarity, 0.45, ENG_MD) + self._create_chatbot() await ctx.tick() - @chatter.command(name="algorithm") - async def chatter_algorithm(self, ctx: commands.Context, algo_number: int): + @chatter.command(name="algorithm", aliases=["algo"]) + async def chatter_algorithm( + self, ctx: commands.Context, algo_number: int, threshold: float = None + ): """ Switch the active logic algorithm to one of the three. Default after reload is Spacy @@ -209,17 +223,61 @@ class Chatter(Cog): 2: Levenshtein """ - algos = [(SpacySimilarity, 0.45), (JaccardSimilarity, 0.75), (LevenshteinDistance, 0.75)] + algos = [SpacySimilarity, JaccardSimilarity, LevenshteinDistance] if algo_number < 0 or algo_number > 2: await ctx.send_help() return - self.chatbot = self._create_chatbot( - self.data_path, algos[algo_number][0], algos[algo_number][1], ENG_MD - ) + if threshold is not None: + if threshold >= 1 or threshold <= 0: + await ctx.maybe_send_embed( + "Threshold must be a number between 0 and 1 (exclusive)" + ) + return + else: + self.similarity_algo = threshold - await ctx.tick() + self.similarity_algo = algos[algo_number] + async with ctx.typing(): + self.chatbot = self._create_chatbot() + + await ctx.tick() + + @chatter.command(name="model") + async def chatter_model(self, ctx: commands.Context, model_number: int): + """ + Switch the active model to one of the three. Default after reload is Medium + + 0: Small + 1: Medium + 2: Large (Requires additional setup) + """ + + models = [ENG_SM, ENG_MD, ENG_LG] + + if model_number < 0 or model_number > 2: + await ctx.send_help() + return + + if model_number == 2: + await ctx.maybe_send_embed( + "Additional requirements needed. See guide before continuing.\n" "Continue?" + ) + pred = MessagePredicate.yes_or_no(ctx) + try: + await self.bot.wait_for("message", check=pred, timeout=30) + except TimeoutError: + await ctx.send("Response timed out, please try again later.") + return + if not pred.result: + return + + self.tagger_language = models[model_number] + async with ctx.typing(): + self.chatbot = self._create_chatbot() + + await ctx.maybe_send_embed(f"Model has been switched to {self.tagger_language.ISO_639_1}") @chatter.command(name="minutes") async def minutes(self, ctx: commands.Context, minutes: int): @@ -278,8 +336,10 @@ class Chatter(Cog): """ if not confirmation: - await ctx.maybe_send_embed("Warning: This command downloads ~500MB then eats your CPU for training\n" - "If you're sure you want to continue, run `[p]chatter trainubuntu True`") + await ctx.maybe_send_embed( + "Warning: This command downloads ~500MB then eats your CPU for training\n" + "If you're sure you want to continue, run `[p]chatter trainubuntu True`" + ) return async with ctx.typing():