From 8feb21e34b70f26acf12c7d5af46e673032c9dc6 Mon Sep 17 00:00:00 2001 From: bobloy Date: Thu, 25 Mar 2021 09:52:20 -0400 Subject: [PATCH] Add new kaggle trainers --- chatter/trainers.py | 155 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 151 insertions(+), 4 deletions(-) diff --git a/chatter/trainers.py b/chatter/trainers.py index 1fe5f62..d8de22c 100644 --- a/chatter/trainers.py +++ b/chatter/trainers.py @@ -1,5 +1,6 @@ import asyncio import csv +import html import logging import os import pathlib @@ -56,13 +57,159 @@ class KaggleTrainer(Trainer): ), ) + def train(self, *args, **kwargs): + log.error("See asynctrain instead") -class UbuntuCorpusTrainer2(KaggleTrainer): + def asynctrain(self, *args, **kwargs): + raise self.TrainerInitializationException() + + +class SouthParkTrainer(KaggleTrainer): def __init__(self, chatbot, datapath: pathlib.Path, **kwargs): super().__init__( chatbot, datapath, downloadpath="ubuntu_data_v2", + kaggle_dataset="tovarischsukhov/southparklines", + **kwargs, + ) + + +class MovieTrainer(KaggleTrainer): + def __init__(self, chatbot, datapath: pathlib.Path, **kwargs): + super().__init__( + chatbot, + datapath, + downloadpath="kaggle_movies", + kaggle_dataset="Cornell-University/movie-dialog-corpus", + **kwargs, + ) + + async def run_movie_training(self): + dialogue_file = "movie_lines.tsv" + conversation_file = "movie_conversations.tsv" + log.info(f"Beginning dialogue training on {dialogue_file}") + start_time = time.time() + + tagger = PosLemmaTagger(language=self.chatbot.storage.tagger.language) + + # [lineID, characterID, movieID, character name, text of utterance] + # File parsing from https://www.kaggle.com/mushaya/conversation-chatbot + + with open(self.data_directory / conversation_file, "r", encoding="utf-8-sig") as conv_tsv: + conv_lines = conv_tsv.readlines() + with open(self.data_directory / dialogue_file, "r", encoding="utf-8-sig") as lines_tsv: + dialog_lines = lines_tsv.readlines() + + # trans_dict = str.maketrans({"": "__", "": "__", '""': '"'}) + + lines_dict = {} + for line in dialog_lines: + _line = line[:-1].strip('"').split("\t") + if len(_line) >= 5: # Only good lines + lines_dict[_line[0]] = ( + html.unescape(("".join(_line[4:])).strip()) + .replace("", "__") + .replace("", "__") + .replace('""', '"') + ) + else: + log.debug(f"Bad line {_line}") + + # collecting line ids for each conversation + conv = [] + for line in conv_lines[:-1]: + _line = line[:-1].split("\t")[-1][1:-1].replace("'", "").replace(" ", ",") + conv.append(_line.split(",")) + + # conversations = csv.reader(conv_tsv, delimiter="\t") + # + # reader = csv.reader(lines_tsv, delimiter="\t") + # + # + # + # lines_dict = {} + # for row in reader: + # try: + # lines_dict[row[0].strip('"')] = row[4] + # except: + # log.exception(f"Bad line: {row}") + # pass + # else: + # # print(f"Good line: {row}") + # pass + # + # # lines_dict = {row[0].strip('"'): row[4] for row in reader_list} + + statements_from_file = [] + + # [characterID of first, characterID of second, movieID, list of utterances] + async for lines in AsyncIter(conv): + previous_statement_text = None + previous_statement_search_text = "" + + for line in lines: + text = lines_dict[line] + statement = Statement( + text=text, + in_response_to=previous_statement_text, + conversation="training", + ) + + for preprocessor in self.chatbot.preprocessors: + statement = preprocessor(statement) + + statement.search_text = tagger.get_text_index_string(statement.text) + statement.search_in_response_to = previous_statement_search_text + + previous_statement_text = statement.text + previous_statement_search_text = statement.search_text + + statements_from_file.append(statement) + + if statements_from_file: + print(statements_from_file) + self.chatbot.storage.create_many(statements_from_file) + statements_from_file = [] + + print("Training took", time.time() - start_time, "seconds.") + + async def asynctrain(self, *args, **kwargs): + extracted_lines = self.data_directory / "movie_lines.tsv" + extracted_lines: pathlib.Path + + # Download and extract the Ubuntu dialog corpus if needed + if not extracted_lines.exists(): + await self.download(self.kaggle_dataset) + else: + log.info("Movie dialog already downloaded") + if not extracted_lines.exists(): + raise FileNotFoundError(f"{extracted_lines}") + + await self.run_movie_training() + + return True + + # train_dialogue = kwargs.get("train_dialogue", True) + # train_196_dialogue = kwargs.get("train_196", False) + # train_301_dialogue = kwargs.get("train_301", False) + # + # if train_dialogue: + # await self.run_dialogue_training(extracted_dir, "dialogueText.csv") + # + # if train_196_dialogue: + # await self.run_dialogue_training(extracted_dir, "dialogueText_196.csv") + # + # if train_301_dialogue: + # await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv") + + +class UbuntuCorpusTrainer2(KaggleTrainer): + def __init__(self, chatbot, datapath: pathlib.Path, **kwargs): + super().__init__( + chatbot, + datapath, + downloadpath="kaggle_ubuntu", kaggle_dataset="rtatman/ubuntu-dialogue-corpus", **kwargs, ) @@ -91,6 +238,8 @@ class UbuntuCorpusTrainer2(KaggleTrainer): if train_301_dialogue: await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv") + return True + async def run_dialogue_training(self, extracted_dir, dialogue_file): log.info(f"Beginning dialogue training on {dialogue_file}") start_time = time.time() @@ -120,6 +269,7 @@ class UbuntuCorpusTrainer2(KaggleTrainer): if count >= save_every: if statements_from_file: self.chatbot.storage.create_many(statements_from_file) + statements_from_file = [] count = 0 if len(row) > 0: @@ -147,9 +297,6 @@ class UbuntuCorpusTrainer2(KaggleTrainer): print("Training took", time.time() - start_time, "seconds.") - def train(self, *args, **kwargs): - log.error("See asynctrain instead") - class TwitterCorpusTrainer(Trainer): pass