Merge branch 'master' into XargsUK_master

pull/193/head
bobloy 4 years ago
commit fd80819618

@ -1,8 +1,10 @@
from .chat import Chatter from .chat import Chatter
def setup(bot): async def setup(bot):
bot.add_cog(Chatter(bot)) cog = Chatter(bot)
await cog.initialize()
bot.add_cog(cog)
# __all__ = ( # __all__ = (

@ -19,6 +19,7 @@ from redbot.core.utils.predicates import MessagePredicate
from chatter.trainers import MovieTrainer, TwitterCorpusTrainer, UbuntuCorpusTrainer2 from chatter.trainers import MovieTrainer, TwitterCorpusTrainer, UbuntuCorpusTrainer2
chatterbot_log = logging.getLogger("red.fox_v3.chatterbot")
log = logging.getLogger("red.fox_v3.chatter") log = logging.getLogger("red.fox_v3.chatter")
@ -58,11 +59,14 @@ class Chatter(Cog):
This cog trains a chatbot that will talk like members of your Guild This cog trains a chatbot that will talk like members of your Guild
""" """
models = [ENG_SM, ENG_MD, ENG_LG, ENG_TRF]
algos = [SpacySimilarity, JaccardSimilarity, LevenshteinDistance]
def __init__(self, bot): def __init__(self, bot):
super().__init__() super().__init__()
self.bot = bot self.bot = bot
self.config = Config.get_conf(self, identifier=6710497116116101114) self.config = Config.get_conf(self, identifier=6710497116116101114)
default_global = {"learning": True} default_global = {"learning": True, "model_number": 0, "algo_number": 0, "threshold": 0.90}
self.default_guild = { self.default_guild = {
"whitelist": None, "whitelist": None,
"days": 1, "days": 1,
@ -79,7 +83,7 @@ class Chatter(Cog):
self.tagger_language = ENG_SM self.tagger_language = ENG_SM
self.similarity_algo = SpacySimilarity self.similarity_algo = SpacySimilarity
self.similarity_threshold = 0.90 self.similarity_threshold = 0.90
self.chatbot = self._create_chatbot() self.chatbot = None
# self.chatbot.set_trainer(ListTrainer) # self.chatbot.set_trainer(ListTrainer)
# self.trainer = ListTrainer(self.chatbot) # self.trainer = ListTrainer(self.chatbot)
@ -98,6 +102,18 @@ class Chatter(Cog):
"""Nothing to delete""" """Nothing to delete"""
return return
async def initialize(self):
all_config = dict(self.config.defaults["GLOBAL"])
all_config.update(await self.config.all())
model_number = all_config["model_number"]
algo_number = all_config["algo_number"]
threshold = all_config["threshold"]
self.tagger_language = self.models[model_number]
self.similarity_algo = self.algos[algo_number]
self.similarity_threshold = threshold
self.chatbot = self._create_chatbot()
def _create_chatbot(self): def _create_chatbot(self):
return ChatBot( return ChatBot(
@ -110,7 +126,7 @@ class Chatter(Cog):
logic_adapters=["chatterbot.logic.BestMatch"], logic_adapters=["chatterbot.logic.BestMatch"],
maximum_similarity_threshold=self.similarity_threshold, maximum_similarity_threshold=self.similarity_threshold,
tagger_language=self.tagger_language, tagger_language=self.tagger_language,
logger=log, logger=chatterbot_log,
) )
async def _get_conversation(self, ctx, in_channels: List[discord.TextChannel]): async def _get_conversation(self, ctx, in_channels: List[discord.TextChannel]):
@ -334,15 +350,12 @@ class Chatter(Cog):
self, ctx: commands.Context, algo_number: int, threshold: float = None self, ctx: commands.Context, algo_number: int, threshold: float = None
): ):
""" """
Switch the active logic algorithm to one of the three. Default after reload is Spacy Switch the active logic algorithm to one of the three. Default is Spacy
0: Spacy 0: Spacy
1: Jaccard 1: Jaccard
2: Levenshtein 2: Levenshtein
""" """
algos = [SpacySimilarity, JaccardSimilarity, LevenshteinDistance]
if algo_number < 0 or algo_number > 2: if algo_number < 0 or algo_number > 2:
await ctx.send_help() await ctx.send_help()
return return
@ -355,8 +368,11 @@ class Chatter(Cog):
return return
else: else:
self.similarity_threshold = threshold self.similarity_threshold = threshold
await self.config.threshold.set(self.similarity_threshold)
self.similarity_algo = self.algos[algo_number]
await self.config.algo_number.set(algo_number)
self.similarity_algo = algos[algo_number]
async with ctx.typing(): async with ctx.typing():
self.chatbot = self._create_chatbot() self.chatbot = self._create_chatbot()
@ -366,21 +382,18 @@ class Chatter(Cog):
@chatter.command(name="model") @chatter.command(name="model")
async def chatter_model(self, ctx: commands.Context, model_number: int): async def chatter_model(self, ctx: commands.Context, model_number: int):
""" """
Switch the active model to one of the three. Default after reload is Medium Switch the active model to one of the three. Default is Small
0: Small 0: Small
1: Medium 1: Medium (Requires additional setup)
2: Large (Requires additional setup) 2: Large (Requires additional setup)
3. Accurate (Requires additional setup) 3. Accurate (Requires additional setup)
""" """
models = [ENG_SM, ENG_MD, ENG_LG, ENG_TRF]
if model_number < 0 or model_number > 3: if model_number < 0 or model_number > 3:
await ctx.send_help() await ctx.send_help()
return return
if model_number == 2: if model_number >= 0:
await ctx.maybe_send_embed( await ctx.maybe_send_embed(
"Additional requirements needed. See guide before continuing.\n" "Continue?" "Additional requirements needed. See guide before continuing.\n" "Continue?"
) )
@ -393,7 +406,8 @@ class Chatter(Cog):
if not pred.result: if not pred.result:
return return
self.tagger_language = models[model_number] self.tagger_language = self.models[model_number]
await self.config.model_number.set(model_number)
async with ctx.typing(): async with ctx.typing():
self.chatbot = self._create_chatbot() self.chatbot = self._create_chatbot()

@ -18,9 +18,7 @@ class MyDumbSQLStorageAdapter(SQLStorageAdapter):
if not self.database_uri: if not self.database_uri:
self.database_uri = "sqlite:///db.sqlite3" self.database_uri = "sqlite:///db.sqlite3"
self.engine = create_engine( self.engine = create_engine(self.database_uri, connect_args={"check_same_thread": False})
self.database_uri, connect_args={"check_same_thread": False}
)
if self.database_uri.startswith("sqlite://"): if self.database_uri.startswith("sqlite://"):
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
@ -31,7 +29,7 @@ class MyDumbSQLStorageAdapter(SQLStorageAdapter):
dbapi_connection.execute("PRAGMA journal_mode=WAL") dbapi_connection.execute("PRAGMA journal_mode=WAL")
dbapi_connection.execute("PRAGMA synchronous=NORMAL") dbapi_connection.execute("PRAGMA synchronous=NORMAL")
if not inspect(self.engine).has_table('Statement'): if not inspect(self.engine).has_table("Statement"):
self.create_database() self.create_database()
self.Session = sessionmaker(bind=self.engine, expire_on_commit=True) self.Session = sessionmaker(bind=self.engine, expire_on_commit=True)

Loading…
Cancel
Save