|
|
|
@ -1,4 +1,5 @@
|
|
|
|
|
import asyncio
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
import pathlib
|
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
@ -7,13 +8,16 @@ import discord
|
|
|
|
|
from chatterbot import ChatBot
|
|
|
|
|
from chatterbot.comparisons import JaccardSimilarity, LevenshteinDistance, SpacySimilarity
|
|
|
|
|
from chatterbot.response_selection import get_random_response
|
|
|
|
|
from chatterbot.trainers import ChatterBotCorpusTrainer, ListTrainer
|
|
|
|
|
from chatterbot.trainers import ChatterBotCorpusTrainer, ListTrainer, UbuntuCorpusTrainer
|
|
|
|
|
from redbot.core import Config, commands
|
|
|
|
|
from redbot.core.commands import Cog
|
|
|
|
|
from redbot.core.data_manager import cog_data_path
|
|
|
|
|
from redbot.core.utils.predicates import MessagePredicate
|
|
|
|
|
|
|
|
|
|
log = logging.getLogger("red.fox_v3.chat")
|
|
|
|
|
|
|
|
|
|
class ENG_LG: # TODO: Add option to use this large model
|
|
|
|
|
|
|
|
|
|
class ENG_LG:
|
|
|
|
|
ISO_639_1 = "en_core_web_lg"
|
|
|
|
|
ISO_639 = "eng"
|
|
|
|
|
ENGLISH_NAME = "English"
|
|
|
|
@ -25,6 +29,12 @@ class ENG_MD:
|
|
|
|
|
ENGLISH_NAME = "English"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ENG_SM:
|
|
|
|
|
ISO_639_1 = "en_core_web_sm"
|
|
|
|
|
ISO_639 = "eng"
|
|
|
|
|
ENGLISH_NAME = "English"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Chatter(Cog):
|
|
|
|
|
"""
|
|
|
|
|
This cog trains a chatbot that will talk like members of your Guild
|
|
|
|
@ -39,7 +49,13 @@ class Chatter(Cog):
|
|
|
|
|
path: pathlib.Path = cog_data_path(self)
|
|
|
|
|
self.data_path = path / "database.sqlite3"
|
|
|
|
|
|
|
|
|
|
self.chatbot = self._create_chatbot(self.data_path, SpacySimilarity, 0.45, ENG_MD)
|
|
|
|
|
# TODO: Move training_model and similarity_algo to config
|
|
|
|
|
# TODO: Add an option to see current settings
|
|
|
|
|
|
|
|
|
|
self.tagger_language = ENG_MD
|
|
|
|
|
self.similarity_algo = SpacySimilarity
|
|
|
|
|
self.similarity_threshold = 0.90
|
|
|
|
|
self.chatbot = self._create_chatbot()
|
|
|
|
|
# self.chatbot.set_trainer(ListTrainer)
|
|
|
|
|
|
|
|
|
|
# self.trainer = ListTrainer(self.chatbot)
|
|
|
|
@ -49,18 +65,18 @@ class Chatter(Cog):
|
|
|
|
|
|
|
|
|
|
self.loop = asyncio.get_event_loop()
|
|
|
|
|
|
|
|
|
|
def _create_chatbot(
|
|
|
|
|
self, data_path, similarity_algorithm, similarity_threshold, tagger_language
|
|
|
|
|
):
|
|
|
|
|
def _create_chatbot(self):
|
|
|
|
|
|
|
|
|
|
return ChatBot(
|
|
|
|
|
"ChatterBot",
|
|
|
|
|
storage_adapter="chatterbot.storage.SQLStorageAdapter",
|
|
|
|
|
database_uri="sqlite:///" + str(data_path),
|
|
|
|
|
statement_comparison_function=similarity_algorithm,
|
|
|
|
|
database_uri="sqlite:///" + str(self.data_path),
|
|
|
|
|
statement_comparison_function=self.similarity_algo,
|
|
|
|
|
response_selection_method=get_random_response,
|
|
|
|
|
logic_adapters=["chatterbot.logic.BestMatch"],
|
|
|
|
|
# maximum_similarity_threshold=similarity_threshold,
|
|
|
|
|
tagger_language=tagger_language,
|
|
|
|
|
maximum_similarity_threshold=self.similarity_threshold,
|
|
|
|
|
tagger_language=self.tagger_language,
|
|
|
|
|
logger=log,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def _get_conversation(self, ctx, in_channel: discord.TextChannel = None):
|
|
|
|
@ -130,6 +146,11 @@ class Chatter(Cog):
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def _train_ubuntu(self):
|
|
|
|
|
trainer = UbuntuCorpusTrainer(self.chatbot)
|
|
|
|
|
trainer.train()
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def _train_english(self):
|
|
|
|
|
trainer = ChatterBotCorpusTrainer(self.chatbot)
|
|
|
|
|
# try:
|
|
|
|
@ -182,14 +203,18 @@ class Chatter(Cog):
|
|
|
|
|
try:
|
|
|
|
|
os.remove(self.data_path)
|
|
|
|
|
except PermissionError:
|
|
|
|
|
await ctx.maybe_send_embed("Failed to clear training database. Please wait a bit and try again")
|
|
|
|
|
await ctx.maybe_send_embed(
|
|
|
|
|
"Failed to clear training database. Please wait a bit and try again"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self._create_chatbot(self.data_path, SpacySimilarity, 0.45, ENG_MD)
|
|
|
|
|
self._create_chatbot()
|
|
|
|
|
|
|
|
|
|
await ctx.tick()
|
|
|
|
|
|
|
|
|
|
@chatter.command(name="algorithm")
|
|
|
|
|
async def chatter_algorithm(self, ctx: commands.Context, algo_number: int):
|
|
|
|
|
@chatter.command(name="algorithm", aliases=["algo"])
|
|
|
|
|
async def chatter_algorithm(
|
|
|
|
|
self, ctx: commands.Context, algo_number: int, threshold: float = None
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Switch the active logic algorithm to one of the three. Default after reload is Spacy
|
|
|
|
|
|
|
|
|
@ -198,17 +223,61 @@ class Chatter(Cog):
|
|
|
|
|
2: Levenshtein
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
algos = [(SpacySimilarity, 0.45), (JaccardSimilarity, 0.75), (LevenshteinDistance, 0.75)]
|
|
|
|
|
algos = [SpacySimilarity, JaccardSimilarity, LevenshteinDistance]
|
|
|
|
|
|
|
|
|
|
if algo_number < 0 or algo_number > 2:
|
|
|
|
|
await ctx.send_help()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
self.chatbot = self._create_chatbot(
|
|
|
|
|
self.data_path, algos[algo_number][0], algos[algo_number][1], ENG_MD
|
|
|
|
|
)
|
|
|
|
|
if threshold is not None:
|
|
|
|
|
if threshold >= 1 or threshold <= 0:
|
|
|
|
|
await ctx.maybe_send_embed(
|
|
|
|
|
"Threshold must be a number between 0 and 1 (exclusive)"
|
|
|
|
|
)
|
|
|
|
|
return
|
|
|
|
|
else:
|
|
|
|
|
self.similarity_algo = threshold
|
|
|
|
|
|
|
|
|
|
await ctx.tick()
|
|
|
|
|
self.similarity_algo = algos[algo_number]
|
|
|
|
|
async with ctx.typing():
|
|
|
|
|
self.chatbot = self._create_chatbot()
|
|
|
|
|
|
|
|
|
|
await ctx.tick()
|
|
|
|
|
|
|
|
|
|
@chatter.command(name="model")
|
|
|
|
|
async def chatter_model(self, ctx: commands.Context, model_number: int):
|
|
|
|
|
"""
|
|
|
|
|
Switch the active model to one of the three. Default after reload is Medium
|
|
|
|
|
|
|
|
|
|
0: Small
|
|
|
|
|
1: Medium
|
|
|
|
|
2: Large (Requires additional setup)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
models = [ENG_SM, ENG_MD, ENG_LG]
|
|
|
|
|
|
|
|
|
|
if model_number < 0 or model_number > 2:
|
|
|
|
|
await ctx.send_help()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if model_number == 2:
|
|
|
|
|
await ctx.maybe_send_embed(
|
|
|
|
|
"Additional requirements needed. See guide before continuing.\n" "Continue?"
|
|
|
|
|
)
|
|
|
|
|
pred = MessagePredicate.yes_or_no(ctx)
|
|
|
|
|
try:
|
|
|
|
|
await self.bot.wait_for("message", check=pred, timeout=30)
|
|
|
|
|
except TimeoutError:
|
|
|
|
|
await ctx.send("Response timed out, please try again later.")
|
|
|
|
|
return
|
|
|
|
|
if not pred.result:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
self.tagger_language = models[model_number]
|
|
|
|
|
async with ctx.typing():
|
|
|
|
|
self.chatbot = self._create_chatbot()
|
|
|
|
|
|
|
|
|
|
await ctx.maybe_send_embed(f"Model has been switched to {self.tagger_language.ISO_639_1}")
|
|
|
|
|
|
|
|
|
|
@chatter.command(name="minutes")
|
|
|
|
|
async def minutes(self, ctx: commands.Context, minutes: int):
|
|
|
|
@ -260,6 +329,27 @@ class Chatter(Cog):
|
|
|
|
|
else:
|
|
|
|
|
await ctx.send("Error occurred :(")
|
|
|
|
|
|
|
|
|
|
@chatter.command(name="trainubuntu")
|
|
|
|
|
async def chatter_train_ubuntu(self, ctx: commands.Context, confirmation: bool = False):
|
|
|
|
|
"""
|
|
|
|
|
WARNING: Large Download! Trains the bot using Ubuntu Dialog Corpus data.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if not confirmation:
|
|
|
|
|
await ctx.maybe_send_embed(
|
|
|
|
|
"Warning: This command downloads ~500MB then eats your CPU for training\n"
|
|
|
|
|
"If you're sure you want to continue, run `[p]chatter trainubuntu True`"
|
|
|
|
|
)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
async with ctx.typing():
|
|
|
|
|
future = await self.loop.run_in_executor(None, self._train_ubuntu)
|
|
|
|
|
|
|
|
|
|
if future:
|
|
|
|
|
await ctx.send("Training successful!")
|
|
|
|
|
else:
|
|
|
|
|
await ctx.send("Error occurred :(")
|
|
|
|
|
|
|
|
|
|
@chatter.command(name="trainenglish")
|
|
|
|
|
async def chatter_train_english(self, ctx: commands.Context):
|
|
|
|
|
"""
|
|
|
|
|