Some progress on updated ubuntu trainer

pull/175/head
bobloy 4 years ago
parent 14f8b825d8
commit 337def2fa3

@ -17,7 +17,7 @@ from redbot.core.commands import Cog
from redbot.core.data_manager import cog_data_path
from redbot.core.utils.predicates import MessagePredicate
from chatter.trainers import TwitterCorpusTrainer
from chatter.trainers import TwitterCorpusTrainer, UbuntuCorpusTrainer2
log = logging.getLogger("red.fox_v3.chatter")
@ -168,6 +168,10 @@ class Chatter(Cog):
trainer.train()
return True
async def _train_ubuntu2(self):
trainer = UbuntuCorpusTrainer2(self.chatbot, cog_data_path(self))
await trainer.asynctrain()
def _train_english(self):
trainer = ChatterBotCorpusTrainer(self.chatbot)
# try:
@ -353,6 +357,15 @@ class Chatter(Cog):
await self.config.guild(ctx.guild).days.set(days)
await ctx.tick()
@commands.is_owner()
@chatter.command(name="kaggle")
async def chatter_kaggle(self, ctx: commands.Context):
"""Register with the kaggle API to download additional datasets for training"""
if not await self.check_for_kaggle():
await ctx.maybe_send_embed(
"[Click here for instructions to setup the kaggle api](https://github.com/Kaggle/kaggle-api#api-credentials)"
)
@commands.is_owner()
@chatter.command(name="backup")
async def backup(self, ctx, backupname):
@ -376,7 +389,13 @@ class Chatter(Cog):
await ctx.maybe_send_embed("Error occurred :(")
@commands.is_owner()
@chatter.command(name="trainubuntu")
@chatter.group(name="train")
async def chatter_train(self, ctx: commands.Context):
"""Commands for training the bot"""
pass
@commands.is_owner()
@chatter_train.command(name="ubuntu")
async def chatter_train_ubuntu(self, ctx: commands.Context, confirmation: bool = False):
"""
WARNING: Large Download! Trains the bot using Ubuntu Dialog Corpus data.
@ -385,7 +404,7 @@ class Chatter(Cog):
if not confirmation:
await ctx.maybe_send_embed(
"Warning: This command downloads ~500MB then eats your CPU for training\n"
"If you're sure you want to continue, run `[p]chatter trainubuntu True`"
"If you're sure you want to continue, run `[p]chatter train ubuntu True`"
)
return
@ -398,7 +417,29 @@ class Chatter(Cog):
await ctx.send("Error occurred :(")
@commands.is_owner()
@chatter.command(name="trainenglish")
@chatter_train.command(name="ubuntu2")
async def chatter_train_ubuntu2(self, ctx: commands.Context, confirmation: bool = False):
"""
WARNING: Large Download! Trains the bot using *NEW* Ubuntu Dialog Corpus data.
"""
if not confirmation:
await ctx.maybe_send_embed(
"Warning: This command downloads ~800 then eats your CPU for training\n"
"If you're sure you want to continue, run `[p]chatter train ubuntu2 True`"
)
return
async with ctx.typing():
future = await self._train_ubuntu2()
if future:
await ctx.send("Training successful!")
else:
await ctx.send("Error occurred :(")
@commands.is_owner()
@chatter_train.command(name="english")
async def chatter_train_english(self, ctx: commands.Context):
"""
Trains the bot in english
@ -412,10 +453,27 @@ class Chatter(Cog):
await ctx.maybe_send_embed("Error occurred :(")
@commands.is_owner()
@chatter.command()
async def train(self, ctx: commands.Context, channel: discord.TextChannel):
@chatter_train.command(name="list")
async def chatter_train_list(self, ctx: commands.Context):
"""Trains the bot based on an uploaded list.
Must be a file in the format of a python list: ['prompt', 'response1', 'response2']
"""
if not ctx.message.attachments:
await ctx.maybe_send_embed("You must upload a file when using this command")
return
attachment: discord.Attachment = ctx.message.attachments[0]
a_bytes = await attachment.read()
await ctx.send("Not yet implemented")
@commands.is_owner()
@chatter_train.command(name="channel")
async def chatter_train_channel(self, ctx: commands.Context, channel: discord.TextChannel):
"""
Trains the bot based on language in this guild
Trains the bot based on language in this guild.
"""
await ctx.maybe_send_embed(
@ -502,7 +560,7 @@ class Chatter(Cog):
if self._last_message_per_channel[ctx.channel.id] is not None:
last_m: discord.Message = self._last_message_per_channel[ctx.channel.id]
minutes = self._guild_cache[ctx.guild.id]["convo_delta"]
if (datetime.utcnow() - last_m.created_at).seconds > minutes*60:
if (datetime.utcnow() - last_m.created_at).seconds > minutes * 60:
in_response_to = None
else:
in_response_to = last_m.content
@ -511,7 +569,7 @@ class Chatter(Cog):
if in_response_to is None:
log.debug("Generating response")
Statement = self.chatbot.storage.get_object('statement')
Statement = self.chatbot.storage.get_object("statement")
future = await self.loop.run_in_executor(
None, self.chatbot.generate_response, Statement(text)
)
@ -525,3 +583,6 @@ class Chatter(Cog):
self._last_message_per_channel[ctx.channel.id] = await ctx.send(str(future))
else:
await ctx.send(":thinking:")
async def check_for_kaggle(self):
return False

@ -17,7 +17,8 @@
"pytz",
"https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz#egg=en_core_web_sm",
"https://github.com/explosion/spacy-models/releases/download/en_core_web_md-2.3.1/en_core_web_md-2.3.1.tar.gz#egg=en_core_web_md",
"spacy>=2.3,<2.4"
"spacy>=2.3,<2.4",
"kaggle"
],
"short": "Local Chatbot run on machine learning",
"end_user_data_statement": "This cog only stores anonymous conversations data; no End User Data is stored.",

