Some progress on updated ubuntu trainer
This commit is contained in:
parent
14f8b825d8
commit
337def2fa3
@ -17,7 +17,7 @@ from redbot.core.commands import Cog
|
||||
from redbot.core.data_manager import cog_data_path
|
||||
from redbot.core.utils.predicates import MessagePredicate
|
||||
|
||||
from chatter.trainers import TwitterCorpusTrainer
|
||||
from chatter.trainers import TwitterCorpusTrainer, UbuntuCorpusTrainer2
|
||||
|
||||
log = logging.getLogger("red.fox_v3.chatter")
|
||||
|
||||
@ -168,6 +168,10 @@ class Chatter(Cog):
|
||||
trainer.train()
|
||||
return True
|
||||
|
||||
async def _train_ubuntu2(self):
|
||||
trainer = UbuntuCorpusTrainer2(self.chatbot, cog_data_path(self))
|
||||
await trainer.asynctrain()
|
||||
|
||||
def _train_english(self):
|
||||
trainer = ChatterBotCorpusTrainer(self.chatbot)
|
||||
# try:
|
||||
@ -353,6 +357,15 @@ class Chatter(Cog):
|
||||
await self.config.guild(ctx.guild).days.set(days)
|
||||
await ctx.tick()
|
||||
|
||||
@commands.is_owner()
|
||||
@chatter.command(name="kaggle")
|
||||
async def chatter_kaggle(self, ctx: commands.Context):
|
||||
"""Register with the kaggle API to download additional datasets for training"""
|
||||
if not await self.check_for_kaggle():
|
||||
await ctx.maybe_send_embed(
|
||||
"[Click here for instructions to setup the kaggle api](https://github.com/Kaggle/kaggle-api#api-credentials)"
|
||||
)
|
||||
|
||||
@commands.is_owner()
|
||||
@chatter.command(name="backup")
|
||||
async def backup(self, ctx, backupname):
|
||||
@ -376,7 +389,13 @@ class Chatter(Cog):
|
||||
await ctx.maybe_send_embed("Error occurred :(")
|
||||
|
||||
@commands.is_owner()
|
||||
@chatter.command(name="trainubuntu")
|
||||
@chatter.group(name="train")
|
||||
async def chatter_train(self, ctx: commands.Context):
|
||||
"""Commands for training the bot"""
|
||||
pass
|
||||
|
||||
@commands.is_owner()
|
||||
@chatter_train.command(name="ubuntu")
|
||||
async def chatter_train_ubuntu(self, ctx: commands.Context, confirmation: bool = False):
|
||||
"""
|
||||
WARNING: Large Download! Trains the bot using Ubuntu Dialog Corpus data.
|
||||
@ -385,7 +404,7 @@ class Chatter(Cog):
|
||||
if not confirmation:
|
||||
await ctx.maybe_send_embed(
|
||||
"Warning: This command downloads ~500MB then eats your CPU for training\n"
|
||||
"If you're sure you want to continue, run `[p]chatter trainubuntu True`"
|
||||
"If you're sure you want to continue, run `[p]chatter train ubuntu True`"
|
||||
)
|
||||
return
|
||||
|
||||
@ -398,7 +417,29 @@ class Chatter(Cog):
|
||||
await ctx.send("Error occurred :(")
|
||||
|
||||
@commands.is_owner()
|
||||
@chatter.command(name="trainenglish")
|
||||
@chatter_train.command(name="ubuntu2")
|
||||
async def chatter_train_ubuntu2(self, ctx: commands.Context, confirmation: bool = False):
|
||||
"""
|
||||
WARNING: Large Download! Trains the bot using *NEW* Ubuntu Dialog Corpus data.
|
||||
"""
|
||||
|
||||
if not confirmation:
|
||||
await ctx.maybe_send_embed(
|
||||
"Warning: This command downloads ~800 then eats your CPU for training\n"
|
||||
"If you're sure you want to continue, run `[p]chatter train ubuntu2 True`"
|
||||
)
|
||||
return
|
||||
|
||||
async with ctx.typing():
|
||||
future = await self._train_ubuntu2()
|
||||
|
||||
if future:
|
||||
await ctx.send("Training successful!")
|
||||
else:
|
||||
await ctx.send("Error occurred :(")
|
||||
|
||||
@commands.is_owner()
|
||||
@chatter_train.command(name="english")
|
||||
async def chatter_train_english(self, ctx: commands.Context):
|
||||
"""
|
||||
Trains the bot in english
|
||||
@ -412,10 +453,27 @@ class Chatter(Cog):
|
||||
await ctx.maybe_send_embed("Error occurred :(")
|
||||
|
||||
@commands.is_owner()
|
||||
@chatter.command()
|
||||
async def train(self, ctx: commands.Context, channel: discord.TextChannel):
|
||||
@chatter_train.command(name="list")
|
||||
async def chatter_train_list(self, ctx: commands.Context):
|
||||
"""Trains the bot based on an uploaded list.
|
||||
|
||||
Must be a file in the format of a python list: ['prompt', 'response1', 'response2']
|
||||
"""
|
||||
Trains the bot based on language in this guild
|
||||
if not ctx.message.attachments:
|
||||
await ctx.maybe_send_embed("You must upload a file when using this command")
|
||||
return
|
||||
|
||||
attachment: discord.Attachment = ctx.message.attachments[0]
|
||||
|
||||
a_bytes = await attachment.read()
|
||||
|
||||
await ctx.send("Not yet implemented")
|
||||
|
||||
@commands.is_owner()
|
||||
@chatter_train.command(name="channel")
|
||||
async def chatter_train_channel(self, ctx: commands.Context, channel: discord.TextChannel):
|
||||
"""
|
||||
Trains the bot based on language in this guild.
|
||||
"""
|
||||
|
||||
await ctx.maybe_send_embed(
|
||||
@ -502,7 +560,7 @@ class Chatter(Cog):
|
||||
if self._last_message_per_channel[ctx.channel.id] is not None:
|
||||
last_m: discord.Message = self._last_message_per_channel[ctx.channel.id]
|
||||
minutes = self._guild_cache[ctx.guild.id]["convo_delta"]
|
||||
if (datetime.utcnow() - last_m.created_at).seconds > minutes*60:
|
||||
if (datetime.utcnow() - last_m.created_at).seconds > minutes * 60:
|
||||
in_response_to = None
|
||||
else:
|
||||
in_response_to = last_m.content
|
||||
@ -511,7 +569,7 @@ class Chatter(Cog):
|
||||
|
||||
if in_response_to is None:
|
||||
log.debug("Generating response")
|
||||
Statement = self.chatbot.storage.get_object('statement')
|
||||
Statement = self.chatbot.storage.get_object("statement")
|
||||
future = await self.loop.run_in_executor(
|
||||
None, self.chatbot.generate_response, Statement(text)
|
||||
)
|
||||
@ -525,3 +583,6 @@ class Chatter(Cog):
|
||||
self._last_message_per_channel[ctx.channel.id] = await ctx.send(str(future))
|
||||
else:
|
||||
await ctx.send(":thinking:")
|
||||
|
||||
async def check_for_kaggle(self):
|
||||
return False
|
||||
|
@ -17,7 +17,8 @@
|
||||
"pytz",
|
||||
"https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz#egg=en_core_web_sm",
|
||||
"https://github.com/explosion/spacy-models/releases/download/en_core_web_md-2.3.1/en_core_web_md-2.3.1.tar.gz#egg=en_core_web_md",
|
||||
"spacy>=2.3,<2.4"
|
||||
"spacy>=2.3,<2.4",
|
||||
"kaggle"
|
||||
],
|
||||
"short": "Local Chatbot run on machine learning",
|
||||
"end_user_data_statement": "This cog only stores anonymous conversations data; no End User Data is stored.",
|
||||
|
@ -1,6 +1,146 @@
|
||||
import asyncio
|
||||
import csv
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
from chatterbot import utils
|
||||
from chatterbot.conversation import Statement
|
||||
from chatterbot.tagging import PosLemmaTagger
|
||||
from chatterbot.trainers import Trainer
|
||||
from redbot.core.bot import Red
|
||||
from dateutil import parser as date_parser
|
||||
from redbot.core.utils import AsyncIter
|
||||
|
||||
log = logging.getLogger("red.fox_v3.chatter.trainers")
|
||||
|
||||
|
||||
class KaggleTrainer(Trainer):
|
||||
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
|
||||
super().__init__(chatbot, **kwargs)
|
||||
|
||||
self.data_directory = datapath / kwargs.get("downloadpath", "kaggle_download")
|
||||
|
||||
self.kaggle_dataset = kwargs.get(
|
||||
"kaggle_dataset",
|
||||
"Cornell-University/movie-dialog-corpus",
|
||||
)
|
||||
|
||||
# Create the data directory if it does not already exist
|
||||
if not os.path.exists(self.data_directory):
|
||||
os.makedirs(self.data_directory)
|
||||
|
||||
def is_downloaded(self, file_path):
|
||||
"""
|
||||
Check if the data file is already downloaded.
|
||||
"""
|
||||
if os.path.exists(file_path):
|
||||
self.chatbot.logger.info("File is already downloaded")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def download(self, dataset):
|
||||
import kaggle # This triggers the API token check
|
||||
|
||||
future = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
kaggle.api.dataset_download_files,
|
||||
dataset=dataset,
|
||||
path=self.data_directory,
|
||||
quiet=False,
|
||||
unzip=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class UbuntuCorpusTrainer2(KaggleTrainer):
|
||||
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
|
||||
super().__init__(
|
||||
chatbot,
|
||||
datapath,
|
||||
downloadpath="ubuntu_data_v2",
|
||||
kaggle_dataset="rtatman/ubuntu-dialogue-corpus",
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def asynctrain(self, *args, **kwargs):
|
||||
extracted_dir = self.data_directory / "Ubuntu-dialogue-corpus"
|
||||
|
||||
# Download and extract the Ubuntu dialog corpus if needed
|
||||
if not extracted_dir.exists():
|
||||
await self.download(self.kaggle_dataset)
|
||||
else:
|
||||
log.info("Ubuntu dialogue already downloaded")
|
||||
if not extracted_dir.exists():
|
||||
raise FileNotFoundError("Did not extract in the expected way")
|
||||
|
||||
train_dialogue = kwargs.get("train_dialogue", True)
|
||||
train_196_dialogue = kwargs.get("train_196", False)
|
||||
train_301_dialogue = kwargs.get("train_301", False)
|
||||
|
||||
if train_dialogue:
|
||||
await self.run_dialogue_training(extracted_dir, "dialogueText.csv")
|
||||
|
||||
if train_196_dialogue:
|
||||
await self.run_dialogue_training(extracted_dir, "dialogueText_196.csv")
|
||||
|
||||
if train_301_dialogue:
|
||||
await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv")
|
||||
|
||||
async def run_dialogue_training(self, extracted_dir, dialogue_file):
|
||||
log.info(f"Beginning dialogue training on {dialogue_file}")
|
||||
start_time = time.time()
|
||||
|
||||
tagger = PosLemmaTagger(language=self.chatbot.storage.tagger.language)
|
||||
|
||||
with open(extracted_dir / dialogue_file, "r", encoding="utf-8") as dg:
|
||||
reader = csv.DictReader(dg)
|
||||
|
||||
next(reader) # Skip the header
|
||||
|
||||
last_dialogue_id = None
|
||||
previous_statement_text = None
|
||||
previous_statement_search_text = ""
|
||||
statements_from_file = []
|
||||
|
||||
async for row in AsyncIter(reader):
|
||||
dialogue_id = row["dialogueID"]
|
||||
if dialogue_id != last_dialogue_id:
|
||||
previous_statement_text = None
|
||||
previous_statement_search_text = ""
|
||||
last_dialogue_id = dialogue_id
|
||||
|
||||
if len(row) > 0:
|
||||
statement = Statement(
|
||||
text=row["text"],
|
||||
in_response_to=previous_statement_text,
|
||||
conversation="training",
|
||||
created_at=date_parser.parse(row["date"]),
|
||||
persona=row["from"],
|
||||
)
|
||||
|
||||
for preprocessor in self.chatbot.preprocessors:
|
||||
statement = preprocessor(statement)
|
||||
|
||||
statement.search_text = tagger.get_text_index_string(statement.text)
|
||||
statement.search_in_response_to = previous_statement_search_text
|
||||
|
||||
previous_statement_text = statement.text
|
||||
previous_statement_search_text = statement.search_text
|
||||
|
||||
statements_from_file.append(statement)
|
||||
|
||||
if statements_from_file:
|
||||
self.chatbot.storage.create_many(statements_from_file)
|
||||
|
||||
print("Training took", time.time() - start_time, "seconds.")
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
log.error("See asynctrain instead")
|
||||
|
||||
|
||||
class TwitterCorpusTrainer(Trainer):
|
||||
@ -46,4 +186,4 @@ class TwitterCorpusTrainer(Trainer):
|
||||
#
|
||||
# statements_to_create.append(statement)
|
||||
#
|
||||
# self.chatbot.storage.create_many(statements_to_create)
|
||||
# self.chatbot.storage.create_many(statements_to_create)
|
||||
|
Loading…
x
Reference in New Issue
Block a user