Logger and Ubuntu trainer
This commit is contained in:
parent
6e9d31df03
commit
a98eb75c0f
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
from datetime import datetime, timedelta
|
||||
@ -7,11 +8,13 @@ import discord
|
||||
from chatterbot import ChatBot
|
||||
from chatterbot.comparisons import JaccardSimilarity, LevenshteinDistance, SpacySimilarity
|
||||
from chatterbot.response_selection import get_random_response
|
||||
from chatterbot.trainers import ChatterBotCorpusTrainer, ListTrainer
|
||||
from chatterbot.trainers import ChatterBotCorpusTrainer, ListTrainer, UbuntuCorpusTrainer
|
||||
from redbot.core import Config, commands
|
||||
from redbot.core.commands import Cog
|
||||
from redbot.core.data_manager import cog_data_path
|
||||
|
||||
log = logging.getLogger("red.fox_v3.chat")
|
||||
|
||||
|
||||
class ENG_LG: # TODO: Add option to use this large model
|
||||
ISO_639_1 = "en_core_web_lg"
|
||||
@ -50,7 +53,7 @@ class Chatter(Cog):
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
def _create_chatbot(
|
||||
self, data_path, similarity_algorithm, similarity_threshold, tagger_language
|
||||
self, data_path, similarity_algorithm, similarity_threshold, tagger_language
|
||||
):
|
||||
return ChatBot(
|
||||
"ChatterBot",
|
||||
@ -61,6 +64,7 @@ class Chatter(Cog):
|
||||
logic_adapters=["chatterbot.logic.BestMatch"],
|
||||
# maximum_similarity_threshold=similarity_threshold,
|
||||
tagger_language=tagger_language,
|
||||
logger=log,
|
||||
)
|
||||
|
||||
async def _get_conversation(self, ctx, in_channel: discord.TextChannel = None):
|
||||
@ -99,7 +103,7 @@ class Chatter(Cog):
|
||||
try:
|
||||
|
||||
async for message in channel.history(
|
||||
limit=None, after=after, oldest_first=True
|
||||
limit=None, after=after, oldest_first=True
|
||||
).filter(
|
||||
predicate=predicate
|
||||
): # type: discord.Message
|
||||
@ -130,6 +134,11 @@ class Chatter(Cog):
|
||||
|
||||
return out
|
||||
|
||||
def _train_ubuntu(self):
|
||||
trainer = UbuntuCorpusTrainer(self.chatbot)
|
||||
trainer.train()
|
||||
return True
|
||||
|
||||
def _train_english(self):
|
||||
trainer = ChatterBotCorpusTrainer(self.chatbot)
|
||||
# try:
|
||||
@ -182,7 +191,9 @@ class Chatter(Cog):
|
||||
try:
|
||||
os.remove(self.data_path)
|
||||
except PermissionError:
|
||||
await ctx.maybe_send_embed("Failed to clear training database. Please wait a bit and try again")
|
||||
await ctx.maybe_send_embed(
|
||||
"Failed to clear training database. Please wait a bit and try again"
|
||||
)
|
||||
|
||||
self._create_chatbot(self.data_path, SpacySimilarity, 0.45, ENG_MD)
|
||||
|
||||
@ -260,6 +271,19 @@ class Chatter(Cog):
|
||||
else:
|
||||
await ctx.send("Error occurred :(")
|
||||
|
||||
@chatter.command(name="trainubuntu")
|
||||
async def chatter_train_ubuntu(self, ctx: commands.Context):
|
||||
"""
|
||||
WARNING: Large Download! Trains the bot using Ubuntu Dialog Corpus data.
|
||||
"""
|
||||
async with ctx.typing():
|
||||
future = await self.loop.run_in_executor(None, self._train_ubuntu)
|
||||
|
||||
if future:
|
||||
await ctx.send("Training successful!")
|
||||
else:
|
||||
await ctx.send("Error occurred :(")
|
||||
|
||||
@chatter.command(name="trainenglish")
|
||||
async def chatter_train_english(self, ctx: commands.Context):
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user