Allow specifying algorithm and model
This commit is contained in:
parent
e5947953aa
commit
4fcc12a2d8
@ -167,7 +167,38 @@ settings. This can take a long time to process.
|
||||
```
|
||||
[p]chatter algorithm X
|
||||
```
|
||||
or
|
||||
```
|
||||
[p]chatter algo X 0.95
|
||||
```
|
||||
|
||||
Chatter can be configured to use one of three different Similarity algorithms.
|
||||
|
||||
Changing this can help if the response speed is too slow, but can reduce the accuracy of results.
|
||||
|
||||
The second argument is the minimum similarity threshold,
|
||||
raising this will make the bot me more selective with the responses it finds.
|
||||
|
||||
Default minimum similarity threshold is 0.90
|
||||
|
||||
|
||||
## Switching Pretrained Models
|
||||
|
||||
```
|
||||
[p]chatter model X
|
||||
```
|
||||
|
||||
Chatter can be configured to use one of three pretrained statistical models for English.
|
||||
|
||||
I have not noticed any advantage to changing this,
|
||||
but supposedly it would help by splitting the search term into more useful parts.
|
||||
|
||||
See [here](https://spacy.io/models) for more info on spaCy models.
|
||||
|
||||
Before you're able to use the *large* model (option 3), you must install it through pip.
|
||||
|
||||
*Warning:* This is ~800MB download.
|
||||
|
||||
```
|
||||
[p]pipinstall https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-2.3.1/en_core_web_lg-2.3.1.tar.gz#egg=en_core_web_lg
|
||||
```
|
||||
|
100
chatter/chat.py
100
chatter/chat.py
@ -12,11 +12,12 @@ from chatterbot.trainers import ChatterBotCorpusTrainer, ListTrainer, UbuntuCorp
|
||||
from redbot.core import Config, commands
|
||||
from redbot.core.commands import Cog
|
||||
from redbot.core.data_manager import cog_data_path
|
||||
from redbot.core.utils.predicates import MessagePredicate
|
||||
|
||||
log = logging.getLogger("red.fox_v3.chat")
|
||||
|
||||
|
||||
class ENG_LG: # TODO: Add option to use this large model
|
||||
class ENG_LG:
|
||||
ISO_639_1 = "en_core_web_lg"
|
||||
ISO_639 = "eng"
|
||||
ENGLISH_NAME = "English"
|
||||
@ -28,6 +29,12 @@ class ENG_MD:
|
||||
ENGLISH_NAME = "English"
|
||||
|
||||
|
||||
class ENG_SM:
|
||||
ISO_639_1 = "en_core_web_sm"
|
||||
ISO_639 = "eng"
|
||||
ENGLISH_NAME = "English"
|
||||
|
||||
|
||||
class Chatter(Cog):
|
||||
"""
|
||||
This cog trains a chatbot that will talk like members of your Guild
|
||||
@ -42,7 +49,13 @@ class Chatter(Cog):
|
||||
path: pathlib.Path = cog_data_path(self)
|
||||
self.data_path = path / "database.sqlite3"
|
||||
|
||||
self.chatbot = self._create_chatbot(self.data_path, SpacySimilarity, 0.45, ENG_MD)
|
||||
# TODO: Move training_model and similarity_algo to config
|
||||
# TODO: Add an option to see current settings
|
||||
|
||||
self.tagger_language = ENG_MD
|
||||
self.similarity_algo = SpacySimilarity
|
||||
self.similarity_threshold = 0.90
|
||||
self.chatbot = self._create_chatbot()
|
||||
# self.chatbot.set_trainer(ListTrainer)
|
||||
|
||||
# self.trainer = ListTrainer(self.chatbot)
|
||||
@ -52,18 +65,17 @@ class Chatter(Cog):
|
||||
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
def _create_chatbot(
|
||||
self, data_path, similarity_algorithm, similarity_threshold, tagger_language
|
||||
):
|
||||
def _create_chatbot(self):
|
||||
|
||||
return ChatBot(
|
||||
"ChatterBot",
|
||||
storage_adapter="chatterbot.storage.SQLStorageAdapter",
|
||||
database_uri="sqlite:///" + str(data_path),
|
||||
statement_comparison_function=similarity_algorithm,
|
||||
database_uri="sqlite:///" + str(self.data_path),
|
||||
statement_comparison_function=self.similarity_algo,
|
||||
response_selection_method=get_random_response,
|
||||
logic_adapters=["chatterbot.logic.BestMatch"],
|
||||
# maximum_similarity_threshold=similarity_threshold,
|
||||
tagger_language=tagger_language,
|
||||
maximum_similarity_threshold=self.similarity_threshold,
|
||||
tagger_language=self.tagger_language,
|
||||
logger=log,
|
||||
)
|
||||
|
||||
@ -103,7 +115,7 @@ class Chatter(Cog):
|
||||
try:
|
||||
|
||||
async for message in channel.history(
|
||||
limit=None, after=after, oldest_first=True
|
||||
limit=None, after=after, oldest_first=True
|
||||
).filter(
|
||||
predicate=predicate
|
||||
): # type: discord.Message
|
||||
@ -195,12 +207,14 @@ class Chatter(Cog):
|
||||
"Failed to clear training database. Please wait a bit and try again"
|
||||
)
|
||||
|
||||
self._create_chatbot(self.data_path, SpacySimilarity, 0.45, ENG_MD)
|
||||
self._create_chatbot()
|
||||
|
||||
await ctx.tick()
|
||||
|
||||
@chatter.command(name="algorithm")
|
||||
async def chatter_algorithm(self, ctx: commands.Context, algo_number: int):
|
||||
@chatter.command(name="algorithm", aliases=["algo"])
|
||||
async def chatter_algorithm(
|
||||
self, ctx: commands.Context, algo_number: int, threshold: float = None
|
||||
):
|
||||
"""
|
||||
Switch the active logic algorithm to one of the three. Default after reload is Spacy
|
||||
|
||||
@ -209,17 +223,61 @@ class Chatter(Cog):
|
||||
2: Levenshtein
|
||||
"""
|
||||
|
||||
algos = [(SpacySimilarity, 0.45), (JaccardSimilarity, 0.75), (LevenshteinDistance, 0.75)]
|
||||
algos = [SpacySimilarity, JaccardSimilarity, LevenshteinDistance]
|
||||
|
||||
if algo_number < 0 or algo_number > 2:
|
||||
await ctx.send_help()
|
||||
return
|
||||
|
||||
self.chatbot = self._create_chatbot(
|
||||
self.data_path, algos[algo_number][0], algos[algo_number][1], ENG_MD
|
||||
)
|
||||
if threshold is not None:
|
||||
if threshold >= 1 or threshold <= 0:
|
||||
await ctx.maybe_send_embed(
|
||||
"Threshold must be a number between 0 and 1 (exclusive)"
|
||||
)
|
||||
return
|
||||
else:
|
||||
self.similarity_algo = threshold
|
||||
|
||||
await ctx.tick()
|
||||
self.similarity_algo = algos[algo_number]
|
||||
async with ctx.typing():
|
||||
self.chatbot = self._create_chatbot()
|
||||
|
||||
await ctx.tick()
|
||||
|
||||
@chatter.command(name="model")
|
||||
async def chatter_model(self, ctx: commands.Context, model_number: int):
|
||||
"""
|
||||
Switch the active model to one of the three. Default after reload is Medium
|
||||
|
||||
0: Small
|
||||
1: Medium
|
||||
2: Large (Requires additional setup)
|
||||
"""
|
||||
|
||||
models = [ENG_SM, ENG_MD, ENG_LG]
|
||||
|
||||
if model_number < 0 or model_number > 2:
|
||||
await ctx.send_help()
|
||||
return
|
||||
|
||||
if model_number == 2:
|
||||
await ctx.maybe_send_embed(
|
||||
"Additional requirements needed. See guide before continuing.\n" "Continue?"
|
||||
)
|
||||
pred = MessagePredicate.yes_or_no(ctx)
|
||||
try:
|
||||
await self.bot.wait_for("message", check=pred, timeout=30)
|
||||
except TimeoutError:
|
||||
await ctx.send("Response timed out, please try again later.")
|
||||
return
|
||||
if not pred.result:
|
||||
return
|
||||
|
||||
self.tagger_language = models[model_number]
|
||||
async with ctx.typing():
|
||||
self.chatbot = self._create_chatbot()
|
||||
|
||||
await ctx.maybe_send_embed(f"Model has been switched to {self.tagger_language.ISO_639_1}")
|
||||
|
||||
@chatter.command(name="minutes")
|
||||
async def minutes(self, ctx: commands.Context, minutes: int):
|
||||
@ -278,8 +336,10 @@ class Chatter(Cog):
|
||||
"""
|
||||
|
||||
if not confirmation:
|
||||
await ctx.maybe_send_embed("Warning: This command downloads ~500MB then eats your CPU for training\n"
|
||||
"If you're sure you want to continue, run `[p]chatter trainubuntu True`")
|
||||
await ctx.maybe_send_embed(
|
||||
"Warning: This command downloads ~500MB then eats your CPU for training\n"
|
||||
"If you're sure you want to continue, run `[p]chatter trainubuntu True`"
|
||||
)
|
||||
return
|
||||
|
||||
async with ctx.typing():
|
||||
|
Loading…
x
Reference in New Issue
Block a user