@ -1,6 +1,146 @@
import asyncio
import csv
import logging
import os
import pathlib
import time
from functools import partial
from chatterbot import utils
from chatterbot.conversation import Statement
from chatterbot.tagging import PosLemmaTagger
from chatterbot.trainers import Trainer
from redbot.core.bot import Red
from dateutil import parser as date_parser
from redbot.core.utils import AsyncIter
log = logging.getLogger("red.fox_v3.chatter.trainers")
class KaggleTrainer(Trainer):
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
super().__init__(chatbot, **kwargs)
self.data_directory = datapath / kwargs.get("downloadpath", "kaggle_download")
self.kaggle_dataset = kwargs.get(
"kaggle_dataset",
"Cornell-University/movie-dialog-corpus",
)
# Create the data directory if it does not already exist
if not os.path.exists(self.data_directory):
os.makedirs(self.data_directory)
def is_downloaded(self, file_path):
"""
Check if the data file is already downloaded.
"""
if os.path.exists(file_path):
self.chatbot.logger.info("File is already downloaded")
return True
return False
async def download(self, dataset):
import kaggle # This triggers the API token check
future = await asyncio.get_event_loop().run_in_executor(
None,
partial(
kaggle.api.dataset_download_files,
dataset=dataset,
path=self.data_directory,
quiet=False,
unzip=True,
),
)
class UbuntuCorpusTrainer2(KaggleTrainer):
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
super().__init__(
chatbot,
datapath,
downloadpath="ubuntu_data_v2",
kaggle_dataset="rtatman/ubuntu-dialogue-corpus",
**kwargs
)
async def asynctrain(self, *args, **kwargs):
extracted_dir = self.data_directory / "Ubuntu-dialogue-corpus"
# Download and extract the Ubuntu dialog corpus if needed
if not extracted_dir.exists():
await self.download(self.kaggle_dataset)
else:
log.info("Ubuntu dialogue already downloaded")
if not extracted_dir.exists():
raise FileNotFoundError("Did not extract in the expected way")
train_dialogue = kwargs.get("train_dialogue", True)
train_196_dialogue = kwargs.get("train_196", False)
train_301_dialogue = kwargs.get("train_301", False)
if train_dialogue:
await self.run_dialogue_training(extracted_dir, "dialogueText.csv")
if train_196_dialogue:
await self.run_dialogue_training(extracted_dir, "dialogueText_196.csv")
if train_301_dialogue:
await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv")
async def run_dialogue_training(self, extracted_dir, dialogue_file):
log.info(f"Beginning dialogue training on {dialogue_file}")
start_time = time.time()
tagger = PosLemmaTagger(language=self.chatbot.storage.tagger.language)
with open(extracted_dir / dialogue_file, "r", encoding="utf-8") as dg:
reader = csv.DictReader(dg)
next(reader) # Skip the header
last_dialogue_id = None
previous_statement_text = None
previous_statement_search_text = ""
statements_from_file = []
async for row in AsyncIter(reader):
dialogue_id = row["dialogueID"]
if dialogue_id != last_dialogue_id:
previous_statement_text = None
previous_statement_search_text = ""
last_dialogue_id = dialogue_id
if len(row) > 0:
statement = Statement(
text=row["text"],
in_response_to=previous_statement_text,
conversation="training",
created_at=date_parser.parse(row["date"]),
persona=row["from"],
)
for preprocessor in self.chatbot.preprocessors:
statement = preprocessor(statement)
statement.search_text = tagger.get_text_index_string(statement.text)
statement.search_in_response_to = previous_statement_search_text
previous_statement_text = statement.text
previous_statement_search_text = statement.search_text
statements_from_file.append(statement)
if statements_from_file:
self.chatbot.storage.create_many(statements_from_file)
print("Training took", time.time() - start_time, "seconds.")
def train(self, *args, **kwargs):
log.error("See asynctrain instead")
class TwitterCorpusTrainer(Trainer):

Loading…
Cancel
Save