Better conversation grouping and formatting

pull/107/head
bobloy 5 years ago
parent 30564ed306
commit 524b71cc83

@ -4,7 +4,7 @@ from datetime import datetime, timedelta
import discord
from chatterbot import ChatBot
from chatterbot.comparisons import JaccardSimilarity, SpacySimilarity, LevenshteinDistance
from chatterbot.comparisons import JaccardSimilarity, LevenshteinDistance, SpacySimilarity
from chatterbot.response_selection import get_first_response
from chatterbot.trainers import ChatterBotCorpusTrainer, ListTrainer
from redbot.core import Config, commands
@ -13,15 +13,15 @@ from redbot.core.data_manager import cog_data_path
class ENG_LG: # TODO: Add option to use this large model
ISO_639_1 = 'en_core_web_lg'
ISO_639 = 'eng'
ENGLISH_NAME = 'English'
ISO_639_1 = "en_core_web_lg"
ISO_639 = "eng"
ENGLISH_NAME = "English"
class ENG_MD:
ISO_639_1 = 'en_core_web_md'
ISO_639 = 'eng'
ENGLISH_NAME = 'English'
ISO_639_1 = "en_core_web_md"
ISO_639 = "eng"
ENGLISH_NAME = "English"
class Chatter(Cog):
@ -48,7 +48,9 @@ class Chatter(Cog):
self.loop = asyncio.get_event_loop()
def _create_chatbot(self, data_path, similarity_algorithm, similarity_threshold, tagger_language):
def _create_chatbot(
self, data_path, similarity_algorithm, similarity_threshold, tagger_language
):
return ChatBot(
"ChatterBot",
storage_adapter="chatterbot.storage.SQLStorageAdapter",
@ -57,7 +59,7 @@ class Chatter(Cog):
response_selection_method=get_first_response,
logic_adapters=["chatterbot.logic.BestMatch"],
maximum_similarity_threshold=similarity_threshold,
tagger_language=tagger_language
tagger_language=tagger_language,
)
async def _get_conversation(self, ctx, in_channel: discord.TextChannel = None):
@ -70,12 +72,15 @@ class Chatter(Cog):
after = datetime.today() - timedelta(days=(await self.config.guild(ctx.guild).days()))
convo_delta = timedelta(minutes=(await self.config.guild(ctx.guild).convo_delta()))
def new_message(msg, sent, out_in, delta):
if sent is None:
return False
def new_conversation(msg, sent, out_in, delta):
# if sent is None:
# return False
if len(out_in) < 2:
return False
# Don't do "too short" processing here. Sometimes people don't respond.
# if len(out_in) < 2:
# return False
# print(msg.created_at - sent)
return msg.created_at - sent >= delta
@ -85,18 +90,24 @@ class Chatter(Cog):
await ctx.send("Gathering {}".format(channel.mention))
user = None
i = 0
send_time = None
send_time = after - timedelta(days=100) # Makes the first message a new message
try:
async for message in channel.history(limit=None, after=after):
async for message in channel.history(
limit=None, after=after
): # type: discord.Message
# if message.author.bot: # Skip bot messages
# continue
if new_message(message, send_time, out[i], convo_delta):
if new_conversation(message, send_time, out[i], convo_delta):
out.append([])
i += 1
user = None
else:
send_time = message.created_at + timedelta(seconds=1)
send_time = (
message.created_at
) # + timedelta(seconds=1) # Can't remember why I added 1 second
if user == message.author:
out[i][-1] += "\n" + message.clean_content
else:
@ -116,9 +127,7 @@ class Chatter(Cog):
def _train_english(self):
trainer = ChatterBotCorpusTrainer(self.chatbot)
try:
trainer.train(
"chatterbot.corpus.english"
)
trainer.train("chatterbot.corpus.english")
except:
return False
return True
@ -127,8 +136,9 @@ class Chatter(Cog):
trainer = ListTrainer(self.chatbot)
try:
for convo in data:
# self.chatbot.train(convo)
trainer.train(convo)
if len(convo) > 1:
trainer.train(convo)
except:
return False
return True
@ -157,7 +167,9 @@ class Chatter(Cog):
await ctx.send_help()
return
self.chatbot = self._create_chatbot(self.data_path, algos[algo_number][0], algos[algo_number][1], ENG_MD)
self.chatbot = self._create_chatbot(
self.data_path, algos[algo_number][0], algos[algo_number][1], ENG_MD
)
await ctx.tick()
@ -254,7 +266,7 @@ class Chatter(Cog):
try:
await temp_message.delete()
except:
except discord.Forbidden:
pass
if future:
@ -296,8 +308,7 @@ class Chatter(Cog):
when_mentionables = commands.when_mentioned(self.bot, message)
prefix = my_local_get_prefix(when_mentionables, message.content
)
prefix = my_local_get_prefix(when_mentionables, message.content)
if prefix is None:
# print("not mentioned")

Loading…
Cancel
Save