You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Fox-V3/chatter/trainers.py

344 lines
12 KiB

import asyncio
import csv
import html
import logging
import os
import pathlib
import time
from functools import partial
from chatterbot import utils
from chatterbot.conversation import Statement
from chatterbot.tagging import PosLemmaTagger
from chatterbot.trainers import Trainer
from redbot.core.bot import Red
from dateutil import parser as date_parser
from redbot.core.utils import AsyncIter
log = logging.getLogger("red.fox_v3.chatter.trainers")
class KaggleTrainer(Trainer):
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
super().__init__(chatbot, **kwargs)
self.data_directory = datapath / kwargs.get("downloadpath", "kaggle_download")
self.kaggle_dataset = kwargs.get(
"kaggle_dataset",
"Cornell-University/movie-dialog-corpus",
)
# Create the data directory if it does not already exist
if not os.path.exists(self.data_directory):
os.makedirs(self.data_directory)
def is_downloaded(self, file_path):
"""
Check if the data file is already downloaded.
"""
if os.path.exists(file_path):
self.chatbot.logger.info("File is already downloaded")
return True
return False
async def download(self, dataset):
import kaggle # This triggers the API token check
future = await asyncio.get_event_loop().run_in_executor(
None,
partial(
kaggle.api.dataset_download_files,
dataset=dataset,
path=self.data_directory,
quiet=False,
unzip=True,
),
)
def train(self, *args, **kwargs):
log.error("See asynctrain instead")
def asynctrain(self, *args, **kwargs):
raise self.TrainerInitializationException()
class SouthParkTrainer(KaggleTrainer):
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
super().__init__(
chatbot,
datapath,
downloadpath="ubuntu_data_v2",
kaggle_dataset="tovarischsukhov/southparklines",
**kwargs,
)
class MovieTrainer(KaggleTrainer):
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
super().__init__(
chatbot,
datapath,
downloadpath="kaggle_movies",
kaggle_dataset="Cornell-University/movie-dialog-corpus",
**kwargs,
)
async def run_movie_training(self):
dialogue_file = "movie_lines.tsv"
conversation_file = "movie_conversations.tsv"
log.info(f"Beginning dialogue training on {dialogue_file}")
start_time = time.time()
tagger = PosLemmaTagger(language=self.chatbot.storage.tagger.language)
# [lineID, characterID, movieID, character name, text of utterance]
# File parsing from https://www.kaggle.com/mushaya/conversation-chatbot
with open(self.data_directory / conversation_file, "r", encoding="utf-8-sig") as conv_tsv:
conv_lines = conv_tsv.readlines()
with open(self.data_directory / dialogue_file, "r", encoding="utf-8-sig") as lines_tsv:
dialog_lines = lines_tsv.readlines()
# trans_dict = str.maketrans({"<u>": "__", "</u>": "__", '""': '"'})
lines_dict = {}
for line in dialog_lines:
_line = line[:-1].strip('"').split("\t")
if len(_line) >= 5: # Only good lines
lines_dict[_line[0]] = (
html.unescape(("".join(_line[4:])).strip())
.replace("<u>", "__")
.replace("</u>", "__")
.replace('""', '"')
)
else:
log.debug(f"Bad line {_line}")
# collecting line ids for each conversation
conv = []
for line in conv_lines[:-1]:
_line = line[:-1].split("\t")[-1][1:-1].replace("'", "").replace(" ", ",")
conv.append(_line.split(","))
# conversations = csv.reader(conv_tsv, delimiter="\t")
#
# reader = csv.reader(lines_tsv, delimiter="\t")
#
#
#
# lines_dict = {}
# for row in reader:
# try:
# lines_dict[row[0].strip('"')] = row[4]
# except:
# log.exception(f"Bad line: {row}")
# pass
# else:
# # log.info(f"Good line: {row}")
# pass
#
# # lines_dict = {row[0].strip('"'): row[4] for row in reader_list}
statements_from_file = []
# [characterID of first, characterID of second, movieID, list of utterances]
async for lines in AsyncIter(conv):
previous_statement_text = None
previous_statement_search_text = ""
for line in lines:
text = lines_dict[line]
statement = Statement(
text=text,
in_response_to=previous_statement_text,
conversation="training",
)
for preprocessor in self.chatbot.preprocessors:
statement = preprocessor(statement)
statement.search_text = tagger.get_text_index_string(statement.text)
statement.search_in_response_to = previous_statement_search_text
previous_statement_text = statement.text
previous_statement_search_text = statement.search_text
statements_from_file.append(statement)
if statements_from_file:
self.chatbot.storage.create_many(statements_from_file)
statements_from_file = []
log.info("Training took", time.time() - start_time, "seconds.")
async def asynctrain(self, *args, **kwargs):
extracted_lines = self.data_directory / "movie_lines.tsv"
extracted_lines: pathlib.Path
# Download and extract the Ubuntu dialog corpus if needed
if not extracted_lines.exists():
await self.download(self.kaggle_dataset)
else:
log.info("Movie dialog already downloaded")
if not extracted_lines.exists():
raise FileNotFoundError(f"{extracted_lines}")
await self.run_movie_training()
return True
# train_dialogue = kwargs.get("train_dialogue", True)
# train_196_dialogue = kwargs.get("train_196", False)
# train_301_dialogue = kwargs.get("train_301", False)
#
# if train_dialogue:
# await self.run_dialogue_training(extracted_dir, "dialogueText.csv")
#
# if train_196_dialogue:
# await self.run_dialogue_training(extracted_dir, "dialogueText_196.csv")
#
# if train_301_dialogue:
# await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv")
class UbuntuCorpusTrainer2(KaggleTrainer):
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
super().__init__(
chatbot,
datapath,
downloadpath="kaggle_ubuntu",
kaggle_dataset="rtatman/ubuntu-dialogue-corpus",
**kwargs,
)
async def asynctrain(self, *args, **kwargs):
extracted_dir = self.data_directory / "Ubuntu-dialogue-corpus"
# Download and extract the Ubuntu dialog corpus if needed
if not extracted_dir.exists():
await self.download(self.kaggle_dataset)
else:
log.info("Ubuntu dialogue already downloaded")
if not extracted_dir.exists():
raise FileNotFoundError("Did not extract in the expected way")
train_dialogue = kwargs.get("train_dialogue", True)
train_196_dialogue = kwargs.get("train_196", False)
train_301_dialogue = kwargs.get("train_301", False)
if train_dialogue:
await self.run_dialogue_training(extracted_dir, "dialogueText.csv")
if train_196_dialogue:
await self.run_dialogue_training(extracted_dir, "dialogueText_196.csv")
if train_301_dialogue:
await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv")
return True
async def run_dialogue_training(self, extracted_dir, dialogue_file):
log.info(f"Beginning dialogue training on {dialogue_file}")
start_time = time.time()
tagger = PosLemmaTagger(language=self.chatbot.storage.tagger.language)
with open(extracted_dir / dialogue_file, "r", encoding="utf-8") as dg:
reader = csv.DictReader(dg)
next(reader) # Skip the header
last_dialogue_id = None
previous_statement_text = None
previous_statement_search_text = ""
statements_from_file = []
save_every = 50
count = 0
async for row in AsyncIter(reader):
dialogue_id = row["dialogueID"]
if dialogue_id != last_dialogue_id:
previous_statement_text = None
previous_statement_search_text = ""
last_dialogue_id = dialogue_id
count += 1
if count >= save_every:
if statements_from_file:
self.chatbot.storage.create_many(statements_from_file)
statements_from_file = []
count = 0
if len(row) > 0:
statement = Statement(
text=row["text"],
in_response_to=previous_statement_text,
conversation="training",
# created_at=date_parser.parse(row["date"]),
persona=row["from"],
)
for preprocessor in self.chatbot.preprocessors:
statement = preprocessor(statement)
statement.search_text = tagger.get_text_index_string(statement.text)
statement.search_in_response_to = previous_statement_search_text
previous_statement_text = statement.text
previous_statement_search_text = statement.search_text
statements_from_file.append(statement)
if statements_from_file:
self.chatbot.storage.create_many(statements_from_file)
log.info("Training took", time.time() - start_time, "seconds.")
class TwitterCorpusTrainer(Trainer):
pass
# def train(self, *args, **kwargs):
# """
# Train the chat bot based on the provided list of
# statements that represents a single conversation.
# """
# import twint
#
# c = twint.Config()
# c.__dict__.update(kwargs)
# twint.run.Search(c)
#
#
# previous_statement_text = None
# previous_statement_search_text = ''
#
# statements_to_create = []
#
# for conversation_count, text in enumerate(conversation):
# if self.show_training_progress:
# utils.print_progress_bar(
# 'List Trainer',
# conversation_count + 1, len(conversation)
# )
#
# statement_search_text = self.chatbot.storage.tagger.get_text_index_string(text)
#
# statement = self.get_preprocessed_statement(
# Statement(
# text=text,
# search_text=statement_search_text,
# in_response_to=previous_statement_text,
# search_in_response_to=previous_statement_search_text,
# conversation='training'
# )
# )
#
# previous_statement_text = statement.text
# previous_statement_search_text = statement_search_text
#
# statements_to_create.append(statement)
#
# self.chatbot.storage.create_many(statements_to_create)