Logger and Ubuntu trainer
This commit is contained in:
		
							parent
							
								
									6e9d31df03
								
							
						
					
					
						commit
						a98eb75c0f
					
				@ -1,4 +1,5 @@
 | 
				
			|||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import pathlib
 | 
					import pathlib
 | 
				
			||||||
from datetime import datetime, timedelta
 | 
					from datetime import datetime, timedelta
 | 
				
			||||||
@ -7,11 +8,13 @@ import discord
 | 
				
			|||||||
from chatterbot import ChatBot
 | 
					from chatterbot import ChatBot
 | 
				
			||||||
from chatterbot.comparisons import JaccardSimilarity, LevenshteinDistance, SpacySimilarity
 | 
					from chatterbot.comparisons import JaccardSimilarity, LevenshteinDistance, SpacySimilarity
 | 
				
			||||||
from chatterbot.response_selection import get_random_response
 | 
					from chatterbot.response_selection import get_random_response
 | 
				
			||||||
from chatterbot.trainers import ChatterBotCorpusTrainer, ListTrainer
 | 
					from chatterbot.trainers import ChatterBotCorpusTrainer, ListTrainer, UbuntuCorpusTrainer
 | 
				
			||||||
from redbot.core import Config, commands
 | 
					from redbot.core import Config, commands
 | 
				
			||||||
from redbot.core.commands import Cog
 | 
					from redbot.core.commands import Cog
 | 
				
			||||||
from redbot.core.data_manager import cog_data_path
 | 
					from redbot.core.data_manager import cog_data_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					log = logging.getLogger("red.fox_v3.chat")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ENG_LG:  # TODO: Add option to use this large model
 | 
					class ENG_LG:  # TODO: Add option to use this large model
 | 
				
			||||||
    ISO_639_1 = "en_core_web_lg"
 | 
					    ISO_639_1 = "en_core_web_lg"
 | 
				
			||||||
@ -50,7 +53,7 @@ class Chatter(Cog):
 | 
				
			|||||||
        self.loop = asyncio.get_event_loop()
 | 
					        self.loop = asyncio.get_event_loop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _create_chatbot(
 | 
					    def _create_chatbot(
 | 
				
			||||||
            self, data_path, similarity_algorithm, similarity_threshold, tagger_language
 | 
					        self, data_path, similarity_algorithm, similarity_threshold, tagger_language
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        return ChatBot(
 | 
					        return ChatBot(
 | 
				
			||||||
            "ChatterBot",
 | 
					            "ChatterBot",
 | 
				
			||||||
@ -61,6 +64,7 @@ class Chatter(Cog):
 | 
				
			|||||||
            logic_adapters=["chatterbot.logic.BestMatch"],
 | 
					            logic_adapters=["chatterbot.logic.BestMatch"],
 | 
				
			||||||
            # maximum_similarity_threshold=similarity_threshold,
 | 
					            # maximum_similarity_threshold=similarity_threshold,
 | 
				
			||||||
            tagger_language=tagger_language,
 | 
					            tagger_language=tagger_language,
 | 
				
			||||||
 | 
					            logger=log,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def _get_conversation(self, ctx, in_channel: discord.TextChannel = None):
 | 
					    async def _get_conversation(self, ctx, in_channel: discord.TextChannel = None):
 | 
				
			||||||
@ -99,7 +103,7 @@ class Chatter(Cog):
 | 
				
			|||||||
            try:
 | 
					            try:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                async for message in channel.history(
 | 
					                async for message in channel.history(
 | 
				
			||||||
                        limit=None, after=after, oldest_first=True
 | 
					                    limit=None, after=after, oldest_first=True
 | 
				
			||||||
                ).filter(
 | 
					                ).filter(
 | 
				
			||||||
                    predicate=predicate
 | 
					                    predicate=predicate
 | 
				
			||||||
                ):  # type: discord.Message
 | 
					                ):  # type: discord.Message
 | 
				
			||||||
@ -130,6 +134,11 @@ class Chatter(Cog):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        return out
 | 
					        return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _train_ubuntu(self):
 | 
				
			||||||
 | 
					        trainer = UbuntuCorpusTrainer(self.chatbot)
 | 
				
			||||||
 | 
					        trainer.train()
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _train_english(self):
 | 
					    def _train_english(self):
 | 
				
			||||||
        trainer = ChatterBotCorpusTrainer(self.chatbot)
 | 
					        trainer = ChatterBotCorpusTrainer(self.chatbot)
 | 
				
			||||||
        # try:
 | 
					        # try:
 | 
				
			||||||
@ -182,7 +191,9 @@ class Chatter(Cog):
 | 
				
			|||||||
                try:
 | 
					                try:
 | 
				
			||||||
                    os.remove(self.data_path)
 | 
					                    os.remove(self.data_path)
 | 
				
			||||||
                except PermissionError:
 | 
					                except PermissionError:
 | 
				
			||||||
                    await ctx.maybe_send_embed("Failed to clear training database. Please wait a bit and try again")
 | 
					                    await ctx.maybe_send_embed(
 | 
				
			||||||
 | 
					                        "Failed to clear training database. Please wait a bit and try again"
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            self._create_chatbot(self.data_path, SpacySimilarity, 0.45, ENG_MD)
 | 
					            self._create_chatbot(self.data_path, SpacySimilarity, 0.45, ENG_MD)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -260,6 +271,19 @@ class Chatter(Cog):
 | 
				
			|||||||
        else:
 | 
					        else:
 | 
				
			||||||
            await ctx.send("Error occurred :(")
 | 
					            await ctx.send("Error occurred :(")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @chatter.command(name="trainubuntu")
 | 
				
			||||||
 | 
					    async def chatter_train_ubuntu(self, ctx: commands.Context):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        WARNING: Large Download! Trains the bot using Ubuntu Dialog Corpus data.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        async with ctx.typing():
 | 
				
			||||||
 | 
					            future = await self.loop.run_in_executor(None, self._train_ubuntu)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if future:
 | 
				
			||||||
 | 
					            await ctx.send("Training successful!")
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            await ctx.send("Error occurred :(")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @chatter.command(name="trainenglish")
 | 
					    @chatter.command(name="trainenglish")
 | 
				
			||||||
    async def chatter_train_english(self, ctx: commands.Context):
 | 
					    async def chatter_train_english(self, ctx: commands.Context):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user