Add new kaggle trainers
This commit is contained in:
parent
04ccb435f8
commit
8feb21e34b
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import csv
|
||||
import html
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
@ -56,13 +57,159 @@ class KaggleTrainer(Trainer):
|
||||
),
|
||||
)
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
log.error("See asynctrain instead")
|
||||
|
||||
def asynctrain(self, *args, **kwargs):
|
||||
raise self.TrainerInitializationException()
|
||||
|
||||
|
||||
class SouthParkTrainer(KaggleTrainer):
|
||||
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
|
||||
super().__init__(
|
||||
chatbot,
|
||||
datapath,
|
||||
downloadpath="ubuntu_data_v2",
|
||||
kaggle_dataset="tovarischsukhov/southparklines",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class MovieTrainer(KaggleTrainer):
|
||||
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
|
||||
super().__init__(
|
||||
chatbot,
|
||||
datapath,
|
||||
downloadpath="kaggle_movies",
|
||||
kaggle_dataset="Cornell-University/movie-dialog-corpus",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def run_movie_training(self):
|
||||
dialogue_file = "movie_lines.tsv"
|
||||
conversation_file = "movie_conversations.tsv"
|
||||
log.info(f"Beginning dialogue training on {dialogue_file}")
|
||||
start_time = time.time()
|
||||
|
||||
tagger = PosLemmaTagger(language=self.chatbot.storage.tagger.language)
|
||||
|
||||
# [lineID, characterID, movieID, character name, text of utterance]
|
||||
# File parsing from https://www.kaggle.com/mushaya/conversation-chatbot
|
||||
|
||||
with open(self.data_directory / conversation_file, "r", encoding="utf-8-sig") as conv_tsv:
|
||||
conv_lines = conv_tsv.readlines()
|
||||
with open(self.data_directory / dialogue_file, "r", encoding="utf-8-sig") as lines_tsv:
|
||||
dialog_lines = lines_tsv.readlines()
|
||||
|
||||
# trans_dict = str.maketrans({"<u>": "__", "</u>": "__", '""': '"'})
|
||||
|
||||
lines_dict = {}
|
||||
for line in dialog_lines:
|
||||
_line = line[:-1].strip('"').split("\t")
|
||||
if len(_line) >= 5: # Only good lines
|
||||
lines_dict[_line[0]] = (
|
||||
html.unescape(("".join(_line[4:])).strip())
|
||||
.replace("<u>", "__")
|
||||
.replace("</u>", "__")
|
||||
.replace('""', '"')
|
||||
)
|
||||
else:
|
||||
log.debug(f"Bad line {_line}")
|
||||
|
||||
# collecting line ids for each conversation
|
||||
conv = []
|
||||
for line in conv_lines[:-1]:
|
||||
_line = line[:-1].split("\t")[-1][1:-1].replace("'", "").replace(" ", ",")
|
||||
conv.append(_line.split(","))
|
||||
|
||||
# conversations = csv.reader(conv_tsv, delimiter="\t")
|
||||
#
|
||||
# reader = csv.reader(lines_tsv, delimiter="\t")
|
||||
#
|
||||
#
|
||||
#
|
||||
# lines_dict = {}
|
||||
# for row in reader:
|
||||
# try:
|
||||
# lines_dict[row[0].strip('"')] = row[4]
|
||||
# except:
|
||||
# log.exception(f"Bad line: {row}")
|
||||
# pass
|
||||
# else:
|
||||
# # print(f"Good line: {row}")
|
||||
# pass
|
||||
#
|
||||
# # lines_dict = {row[0].strip('"'): row[4] for row in reader_list}
|
||||
|
||||
statements_from_file = []
|
||||
|
||||
# [characterID of first, characterID of second, movieID, list of utterances]
|
||||
async for lines in AsyncIter(conv):
|
||||
previous_statement_text = None
|
||||
previous_statement_search_text = ""
|
||||
|
||||
for line in lines:
|
||||
text = lines_dict[line]
|
||||
statement = Statement(
|
||||
text=text,
|
||||
in_response_to=previous_statement_text,
|
||||
conversation="training",
|
||||
)
|
||||
|
||||
for preprocessor in self.chatbot.preprocessors:
|
||||
statement = preprocessor(statement)
|
||||
|
||||
statement.search_text = tagger.get_text_index_string(statement.text)
|
||||
statement.search_in_response_to = previous_statement_search_text
|
||||
|
||||
previous_statement_text = statement.text
|
||||
previous_statement_search_text = statement.search_text
|
||||
|
||||
statements_from_file.append(statement)
|
||||
|
||||
if statements_from_file:
|
||||
print(statements_from_file)
|
||||
self.chatbot.storage.create_many(statements_from_file)
|
||||
statements_from_file = []
|
||||
|
||||
print("Training took", time.time() - start_time, "seconds.")
|
||||
|
||||
async def asynctrain(self, *args, **kwargs):
|
||||
extracted_lines = self.data_directory / "movie_lines.tsv"
|
||||
extracted_lines: pathlib.Path
|
||||
|
||||
# Download and extract the Ubuntu dialog corpus if needed
|
||||
if not extracted_lines.exists():
|
||||
await self.download(self.kaggle_dataset)
|
||||
else:
|
||||
log.info("Movie dialog already downloaded")
|
||||
if not extracted_lines.exists():
|
||||
raise FileNotFoundError(f"{extracted_lines}")
|
||||
|
||||
await self.run_movie_training()
|
||||
|
||||
return True
|
||||
|
||||
# train_dialogue = kwargs.get("train_dialogue", True)
|
||||
# train_196_dialogue = kwargs.get("train_196", False)
|
||||
# train_301_dialogue = kwargs.get("train_301", False)
|
||||
#
|
||||
# if train_dialogue:
|
||||
# await self.run_dialogue_training(extracted_dir, "dialogueText.csv")
|
||||
#
|
||||
# if train_196_dialogue:
|
||||
# await self.run_dialogue_training(extracted_dir, "dialogueText_196.csv")
|
||||
#
|
||||
# if train_301_dialogue:
|
||||
# await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv")
|
||||
|
||||
|
||||
class UbuntuCorpusTrainer2(KaggleTrainer):
|
||||
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
|
||||
super().__init__(
|
||||
chatbot,
|
||||
datapath,
|
||||
downloadpath="ubuntu_data_v2",
|
||||
downloadpath="kaggle_ubuntu",
|
||||
kaggle_dataset="rtatman/ubuntu-dialogue-corpus",
|
||||
**kwargs,
|
||||
)
|
||||
@ -91,6 +238,8 @@ class UbuntuCorpusTrainer2(KaggleTrainer):
|
||||
if train_301_dialogue:
|
||||
await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv")
|
||||
|
||||
return True
|
||||
|
||||
async def run_dialogue_training(self, extracted_dir, dialogue_file):
|
||||
log.info(f"Beginning dialogue training on {dialogue_file}")
|
||||
start_time = time.time()
|
||||
@ -120,6 +269,7 @@ class UbuntuCorpusTrainer2(KaggleTrainer):
|
||||
if count >= save_every:
|
||||
if statements_from_file:
|
||||
self.chatbot.storage.create_many(statements_from_file)
|
||||
statements_from_file = []
|
||||
count = 0
|
||||
|
||||
if len(row) > 0:
|
||||
@ -147,9 +297,6 @@ class UbuntuCorpusTrainer2(KaggleTrainer):
|
||||
|
||||
print("Training took", time.time() - start_time, "seconds.")
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
log.error("See asynctrain instead")
|
||||
|
||||
|
||||
class TwitterCorpusTrainer(Trainer):
|
||||
pass
|
||||
|
Loading…
x
Reference in New Issue
Block a user