Allow specifying algorithm and model

pull/118/head
bobloy 5 years ago
parent e5947953aa
commit 4fcc12a2d8

@ -167,7 +167,38 @@ settings. This can take a long time to process.
``` ```
[p]chatter algorithm X [p]chatter algorithm X
``` ```
or
```
[p]chatter algo X 0.95
```
Chatter can be configured to use one of three different Similarity algorithms. Chatter can be configured to use one of three different Similarity algorithms.
Changing this can help if the response speed is too slow, but can reduce the accuracy of results. Changing this can help if the response speed is too slow, but can reduce the accuracy of results.
The second argument is the minimum similarity threshold,
raising this will make the bot me more selective with the responses it finds.
Default minimum similarity threshold is 0.90
## Switching Pretrained Models
```
[p]chatter model X
```
Chatter can be configured to use one of three pretrained statistical models for English.
I have not noticed any advantage to changing this,
but supposedly it would help by splitting the search term into more useful parts.
See [here](https://spacy.io/models) for more info on spaCy models.
Before you're able to use the *large* model (option 3), you must install it through pip.
*Warning:* This is ~800MB download.
```
[p]pipinstall https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-2.3.1/en_core_web_lg-2.3.1.tar.gz#egg=en_core_web_lg
```

@ -12,11 +12,12 @@ from chatterbot.trainers import ChatterBotCorpusTrainer, ListTrainer, UbuntuCorp
from redbot.core import Config, commands from redbot.core import Config, commands
from redbot.core.commands import Cog from redbot.core.commands import Cog
from redbot.core.data_manager import cog_data_path from redbot.core.data_manager import cog_data_path
from redbot.core.utils.predicates import MessagePredicate
log = logging.getLogger("red.fox_v3.chat") log = logging.getLogger("red.fox_v3.chat")
class ENG_LG: # TODO: Add option to use this large model class ENG_LG:
ISO_639_1 = "en_core_web_lg" ISO_639_1 = "en_core_web_lg"
ISO_639 = "eng" ISO_639 = "eng"
ENGLISH_NAME = "English" ENGLISH_NAME = "English"
@ -28,6 +29,12 @@ class ENG_MD:
ENGLISH_NAME = "English" ENGLISH_NAME = "English"
class ENG_SM:
ISO_639_1 = "en_core_web_sm"
ISO_639 = "eng"
ENGLISH_NAME = "English"
class Chatter(Cog): class Chatter(Cog):
""" """
This cog trains a chatbot that will talk like members of your Guild This cog trains a chatbot that will talk like members of your Guild
@ -42,7 +49,13 @@ class Chatter(Cog):
path: pathlib.Path = cog_data_path(self) path: pathlib.Path = cog_data_path(self)
self.data_path = path / "database.sqlite3" self.data_path = path / "database.sqlite3"
self.chatbot = self._create_chatbot(self.data_path, SpacySimilarity, 0.45, ENG_MD) # TODO: Move training_model and similarity_algo to config
# TODO: Add an option to see current settings
self.tagger_language = ENG_MD
self.similarity_algo = SpacySimilarity
self.similarity_threshold = 0.90
self.chatbot = self._create_chatbot()
# self.chatbot.set_trainer(ListTrainer) # self.chatbot.set_trainer(ListTrainer)
# self.trainer = ListTrainer(self.chatbot) # self.trainer = ListTrainer(self.chatbot)
@ -52,18 +65,17 @@ class Chatter(Cog):
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
def _create_chatbot( def _create_chatbot(self):
self, data_path, similarity_algorithm, similarity_threshold, tagger_language
):
return ChatBot( return ChatBot(
"ChatterBot", "ChatterBot",
storage_adapter="chatterbot.storage.SQLStorageAdapter", storage_adapter="chatterbot.storage.SQLStorageAdapter",
database_uri="sqlite:///" + str(data_path), database_uri="sqlite:///" + str(self.data_path),
statement_comparison_function=similarity_algorithm, statement_comparison_function=self.similarity_algo,
response_selection_method=get_random_response, response_selection_method=get_random_response,
logic_adapters=["chatterbot.logic.BestMatch"], logic_adapters=["chatterbot.logic.BestMatch"],
# maximum_similarity_threshold=similarity_threshold, maximum_similarity_threshold=self.similarity_threshold,
tagger_language=tagger_language, tagger_language=self.tagger_language,
logger=log, logger=log,
) )
@ -103,7 +115,7 @@ class Chatter(Cog):
try: try:
async for message in channel.history( async for message in channel.history(
limit=None, after=after, oldest_first=True limit=None, after=after, oldest_first=True
).filter( ).filter(
predicate=predicate predicate=predicate
): # type: discord.Message ): # type: discord.Message
@ -195,12 +207,14 @@ class Chatter(Cog):
"Failed to clear training database. Please wait a bit and try again" "Failed to clear training database. Please wait a bit and try again"
) )
self._create_chatbot(self.data_path, SpacySimilarity, 0.45, ENG_MD) self._create_chatbot()
await ctx.tick() await ctx.tick()
@chatter.command(name="algorithm") @chatter.command(name="algorithm", aliases=["algo"])
async def chatter_algorithm(self, ctx: commands.Context, algo_number: int): async def chatter_algorithm(
self, ctx: commands.Context, algo_number: int, threshold: float = None
):
""" """
Switch the active logic algorithm to one of the three. Default after reload is Spacy Switch the active logic algorithm to one of the three. Default after reload is Spacy
@ -209,17 +223,61 @@ class Chatter(Cog):
2: Levenshtein 2: Levenshtein
""" """
algos = [(SpacySimilarity, 0.45), (JaccardSimilarity, 0.75), (LevenshteinDistance, 0.75)] algos = [SpacySimilarity, JaccardSimilarity, LevenshteinDistance]
if algo_number < 0 or algo_number > 2: if algo_number < 0 or algo_number > 2:
await ctx.send_help() await ctx.send_help()
return return
self.chatbot = self._create_chatbot( if threshold is not None:
self.data_path, algos[algo_number][0], algos[algo_number][1], ENG_MD if threshold >= 1 or threshold <= 0:
) await ctx.maybe_send_embed(
"Threshold must be a number between 0 and 1 (exclusive)"
)
return
else:
self.similarity_algo = threshold
await ctx.tick() self.similarity_algo = algos[algo_number]
async with ctx.typing():
self.chatbot = self._create_chatbot()
await ctx.tick()
@chatter.command(name="model")
async def chatter_model(self, ctx: commands.Context, model_number: int):
"""
Switch the active model to one of the three. Default after reload is Medium
0: Small
1: Medium
2: Large (Requires additional setup)
"""
models = [ENG_SM, ENG_MD, ENG_LG]
if model_number < 0 or model_number > 2:
await ctx.send_help()
return
if model_number == 2:
await ctx.maybe_send_embed(
"Additional requirements needed. See guide before continuing.\n" "Continue?"
)
pred = MessagePredicate.yes_or_no(ctx)
try:
await self.bot.wait_for("message", check=pred, timeout=30)
except TimeoutError:
await ctx.send("Response timed out, please try again later.")
return
if not pred.result:
return
self.tagger_language = models[model_number]
async with ctx.typing():
self.chatbot = self._create_chatbot()
await ctx.maybe_send_embed(f"Model has been switched to {self.tagger_language.ISO_639_1}")
@chatter.command(name="minutes") @chatter.command(name="minutes")
async def minutes(self, ctx: commands.Context, minutes: int): async def minutes(self, ctx: commands.Context, minutes: int):
@ -278,8 +336,10 @@ class Chatter(Cog):
""" """
if not confirmation: if not confirmation:
await ctx.maybe_send_embed("Warning: This command downloads ~500MB then eats your CPU for training\n" await ctx.maybe_send_embed(
"If you're sure you want to continue, run `[p]chatter trainubuntu True`") "Warning: This command downloads ~500MB then eats your CPU for training\n"
"If you're sure you want to continue, run `[p]chatter trainubuntu True`"
)
return return
async with ctx.typing(): async with ctx.typing():

Loading…
Cancel
Save