import asyncio import csv import html import logging import os import pathlib import time from functools import partial from chatterbot import utils from chatterbot.conversation import Statement from chatterbot.tagging import PosLemmaTagger from chatterbot.trainers import Trainer from redbot.core.bot import Red from dateutil import parser as date_parser from redbot.core.utils import AsyncIter log = logging.getLogger("red.fox_v3.chatter.trainers") class KaggleTrainer(Trainer): def __init__(self, chatbot, datapath: pathlib.Path, **kwargs): super().__init__(chatbot, **kwargs) self.data_directory = datapath / kwargs.get("downloadpath", "kaggle_download") self.kaggle_dataset = kwargs.get( "kaggle_dataset", "Cornell-University/movie-dialog-corpus", ) # Create the data directory if it does not already exist if not os.path.exists(self.data_directory): os.makedirs(self.data_directory) def is_downloaded(self, file_path): """ Check if the data file is already downloaded. """ if os.path.exists(file_path): self.chatbot.logger.info("File is already downloaded") return True return False async def download(self, dataset): import kaggle # This triggers the API token check future = await asyncio.get_event_loop().run_in_executor( None, partial( kaggle.api.dataset_download_files, dataset=dataset, path=self.data_directory, quiet=False, unzip=True, ), ) def train(self, *args, **kwargs): log.error("See asynctrain instead") def asynctrain(self, *args, **kwargs): raise self.TrainerInitializationException() class SouthParkTrainer(KaggleTrainer): def __init__(self, chatbot, datapath: pathlib.Path, **kwargs): super().__init__( chatbot, datapath, downloadpath="ubuntu_data_v2", kaggle_dataset="tovarischsukhov/southparklines", **kwargs, ) class MovieTrainer(KaggleTrainer): def __init__(self, chatbot, datapath: pathlib.Path, **kwargs): super().__init__( chatbot, datapath, downloadpath="kaggle_movies", kaggle_dataset="Cornell-University/movie-dialog-corpus", **kwargs, ) async def run_movie_training(self): dialogue_file = "movie_lines.tsv" conversation_file = "movie_conversations.tsv" log.info(f"Beginning dialogue training on {dialogue_file}") start_time = time.time() tagger = PosLemmaTagger(language=self.chatbot.storage.tagger.language) # [lineID, characterID, movieID, character name, text of utterance] # File parsing from https://www.kaggle.com/mushaya/conversation-chatbot with open(self.data_directory / conversation_file, "r", encoding="utf-8-sig") as conv_tsv: conv_lines = conv_tsv.readlines() with open(self.data_directory / dialogue_file, "r", encoding="utf-8-sig") as lines_tsv: dialog_lines = lines_tsv.readlines() # trans_dict = str.maketrans({"": "__", "": "__", '""': '"'}) lines_dict = {} for line in dialog_lines: _line = line[:-1].strip('"').split("\t") if len(_line) >= 5: # Only good lines lines_dict[_line[0]] = ( html.unescape(("".join(_line[4:])).strip()) .replace("", "__") .replace("", "__") .replace('""', '"') ) else: log.debug(f"Bad line {_line}") # collecting line ids for each conversation conv = [] for line in conv_lines[:-1]: _line = line[:-1].split("\t")[-1][1:-1].replace("'", "").replace(" ", ",") conv.append(_line.split(",")) # conversations = csv.reader(conv_tsv, delimiter="\t") # # reader = csv.reader(lines_tsv, delimiter="\t") # # # # lines_dict = {} # for row in reader: # try: # lines_dict[row[0].strip('"')] = row[4] # except: # log.exception(f"Bad line: {row}") # pass # else: # # log.info(f"Good line: {row}") # pass # # # lines_dict = {row[0].strip('"'): row[4] for row in reader_list} statements_from_file = [] save_every = 300 count = 0 # [characterID of first, characterID of second, movieID, list of utterances] async for lines in AsyncIter(conv): previous_statement_text = None previous_statement_search_text = "" for line in lines: text = lines_dict[line] statement = Statement( text=text, in_response_to=previous_statement_text, conversation="training", ) for preprocessor in self.chatbot.preprocessors: statement = preprocessor(statement) statement.search_text = tagger.get_text_index_string(statement.text) statement.search_in_response_to = previous_statement_search_text previous_statement_text = statement.text previous_statement_search_text = statement.search_text statements_from_file.append(statement) count += 1 if count >= save_every: if statements_from_file: self.chatbot.storage.create_many(statements_from_file) statements_from_file = [] count = 0 if statements_from_file: self.chatbot.storage.create_many(statements_from_file) log.info(f"Training took {time.time() - start_time} seconds.") async def asynctrain(self, *args, **kwargs): extracted_lines = self.data_directory / "movie_lines.tsv" extracted_lines: pathlib.Path # Download and extract the Ubuntu dialog corpus if needed if not extracted_lines.exists(): await self.download(self.kaggle_dataset) else: log.info("Movie dialog already downloaded") if not extracted_lines.exists(): raise FileNotFoundError(f"{extracted_lines}") await self.run_movie_training() return True # train_dialogue = kwargs.get("train_dialogue", True) # train_196_dialogue = kwargs.get("train_196", False) # train_301_dialogue = kwargs.get("train_301", False) # # if train_dialogue: # await self.run_dialogue_training(extracted_dir, "dialogueText.csv") # # if train_196_dialogue: # await self.run_dialogue_training(extracted_dir, "dialogueText_196.csv") # # if train_301_dialogue: # await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv") class UbuntuCorpusTrainer2(KaggleTrainer): def __init__(self, chatbot, datapath: pathlib.Path, **kwargs): super().__init__( chatbot, datapath, downloadpath="kaggle_ubuntu", kaggle_dataset="rtatman/ubuntu-dialogue-corpus", **kwargs, ) async def asynctrain(self, *args, **kwargs): extracted_dir = self.data_directory / "Ubuntu-dialogue-corpus" # Download and extract the Ubuntu dialog corpus if needed if not extracted_dir.exists(): await self.download(self.kaggle_dataset) else: log.info("Ubuntu dialogue already downloaded") if not extracted_dir.exists(): raise FileNotFoundError("Did not extract in the expected way") train_dialogue = kwargs.get("train_dialogue", True) train_196_dialogue = kwargs.get("train_196", False) train_301_dialogue = kwargs.get("train_301", False) if train_dialogue: await self.run_dialogue_training(extracted_dir, "dialogueText.csv") if train_196_dialogue: await self.run_dialogue_training(extracted_dir, "dialogueText_196.csv") if train_301_dialogue: await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv") return True async def run_dialogue_training(self, extracted_dir, dialogue_file): log.info(f"Beginning dialogue training on {dialogue_file}") start_time = time.time() tagger = PosLemmaTagger(language=self.chatbot.storage.tagger.language) with open(extracted_dir / dialogue_file, "r", encoding="utf-8") as dg: reader = csv.DictReader(dg) next(reader) # Skip the header last_dialogue_id = None previous_statement_text = None previous_statement_search_text = "" statements_from_file = [] save_every = 50 count = 0 async for row in AsyncIter(reader): dialogue_id = row["dialogueID"] if dialogue_id != last_dialogue_id: previous_statement_text = None previous_statement_search_text = "" last_dialogue_id = dialogue_id count += 1 if count >= save_every: if statements_from_file: self.chatbot.storage.create_many(statements_from_file) statements_from_file = [] count = 0 if len(row) > 0: statement = Statement( text=row["text"], in_response_to=previous_statement_text, conversation="training", # created_at=date_parser.parse(row["date"]), persona=row["from"], ) for preprocessor in self.chatbot.preprocessors: statement = preprocessor(statement) statement.search_text = tagger.get_text_index_string(statement.text) statement.search_in_response_to = previous_statement_search_text previous_statement_text = statement.text previous_statement_search_text = statement.search_text statements_from_file.append(statement) if statements_from_file: self.chatbot.storage.create_many(statements_from_file) log.info(f"Training took {time.time() - start_time} seconds.") class TwitterCorpusTrainer(Trainer): pass # def train(self, *args, **kwargs): # """ # Train the chat bot based on the provided list of # statements that represents a single conversation. # """ # import twint # # c = twint.Config() # c.__dict__.update(kwargs) # twint.run.Search(c) # # # previous_statement_text = None # previous_statement_search_text = '' # # statements_to_create = [] # # for conversation_count, text in enumerate(conversation): # if self.show_training_progress: # utils.print_progress_bar( # 'List Trainer', # conversation_count + 1, len(conversation) # ) # # statement_search_text = self.chatbot.storage.tagger.get_text_index_string(text) # # statement = self.get_preprocessed_statement( # Statement( # text=text, # search_text=statement_search_text, # in_response_to=previous_statement_text, # search_in_response_to=previous_statement_search_text, # conversation='training' # ) # ) # # previous_statement_text = statement.text # previous_statement_search_text = statement_search_text # # statements_to_create.append(statement) # # self.chatbot.storage.create_many(statements_to_create)