Logger and Ubuntu trainer

pull/118/head
bobloy 5 years ago
parent 6e9d31df03
commit a98eb75c0f

@ -1,4 +1,5 @@
import asyncio import asyncio
import logging
import os import os
import pathlib import pathlib
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -7,11 +8,13 @@ import discord
from chatterbot import ChatBot from chatterbot import ChatBot
from chatterbot.comparisons import JaccardSimilarity, LevenshteinDistance, SpacySimilarity from chatterbot.comparisons import JaccardSimilarity, LevenshteinDistance, SpacySimilarity
from chatterbot.response_selection import get_random_response from chatterbot.response_selection import get_random_response
from chatterbot.trainers import ChatterBotCorpusTrainer, ListTrainer from chatterbot.trainers import ChatterBotCorpusTrainer, ListTrainer, UbuntuCorpusTrainer
from redbot.core import Config, commands from redbot.core import Config, commands
from redbot.core.commands import Cog from redbot.core.commands import Cog
from redbot.core.data_manager import cog_data_path from redbot.core.data_manager import cog_data_path
log = logging.getLogger("red.fox_v3.chat")
class ENG_LG: # TODO: Add option to use this large model class ENG_LG: # TODO: Add option to use this large model
ISO_639_1 = "en_core_web_lg" ISO_639_1 = "en_core_web_lg"
@ -50,7 +53,7 @@ class Chatter(Cog):
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
def _create_chatbot( def _create_chatbot(
self, data_path, similarity_algorithm, similarity_threshold, tagger_language self, data_path, similarity_algorithm, similarity_threshold, tagger_language
): ):
return ChatBot( return ChatBot(
"ChatterBot", "ChatterBot",
@ -61,6 +64,7 @@ class Chatter(Cog):
logic_adapters=["chatterbot.logic.BestMatch"], logic_adapters=["chatterbot.logic.BestMatch"],
# maximum_similarity_threshold=similarity_threshold, # maximum_similarity_threshold=similarity_threshold,
tagger_language=tagger_language, tagger_language=tagger_language,
logger=log,
) )
async def _get_conversation(self, ctx, in_channel: discord.TextChannel = None): async def _get_conversation(self, ctx, in_channel: discord.TextChannel = None):
@ -99,7 +103,7 @@ class Chatter(Cog):
try: try:
async for message in channel.history( async for message in channel.history(
limit=None, after=after, oldest_first=True limit=None, after=after, oldest_first=True
).filter( ).filter(
predicate=predicate predicate=predicate
): # type: discord.Message ): # type: discord.Message
@ -130,6 +134,11 @@ class Chatter(Cog):
return out return out
def _train_ubuntu(self):
trainer = UbuntuCorpusTrainer(self.chatbot)
trainer.train()
return True
def _train_english(self): def _train_english(self):
trainer = ChatterBotCorpusTrainer(self.chatbot) trainer = ChatterBotCorpusTrainer(self.chatbot)
# try: # try:
@ -182,7 +191,9 @@ class Chatter(Cog):
try: try:
os.remove(self.data_path) os.remove(self.data_path)
except PermissionError: except PermissionError:
await ctx.maybe_send_embed("Failed to clear training database. Please wait a bit and try again") await ctx.maybe_send_embed(
"Failed to clear training database. Please wait a bit and try again"
)
self._create_chatbot(self.data_path, SpacySimilarity, 0.45, ENG_MD) self._create_chatbot(self.data_path, SpacySimilarity, 0.45, ENG_MD)
@ -260,6 +271,19 @@ class Chatter(Cog):
else: else:
await ctx.send("Error occurred :(") await ctx.send("Error occurred :(")
@chatter.command(name="trainubuntu")
async def chatter_train_ubuntu(self, ctx: commands.Context):
"""
WARNING: Large Download! Trains the bot using Ubuntu Dialog Corpus data.
"""
async with ctx.typing():
future = await self.loop.run_in_executor(None, self._train_ubuntu)
if future:
await ctx.send("Training successful!")
else:
await ctx.send("Error occurred :(")
@chatter.command(name="trainenglish") @chatter.command(name="trainenglish")
async def chatter_train_english(self, ctx: commands.Context): async def chatter_train_english(self, ctx: commands.Context):
""" """

Loading…
Cancel
Save