Add chatter initialization
This commit is contained in:
parent
b752bfd153
commit
698dafade4
@ -1,8 +1,10 @@
|
||||
from .chat import Chatter
|
||||
|
||||
|
||||
def setup(bot):
|
||||
bot.add_cog(Chatter(bot))
|
||||
async def setup(bot):
|
||||
cog = Chatter(bot)
|
||||
await cog.initialize()
|
||||
bot.add_cog(cog)
|
||||
|
||||
|
||||
# __all__ = (
|
||||
|
@ -19,6 +19,7 @@ from redbot.core.utils.predicates import MessagePredicate
|
||||
|
||||
from chatter.trainers import MovieTrainer, TwitterCorpusTrainer, UbuntuCorpusTrainer2
|
||||
|
||||
chatterbot_log = logging.getLogger("red.fox_v3.chatterbot")
|
||||
log = logging.getLogger("red.fox_v3.chatter")
|
||||
|
||||
|
||||
@ -58,11 +59,14 @@ class Chatter(Cog):
|
||||
This cog trains a chatbot that will talk like members of your Guild
|
||||
"""
|
||||
|
||||
models = [ENG_SM, ENG_MD, ENG_LG, ENG_TRF]
|
||||
algos = [SpacySimilarity, JaccardSimilarity, LevenshteinDistance]
|
||||
|
||||
def __init__(self, bot):
|
||||
super().__init__()
|
||||
self.bot = bot
|
||||
self.config = Config.get_conf(self, identifier=6710497116116101114)
|
||||
default_global = {"learning": True}
|
||||
default_global = {"learning": True, "model_number": 0, "algo_number": 0, "threshold": 0.90}
|
||||
self.default_guild = {
|
||||
"whitelist": None,
|
||||
"days": 1,
|
||||
@ -79,7 +83,7 @@ class Chatter(Cog):
|
||||
self.tagger_language = ENG_SM
|
||||
self.similarity_algo = SpacySimilarity
|
||||
self.similarity_threshold = 0.90
|
||||
self.chatbot = self._create_chatbot()
|
||||
self.chatbot = None
|
||||
# self.chatbot.set_trainer(ListTrainer)
|
||||
|
||||
# self.trainer = ListTrainer(self.chatbot)
|
||||
@ -98,6 +102,18 @@ class Chatter(Cog):
|
||||
"""Nothing to delete"""
|
||||
return
|
||||
|
||||
async def initialize(self):
|
||||
all_config = dict(self.config.defaults["GLOBAL"])
|
||||
all_config.update(await self.config.all())
|
||||
model_number = all_config["model_number"]
|
||||
algo_number = all_config["algo_number"]
|
||||
threshold = all_config["threshold"]
|
||||
|
||||
self.tagger_language = self.models[model_number]
|
||||
self.similarity_algo = self.algos[algo_number]
|
||||
self.similarity_threshold = threshold
|
||||
self.chatbot = self._create_chatbot()
|
||||
|
||||
def _create_chatbot(self):
|
||||
|
||||
return ChatBot(
|
||||
@ -110,7 +126,7 @@ class Chatter(Cog):
|
||||
logic_adapters=["chatterbot.logic.BestMatch"],
|
||||
maximum_similarity_threshold=self.similarity_threshold,
|
||||
tagger_language=self.tagger_language,
|
||||
logger=log,
|
||||
logger=chatterbot_log,
|
||||
)
|
||||
|
||||
async def _get_conversation(self, ctx, in_channels: List[discord.TextChannel]):
|
||||
@ -334,15 +350,12 @@ class Chatter(Cog):
|
||||
self, ctx: commands.Context, algo_number: int, threshold: float = None
|
||||
):
|
||||
"""
|
||||
Switch the active logic algorithm to one of the three. Default after reload is Spacy
|
||||
Switch the active logic algorithm to one of the three. Default is Spacy
|
||||
|
||||
0: Spacy
|
||||
1: Jaccard
|
||||
2: Levenshtein
|
||||
"""
|
||||
|
||||
algos = [SpacySimilarity, JaccardSimilarity, LevenshteinDistance]
|
||||
|
||||
if algo_number < 0 or algo_number > 2:
|
||||
await ctx.send_help()
|
||||
return
|
||||
@ -355,8 +368,11 @@ class Chatter(Cog):
|
||||
return
|
||||
else:
|
||||
self.similarity_threshold = threshold
|
||||
await self.config.threshold.set(self.similarity_threshold)
|
||||
|
||||
self.similarity_algo = self.algos[algo_number]
|
||||
await self.config.algo_number.set(algo_number)
|
||||
|
||||
self.similarity_algo = algos[algo_number]
|
||||
async with ctx.typing():
|
||||
self.chatbot = self._create_chatbot()
|
||||
|
||||
@ -366,21 +382,18 @@ class Chatter(Cog):
|
||||
@chatter.command(name="model")
|
||||
async def chatter_model(self, ctx: commands.Context, model_number: int):
|
||||
"""
|
||||
Switch the active model to one of the three. Default after reload is Medium
|
||||
Switch the active model to one of the three. Default is Small
|
||||
|
||||
0: Small
|
||||
1: Medium
|
||||
1: Medium (Requires additional setup)
|
||||
2: Large (Requires additional setup)
|
||||
3. Accurate (Requires additional setup)
|
||||
"""
|
||||
|
||||
models = [ENG_SM, ENG_MD, ENG_LG, ENG_TRF]
|
||||
|
||||
if model_number < 0 or model_number > 3:
|
||||
await ctx.send_help()
|
||||
return
|
||||
|
||||
if model_number == 2:
|
||||
if model_number >= 0:
|
||||
await ctx.maybe_send_embed(
|
||||
"Additional requirements needed. See guide before continuing.\n" "Continue?"
|
||||
)
|
||||
@ -393,7 +406,8 @@ class Chatter(Cog):
|
||||
if not pred.result:
|
||||
return
|
||||
|
||||
self.tagger_language = models[model_number]
|
||||
self.tagger_language = self.models[model_number]
|
||||
await self.config.model_number.set(model_number)
|
||||
async with ctx.typing():
|
||||
self.chatbot = self._create_chatbot()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user