diff --git a/chatter/chat.py b/chatter/chat.py index a0a5f28..098ba73 100644 --- a/chatter/chat.py +++ b/chatter/chat.py @@ -17,7 +17,7 @@ from redbot.core.commands import Cog from redbot.core.data_manager import cog_data_path from redbot.core.utils.predicates import MessagePredicate -from chatter.trainers import TwitterCorpusTrainer +from chatter.trainers import TwitterCorpusTrainer, UbuntuCorpusTrainer2 log = logging.getLogger("red.fox_v3.chatter") @@ -168,6 +168,10 @@ class Chatter(Cog): trainer.train() return True + async def _train_ubuntu2(self): + trainer = UbuntuCorpusTrainer2(self.chatbot, cog_data_path(self)) + await trainer.asynctrain() + def _train_english(self): trainer = ChatterBotCorpusTrainer(self.chatbot) # try: @@ -353,6 +357,15 @@ class Chatter(Cog): await self.config.guild(ctx.guild).days.set(days) await ctx.tick() + @commands.is_owner() + @chatter.command(name="kaggle") + async def chatter_kaggle(self, ctx: commands.Context): + """Register with the kaggle API to download additional datasets for training""" + if not await self.check_for_kaggle(): + await ctx.maybe_send_embed( + "[Click here for instructions to setup the kaggle api](https://github.com/Kaggle/kaggle-api#api-credentials)" + ) + @commands.is_owner() @chatter.command(name="backup") async def backup(self, ctx, backupname): @@ -376,7 +389,13 @@ class Chatter(Cog): await ctx.maybe_send_embed("Error occurred :(") @commands.is_owner() - @chatter.command(name="trainubuntu") + @chatter.group(name="train") + async def chatter_train(self, ctx: commands.Context): + """Commands for training the bot""" + pass + + @commands.is_owner() + @chatter_train.command(name="ubuntu") async def chatter_train_ubuntu(self, ctx: commands.Context, confirmation: bool = False): """ WARNING: Large Download! Trains the bot using Ubuntu Dialog Corpus data. @@ -385,7 +404,7 @@ class Chatter(Cog): if not confirmation: await ctx.maybe_send_embed( "Warning: This command downloads ~500MB then eats your CPU for training\n" - "If you're sure you want to continue, run `[p]chatter trainubuntu True`" + "If you're sure you want to continue, run `[p]chatter train ubuntu True`" ) return @@ -398,7 +417,29 @@ class Chatter(Cog): await ctx.send("Error occurred :(") @commands.is_owner() - @chatter.command(name="trainenglish") + @chatter_train.command(name="ubuntu2") + async def chatter_train_ubuntu2(self, ctx: commands.Context, confirmation: bool = False): + """ + WARNING: Large Download! Trains the bot using *NEW* Ubuntu Dialog Corpus data. + """ + + if not confirmation: + await ctx.maybe_send_embed( + "Warning: This command downloads ~800 then eats your CPU for training\n" + "If you're sure you want to continue, run `[p]chatter train ubuntu2 True`" + ) + return + + async with ctx.typing(): + future = await self._train_ubuntu2() + + if future: + await ctx.send("Training successful!") + else: + await ctx.send("Error occurred :(") + + @commands.is_owner() + @chatter_train.command(name="english") async def chatter_train_english(self, ctx: commands.Context): """ Trains the bot in english @@ -412,10 +453,27 @@ class Chatter(Cog): await ctx.maybe_send_embed("Error occurred :(") @commands.is_owner() - @chatter.command() - async def train(self, ctx: commands.Context, channel: discord.TextChannel): + @chatter_train.command(name="list") + async def chatter_train_list(self, ctx: commands.Context): + """Trains the bot based on an uploaded list. + + Must be a file in the format of a python list: ['prompt', 'response1', 'response2'] + """ + if not ctx.message.attachments: + await ctx.maybe_send_embed("You must upload a file when using this command") + return + + attachment: discord.Attachment = ctx.message.attachments[0] + + a_bytes = await attachment.read() + + await ctx.send("Not yet implemented") + + @commands.is_owner() + @chatter_train.command(name="channel") + async def chatter_train_channel(self, ctx: commands.Context, channel: discord.TextChannel): """ - Trains the bot based on language in this guild + Trains the bot based on language in this guild. """ await ctx.maybe_send_embed( @@ -502,7 +560,7 @@ class Chatter(Cog): if self._last_message_per_channel[ctx.channel.id] is not None: last_m: discord.Message = self._last_message_per_channel[ctx.channel.id] minutes = self._guild_cache[ctx.guild.id]["convo_delta"] - if (datetime.utcnow() - last_m.created_at).seconds > minutes*60: + if (datetime.utcnow() - last_m.created_at).seconds > minutes * 60: in_response_to = None else: in_response_to = last_m.content @@ -511,7 +569,7 @@ class Chatter(Cog): if in_response_to is None: log.debug("Generating response") - Statement = self.chatbot.storage.get_object('statement') + Statement = self.chatbot.storage.get_object("statement") future = await self.loop.run_in_executor( None, self.chatbot.generate_response, Statement(text) ) @@ -525,3 +583,6 @@ class Chatter(Cog): self._last_message_per_channel[ctx.channel.id] = await ctx.send(str(future)) else: await ctx.send(":thinking:") + + async def check_for_kaggle(self): + return False diff --git a/chatter/info.json b/chatter/info.json index b79e587..a048c23 100644 --- a/chatter/info.json +++ b/chatter/info.json @@ -17,7 +17,8 @@ "pytz", "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz#egg=en_core_web_sm", "https://github.com/explosion/spacy-models/releases/download/en_core_web_md-2.3.1/en_core_web_md-2.3.1.tar.gz#egg=en_core_web_md", - "spacy>=2.3,<2.4" + "spacy>=2.3,<2.4", + "kaggle" ], "short": "Local Chatbot run on machine learning", "end_user_data_statement": "This cog only stores anonymous conversations data; no End User Data is stored.", diff --git a/chatter/trainers.py b/chatter/trainers.py index 42d6288..0b765b7 100644 --- a/chatter/trainers.py +++ b/chatter/trainers.py @@ -1,6 +1,146 @@ +import asyncio +import csv +import logging +import os +import pathlib +import time +from functools import partial + from chatterbot import utils from chatterbot.conversation import Statement +from chatterbot.tagging import PosLemmaTagger from chatterbot.trainers import Trainer +from redbot.core.bot import Red +from dateutil import parser as date_parser +from redbot.core.utils import AsyncIter + +log = logging.getLogger("red.fox_v3.chatter.trainers") + + +class KaggleTrainer(Trainer): + def __init__(self, chatbot, datapath: pathlib.Path, **kwargs): + super().__init__(chatbot, **kwargs) + + self.data_directory = datapath / kwargs.get("downloadpath", "kaggle_download") + + self.kaggle_dataset = kwargs.get( + "kaggle_dataset", + "Cornell-University/movie-dialog-corpus", + ) + + # Create the data directory if it does not already exist + if not os.path.exists(self.data_directory): + os.makedirs(self.data_directory) + + def is_downloaded(self, file_path): + """ + Check if the data file is already downloaded. + """ + if os.path.exists(file_path): + self.chatbot.logger.info("File is already downloaded") + return True + + return False + + async def download(self, dataset): + import kaggle # This triggers the API token check + + future = await asyncio.get_event_loop().run_in_executor( + None, + partial( + kaggle.api.dataset_download_files, + dataset=dataset, + path=self.data_directory, + quiet=False, + unzip=True, + ), + ) + + +class UbuntuCorpusTrainer2(KaggleTrainer): + def __init__(self, chatbot, datapath: pathlib.Path, **kwargs): + super().__init__( + chatbot, + datapath, + downloadpath="ubuntu_data_v2", + kaggle_dataset="rtatman/ubuntu-dialogue-corpus", + **kwargs + ) + + async def asynctrain(self, *args, **kwargs): + extracted_dir = self.data_directory / "Ubuntu-dialogue-corpus" + + # Download and extract the Ubuntu dialog corpus if needed + if not extracted_dir.exists(): + await self.download(self.kaggle_dataset) + else: + log.info("Ubuntu dialogue already downloaded") + if not extracted_dir.exists(): + raise FileNotFoundError("Did not extract in the expected way") + + train_dialogue = kwargs.get("train_dialogue", True) + train_196_dialogue = kwargs.get("train_196", False) + train_301_dialogue = kwargs.get("train_301", False) + + if train_dialogue: + await self.run_dialogue_training(extracted_dir, "dialogueText.csv") + + if train_196_dialogue: + await self.run_dialogue_training(extracted_dir, "dialogueText_196.csv") + + if train_301_dialogue: + await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv") + + async def run_dialogue_training(self, extracted_dir, dialogue_file): + log.info(f"Beginning dialogue training on {dialogue_file}") + start_time = time.time() + + tagger = PosLemmaTagger(language=self.chatbot.storage.tagger.language) + + with open(extracted_dir / dialogue_file, "r", encoding="utf-8") as dg: + reader = csv.DictReader(dg) + + next(reader) # Skip the header + + last_dialogue_id = None + previous_statement_text = None + previous_statement_search_text = "" + statements_from_file = [] + + async for row in AsyncIter(reader): + dialogue_id = row["dialogueID"] + if dialogue_id != last_dialogue_id: + previous_statement_text = None + previous_statement_search_text = "" + last_dialogue_id = dialogue_id + + if len(row) > 0: + statement = Statement( + text=row["text"], + in_response_to=previous_statement_text, + conversation="training", + created_at=date_parser.parse(row["date"]), + persona=row["from"], + ) + + for preprocessor in self.chatbot.preprocessors: + statement = preprocessor(statement) + + statement.search_text = tagger.get_text_index_string(statement.text) + statement.search_in_response_to = previous_statement_search_text + + previous_statement_text = statement.text + previous_statement_search_text = statement.search_text + + statements_from_file.append(statement) + + if statements_from_file: + self.chatbot.storage.create_many(statements_from_file) + + print("Training took", time.time() - start_time, "seconds.") + + def train(self, *args, **kwargs): + log.error("See asynctrain instead") class TwitterCorpusTrainer(Trainer): @@ -46,4 +186,4 @@ class TwitterCorpusTrainer(Trainer): # # statements_to_create.append(statement) # - # self.chatbot.storage.create_many(statements_to_create) \ No newline at end of file + # self.chatbot.storage.create_many(statements_to_create)