Add chatter initialization

pull/198/merge
bobloy 4 years ago
parent b752bfd153
commit 698dafade4

@ -1,8 +1,10 @@
from .chat import Chatter
def setup(bot):
bot.add_cog(Chatter(bot))
async def setup(bot):
cog = Chatter(bot)
await cog.initialize()
bot.add_cog(cog)
# __all__ = (

@ -19,6 +19,7 @@ from redbot.core.utils.predicates import MessagePredicate
from chatter.trainers import MovieTrainer, TwitterCorpusTrainer, UbuntuCorpusTrainer2
chatterbot_log = logging.getLogger("red.fox_v3.chatterbot")
log = logging.getLogger("red.fox_v3.chatter")
@ -58,11 +59,14 @@ class Chatter(Cog):
This cog trains a chatbot that will talk like members of your Guild
"""
models = [ENG_SM, ENG_MD, ENG_LG, ENG_TRF]
algos = [SpacySimilarity, JaccardSimilarity, LevenshteinDistance]
def __init__(self, bot):
super().__init__()
self.bot = bot
self.config = Config.get_conf(self, identifier=6710497116116101114)
default_global = {"learning": True}
default_global = {"learning": True, "model_number": 0, "algo_number": 0, "threshold": 0.90}
self.default_guild = {
"whitelist": None,
"days": 1,
@ -79,7 +83,7 @@ class Chatter(Cog):
self.tagger_language = ENG_SM
self.similarity_algo = SpacySimilarity
self.similarity_threshold = 0.90
self.chatbot = self._create_chatbot()
self.chatbot = None
# self.chatbot.set_trainer(ListTrainer)
# self.trainer = ListTrainer(self.chatbot)
@ -98,6 +102,18 @@ class Chatter(Cog):
"""Nothing to delete"""
return
async def initialize(self):
all_config = dict(self.config.defaults["GLOBAL"])
all_config.update(await self.config.all())
model_number = all_config["model_number"]
algo_number = all_config["algo_number"]
threshold = all_config["threshold"]
self.tagger_language = self.models[model_number]
self.similarity_algo = self.algos[algo_number]
self.similarity_threshold = threshold
self.chatbot = self._create_chatbot()
def _create_chatbot(self):
return ChatBot(
@ -110,7 +126,7 @@ class Chatter(Cog):
logic_adapters=["chatterbot.logic.BestMatch"],
maximum_similarity_threshold=self.similarity_threshold,
tagger_language=self.tagger_language,
logger=log,
logger=chatterbot_log,
)
async def _get_conversation(self, ctx, in_channels: List[discord.TextChannel]):
@ -334,15 +350,12 @@ class Chatter(Cog):
self, ctx: commands.Context, algo_number: int, threshold: float = None
):
"""
Switch the active logic algorithm to one of the three. Default after reload is Spacy
Switch the active logic algorithm to one of the three. Default is Spacy
0: Spacy
1: Jaccard
2: Levenshtein
"""
algos = [SpacySimilarity, JaccardSimilarity, LevenshteinDistance]
if algo_number < 0 or algo_number > 2:
await ctx.send_help()
return
@ -355,8 +368,11 @@ class Chatter(Cog):
return
else:
self.similarity_threshold = threshold
await self.config.threshold.set(self.similarity_threshold)
self.similarity_algo = self.algos[algo_number]
await self.config.algo_number.set(algo_number)
self.similarity_algo = algos[algo_number]
async with ctx.typing():
self.chatbot = self._create_chatbot()
@ -366,21 +382,18 @@ class Chatter(Cog):
@chatter.command(name="model")
async def chatter_model(self, ctx: commands.Context, model_number: int):
"""
Switch the active model to one of the three. Default after reload is Medium
Switch the active model to one of the three. Default is Small
0: Small
1: Medium
1: Medium (Requires additional setup)
2: Large (Requires additional setup)
3. Accurate (Requires additional setup)
"""
models = [ENG_SM, ENG_MD, ENG_LG, ENG_TRF]
if model_number < 0 or model_number > 3:
await ctx.send_help()
return
if model_number == 2:
if model_number >= 0:
await ctx.maybe_send_embed(
"Additional requirements needed. See guide before continuing.\n" "Continue?"
)
@ -393,7 +406,8 @@ class Chatter(Cog):
if not pred.result:
return
self.tagger_language = models[model_number]
self.tagger_language = self.models[model_number]
await self.config.model_number.set(model_number)
async with ctx.typing():
self.chatbot = self._create_chatbot()

Loading…
Cancel
Save