diff --git a/chatter/trainers.py b/chatter/trainers.py
index 1fe5f62..d8de22c 100644
--- a/chatter/trainers.py
+++ b/chatter/trainers.py
@@ -1,5 +1,6 @@
import asyncio
import csv
+import html
import logging
import os
import pathlib
@@ -56,13 +57,159 @@ class KaggleTrainer(Trainer):
),
)
+ def train(self, *args, **kwargs):
+ log.error("See asynctrain instead")
-class UbuntuCorpusTrainer2(KaggleTrainer):
+ def asynctrain(self, *args, **kwargs):
+ raise self.TrainerInitializationException()
+
+
+class SouthParkTrainer(KaggleTrainer):
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
super().__init__(
chatbot,
datapath,
downloadpath="ubuntu_data_v2",
+ kaggle_dataset="tovarischsukhov/southparklines",
+ **kwargs,
+ )
+
+
+class MovieTrainer(KaggleTrainer):
+ def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
+ super().__init__(
+ chatbot,
+ datapath,
+ downloadpath="kaggle_movies",
+ kaggle_dataset="Cornell-University/movie-dialog-corpus",
+ **kwargs,
+ )
+
+ async def run_movie_training(self):
+ dialogue_file = "movie_lines.tsv"
+ conversation_file = "movie_conversations.tsv"
+ log.info(f"Beginning dialogue training on {dialogue_file}")
+ start_time = time.time()
+
+ tagger = PosLemmaTagger(language=self.chatbot.storage.tagger.language)
+
+ # [lineID, characterID, movieID, character name, text of utterance]
+ # File parsing from https://www.kaggle.com/mushaya/conversation-chatbot
+
+ with open(self.data_directory / conversation_file, "r", encoding="utf-8-sig") as conv_tsv:
+ conv_lines = conv_tsv.readlines()
+ with open(self.data_directory / dialogue_file, "r", encoding="utf-8-sig") as lines_tsv:
+ dialog_lines = lines_tsv.readlines()
+
+ # trans_dict = str.maketrans({"": "__", "": "__", '""': '"'})
+
+ lines_dict = {}
+ for line in dialog_lines:
+ _line = line[:-1].strip('"').split("\t")
+ if len(_line) >= 5: # Only good lines
+ lines_dict[_line[0]] = (
+ html.unescape(("".join(_line[4:])).strip())
+ .replace("", "__")
+ .replace("", "__")
+ .replace('""', '"')
+ )
+ else:
+ log.debug(f"Bad line {_line}")
+
+ # collecting line ids for each conversation
+ conv = []
+ for line in conv_lines[:-1]:
+ _line = line[:-1].split("\t")[-1][1:-1].replace("'", "").replace(" ", ",")
+ conv.append(_line.split(","))
+
+ # conversations = csv.reader(conv_tsv, delimiter="\t")
+ #
+ # reader = csv.reader(lines_tsv, delimiter="\t")
+ #
+ #
+ #
+ # lines_dict = {}
+ # for row in reader:
+ # try:
+ # lines_dict[row[0].strip('"')] = row[4]
+ # except:
+ # log.exception(f"Bad line: {row}")
+ # pass
+ # else:
+ # # print(f"Good line: {row}")
+ # pass
+ #
+ # # lines_dict = {row[0].strip('"'): row[4] for row in reader_list}
+
+ statements_from_file = []
+
+ # [characterID of first, characterID of second, movieID, list of utterances]
+ async for lines in AsyncIter(conv):
+ previous_statement_text = None
+ previous_statement_search_text = ""
+
+ for line in lines:
+ text = lines_dict[line]
+ statement = Statement(
+ text=text,
+ in_response_to=previous_statement_text,
+ conversation="training",
+ )
+
+ for preprocessor in self.chatbot.preprocessors:
+ statement = preprocessor(statement)
+
+ statement.search_text = tagger.get_text_index_string(statement.text)
+ statement.search_in_response_to = previous_statement_search_text
+
+ previous_statement_text = statement.text
+ previous_statement_search_text = statement.search_text
+
+ statements_from_file.append(statement)
+
+ if statements_from_file:
+ print(statements_from_file)
+ self.chatbot.storage.create_many(statements_from_file)
+ statements_from_file = []
+
+ print("Training took", time.time() - start_time, "seconds.")
+
+ async def asynctrain(self, *args, **kwargs):
+ extracted_lines = self.data_directory / "movie_lines.tsv"
+ extracted_lines: pathlib.Path
+
+ # Download and extract the Ubuntu dialog corpus if needed
+ if not extracted_lines.exists():
+ await self.download(self.kaggle_dataset)
+ else:
+ log.info("Movie dialog already downloaded")
+ if not extracted_lines.exists():
+ raise FileNotFoundError(f"{extracted_lines}")
+
+ await self.run_movie_training()
+
+ return True
+
+ # train_dialogue = kwargs.get("train_dialogue", True)
+ # train_196_dialogue = kwargs.get("train_196", False)
+ # train_301_dialogue = kwargs.get("train_301", False)
+ #
+ # if train_dialogue:
+ # await self.run_dialogue_training(extracted_dir, "dialogueText.csv")
+ #
+ # if train_196_dialogue:
+ # await self.run_dialogue_training(extracted_dir, "dialogueText_196.csv")
+ #
+ # if train_301_dialogue:
+ # await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv")
+
+
+class UbuntuCorpusTrainer2(KaggleTrainer):
+ def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
+ super().__init__(
+ chatbot,
+ datapath,
+ downloadpath="kaggle_ubuntu",
kaggle_dataset="rtatman/ubuntu-dialogue-corpus",
**kwargs,
)
@@ -91,6 +238,8 @@ class UbuntuCorpusTrainer2(KaggleTrainer):
if train_301_dialogue:
await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv")
+ return True
+
async def run_dialogue_training(self, extracted_dir, dialogue_file):
log.info(f"Beginning dialogue training on {dialogue_file}")
start_time = time.time()
@@ -120,6 +269,7 @@ class UbuntuCorpusTrainer2(KaggleTrainer):
if count >= save_every:
if statements_from_file:
self.chatbot.storage.create_many(statements_from_file)
+ statements_from_file = []
count = 0
if len(row) > 0:
@@ -147,9 +297,6 @@ class UbuntuCorpusTrainer2(KaggleTrainer):
print("Training took", time.time() - start_time, "seconds.")
- def train(self, *args, **kwargs):
- log.error("See asynctrain instead")
-
class TwitterCorpusTrainer(Trainer):
pass