diff --git a/chatter/chat.py b/chatter/chat.py index f7e8944..1473ad3 100644 --- a/chatter/chat.py +++ b/chatter/chat.py @@ -1,4 +1,5 @@ import asyncio +import logging import os import pathlib from datetime import datetime, timedelta @@ -7,11 +8,13 @@ import discord from chatterbot import ChatBot from chatterbot.comparisons import JaccardSimilarity, LevenshteinDistance, SpacySimilarity from chatterbot.response_selection import get_random_response -from chatterbot.trainers import ChatterBotCorpusTrainer, ListTrainer +from chatterbot.trainers import ChatterBotCorpusTrainer, ListTrainer, UbuntuCorpusTrainer from redbot.core import Config, commands from redbot.core.commands import Cog from redbot.core.data_manager import cog_data_path +log = logging.getLogger("red.fox_v3.chat") + class ENG_LG: # TODO: Add option to use this large model ISO_639_1 = "en_core_web_lg" @@ -50,7 +53,7 @@ class Chatter(Cog): self.loop = asyncio.get_event_loop() def _create_chatbot( - self, data_path, similarity_algorithm, similarity_threshold, tagger_language + self, data_path, similarity_algorithm, similarity_threshold, tagger_language ): return ChatBot( "ChatterBot", @@ -61,6 +64,7 @@ class Chatter(Cog): logic_adapters=["chatterbot.logic.BestMatch"], # maximum_similarity_threshold=similarity_threshold, tagger_language=tagger_language, + logger=log, ) async def _get_conversation(self, ctx, in_channel: discord.TextChannel = None): @@ -99,7 +103,7 @@ class Chatter(Cog): try: async for message in channel.history( - limit=None, after=after, oldest_first=True + limit=None, after=after, oldest_first=True ).filter( predicate=predicate ): # type: discord.Message @@ -130,6 +134,11 @@ class Chatter(Cog): return out + def _train_ubuntu(self): + trainer = UbuntuCorpusTrainer(self.chatbot) + trainer.train() + return True + def _train_english(self): trainer = ChatterBotCorpusTrainer(self.chatbot) # try: @@ -182,7 +191,9 @@ class Chatter(Cog): try: os.remove(self.data_path) except PermissionError: - await ctx.maybe_send_embed("Failed to clear training database. Please wait a bit and try again") + await ctx.maybe_send_embed( + "Failed to clear training database. Please wait a bit and try again" + ) self._create_chatbot(self.data_path, SpacySimilarity, 0.45, ENG_MD) @@ -260,6 +271,19 @@ class Chatter(Cog): else: await ctx.send("Error occurred :(") + @chatter.command(name="trainubuntu") + async def chatter_train_ubuntu(self, ctx: commands.Context): + """ + WARNING: Large Download! Trains the bot using Ubuntu Dialog Corpus data. + """ + async with ctx.typing(): + future = await self.loop.run_in_executor(None, self._train_ubuntu) + + if future: + await ctx.send("Training successful!") + else: + await ctx.send("Error occurred :(") + @chatter.command(name="trainenglish") async def chatter_train_english(self, ctx: commands.Context): """