commit
bd45f353a8
@ -0,0 +1,73 @@
|
|||||||
|
from chatterbot.storage import StorageAdapter, SQLStorageAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class MyDumbSQLStorageAdapter(SQLStorageAdapter):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(SQLStorageAdapter, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
self.database_uri = kwargs.get("database_uri", False)
|
||||||
|
|
||||||
|
# None results in a sqlite in-memory database as the default
|
||||||
|
if self.database_uri is None:
|
||||||
|
self.database_uri = "sqlite://"
|
||||||
|
|
||||||
|
# Create a file database if the database is not a connection string
|
||||||
|
if not self.database_uri:
|
||||||
|
self.database_uri = "sqlite:///db.sqlite3"
|
||||||
|
|
||||||
|
self.engine = create_engine(
|
||||||
|
self.database_uri, convert_unicode=True, connect_args={"check_same_thread": False}
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.database_uri.startswith("sqlite://"):
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy import event
|
||||||
|
|
||||||
|
@event.listens_for(Engine, "connect")
|
||||||
|
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||||
|
dbapi_connection.execute("PRAGMA journal_mode=WAL")
|
||||||
|
dbapi_connection.execute("PRAGMA synchronous=NORMAL")
|
||||||
|
|
||||||
|
if not self.engine.dialect.has_table(self.engine, "Statement"):
|
||||||
|
self.create_database()
|
||||||
|
|
||||||
|
self.Session = sessionmaker(bind=self.engine, expire_on_commit=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncSQLStorageAdapter(SQLStorageAdapter):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(SQLStorageAdapter, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
self.database_uri = kwargs.get("database_uri", False)
|
||||||
|
|
||||||
|
# None results in a sqlite in-memory database as the default
|
||||||
|
if self.database_uri is None:
|
||||||
|
self.database_uri = "sqlite://"
|
||||||
|
|
||||||
|
# Create a file database if the database is not a connection string
|
||||||
|
if not self.database_uri:
|
||||||
|
self.database_uri = "sqlite:///db.sqlite3"
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
# from sqlalchemy import create_engine
|
||||||
|
from aiomysql.sa import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
self.engine = await create_engine(self.database_uri, convert_unicode=True)
|
||||||
|
|
||||||
|
if self.database_uri.startswith("sqlite://"):
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy import event
|
||||||
|
|
||||||
|
@event.listens_for(Engine, "connect")
|
||||||
|
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||||
|
dbapi_connection.execute("PRAGMA journal_mode=WAL")
|
||||||
|
dbapi_connection.execute("PRAGMA synchronous=NORMAL")
|
||||||
|
|
||||||
|
if not self.engine.dialect.has_table(self.engine, "Statement"):
|
||||||
|
self.create_database()
|
||||||
|
|
||||||
|
self.Session = sessionmaker(bind=self.engine, expire_on_commit=True)
|
@ -0,0 +1,351 @@
|
|||||||
|
import asyncio
|
||||||
|
import csv
|
||||||
|
import html
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import time
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from chatterbot import utils
|
||||||
|
from chatterbot.conversation import Statement
|
||||||
|
from chatterbot.tagging import PosLemmaTagger
|
||||||
|
from chatterbot.trainers import Trainer
|
||||||
|
from redbot.core.bot import Red
|
||||||
|
from dateutil import parser as date_parser
|
||||||
|
from redbot.core.utils import AsyncIter
|
||||||
|
|
||||||
|
log = logging.getLogger("red.fox_v3.chatter.trainers")
|
||||||
|
|
||||||
|
|
||||||
|
class KaggleTrainer(Trainer):
|
||||||
|
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
|
||||||
|
super().__init__(chatbot, **kwargs)
|
||||||
|
|
||||||
|
self.data_directory = datapath / kwargs.get("downloadpath", "kaggle_download")
|
||||||
|
|
||||||
|
self.kaggle_dataset = kwargs.get(
|
||||||
|
"kaggle_dataset",
|
||||||
|
"Cornell-University/movie-dialog-corpus",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the data directory if it does not already exist
|
||||||
|
if not os.path.exists(self.data_directory):
|
||||||
|
os.makedirs(self.data_directory)
|
||||||
|
|
||||||
|
def is_downloaded(self, file_path):
|
||||||
|
"""
|
||||||
|
Check if the data file is already downloaded.
|
||||||
|
"""
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
self.chatbot.logger.info("File is already downloaded")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def download(self, dataset):
|
||||||
|
import kaggle # This triggers the API token check
|
||||||
|
|
||||||
|
future = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
partial(
|
||||||
|
kaggle.api.dataset_download_files,
|
||||||
|
dataset=dataset,
|
||||||
|
path=self.data_directory,
|
||||||
|
quiet=False,
|
||||||
|
unzip=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def train(self, *args, **kwargs):
|
||||||
|
log.error("See asynctrain instead")
|
||||||
|
|
||||||
|
def asynctrain(self, *args, **kwargs):
|
||||||
|
raise self.TrainerInitializationException()
|
||||||
|
|
||||||
|
|
||||||
|
class SouthParkTrainer(KaggleTrainer):
|
||||||
|
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
chatbot,
|
||||||
|
datapath,
|
||||||
|
downloadpath="ubuntu_data_v2",
|
||||||
|
kaggle_dataset="tovarischsukhov/southparklines",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MovieTrainer(KaggleTrainer):
|
||||||
|
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
chatbot,
|
||||||
|
datapath,
|
||||||
|
downloadpath="kaggle_movies",
|
||||||
|
kaggle_dataset="Cornell-University/movie-dialog-corpus",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_movie_training(self):
|
||||||
|
dialogue_file = "movie_lines.tsv"
|
||||||
|
conversation_file = "movie_conversations.tsv"
|
||||||
|
log.info(f"Beginning dialogue training on {dialogue_file}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
tagger = PosLemmaTagger(language=self.chatbot.storage.tagger.language)
|
||||||
|
|
||||||
|
# [lineID, characterID, movieID, character name, text of utterance]
|
||||||
|
# File parsing from https://www.kaggle.com/mushaya/conversation-chatbot
|
||||||
|
|
||||||
|
with open(self.data_directory / conversation_file, "r", encoding="utf-8-sig") as conv_tsv:
|
||||||
|
conv_lines = conv_tsv.readlines()
|
||||||
|
with open(self.data_directory / dialogue_file, "r", encoding="utf-8-sig") as lines_tsv:
|
||||||
|
dialog_lines = lines_tsv.readlines()
|
||||||
|
|
||||||
|
# trans_dict = str.maketrans({"<u>": "__", "</u>": "__", '""': '"'})
|
||||||
|
|
||||||
|
lines_dict = {}
|
||||||
|
for line in dialog_lines:
|
||||||
|
_line = line[:-1].strip('"').split("\t")
|
||||||
|
if len(_line) >= 5: # Only good lines
|
||||||
|
lines_dict[_line[0]] = (
|
||||||
|
html.unescape(("".join(_line[4:])).strip())
|
||||||
|
.replace("<u>", "__")
|
||||||
|
.replace("</u>", "__")
|
||||||
|
.replace('""', '"')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log.debug(f"Bad line {_line}")
|
||||||
|
|
||||||
|
# collecting line ids for each conversation
|
||||||
|
conv = []
|
||||||
|
for line in conv_lines[:-1]:
|
||||||
|
_line = line[:-1].split("\t")[-1][1:-1].replace("'", "").replace(" ", ",")
|
||||||
|
conv.append(_line.split(","))
|
||||||
|
|
||||||
|
# conversations = csv.reader(conv_tsv, delimiter="\t")
|
||||||
|
#
|
||||||
|
# reader = csv.reader(lines_tsv, delimiter="\t")
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# lines_dict = {}
|
||||||
|
# for row in reader:
|
||||||
|
# try:
|
||||||
|
# lines_dict[row[0].strip('"')] = row[4]
|
||||||
|
# except:
|
||||||
|
# log.exception(f"Bad line: {row}")
|
||||||
|
# pass
|
||||||
|
# else:
|
||||||
|
# # log.info(f"Good line: {row}")
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
# # lines_dict = {row[0].strip('"'): row[4] for row in reader_list}
|
||||||
|
|
||||||
|
statements_from_file = []
|
||||||
|
save_every = 300
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# [characterID of first, characterID of second, movieID, list of utterances]
|
||||||
|
async for lines in AsyncIter(conv):
|
||||||
|
previous_statement_text = None
|
||||||
|
previous_statement_search_text = ""
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
text = lines_dict[line]
|
||||||
|
statement = Statement(
|
||||||
|
text=text,
|
||||||
|
in_response_to=previous_statement_text,
|
||||||
|
conversation="training",
|
||||||
|
)
|
||||||
|
|
||||||
|
for preprocessor in self.chatbot.preprocessors:
|
||||||
|
statement = preprocessor(statement)
|
||||||
|
|
||||||
|
statement.search_text = tagger.get_text_index_string(statement.text)
|
||||||
|
statement.search_in_response_to = previous_statement_search_text
|
||||||
|
|
||||||
|
previous_statement_text = statement.text
|
||||||
|
previous_statement_search_text = statement.search_text
|
||||||
|
|
||||||
|
statements_from_file.append(statement)
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
if count >= save_every:
|
||||||
|
if statements_from_file:
|
||||||
|
self.chatbot.storage.create_many(statements_from_file)
|
||||||
|
statements_from_file = []
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
if statements_from_file:
|
||||||
|
self.chatbot.storage.create_many(statements_from_file)
|
||||||
|
|
||||||
|
log.info(f"Training took {time.time() - start_time} seconds.")
|
||||||
|
|
||||||
|
async def asynctrain(self, *args, **kwargs):
|
||||||
|
extracted_lines = self.data_directory / "movie_lines.tsv"
|
||||||
|
extracted_lines: pathlib.Path
|
||||||
|
|
||||||
|
# Download and extract the Ubuntu dialog corpus if needed
|
||||||
|
if not extracted_lines.exists():
|
||||||
|
await self.download(self.kaggle_dataset)
|
||||||
|
else:
|
||||||
|
log.info("Movie dialog already downloaded")
|
||||||
|
if not extracted_lines.exists():
|
||||||
|
raise FileNotFoundError(f"{extracted_lines}")
|
||||||
|
|
||||||
|
await self.run_movie_training()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
# train_dialogue = kwargs.get("train_dialogue", True)
|
||||||
|
# train_196_dialogue = kwargs.get("train_196", False)
|
||||||
|
# train_301_dialogue = kwargs.get("train_301", False)
|
||||||
|
#
|
||||||
|
# if train_dialogue:
|
||||||
|
# await self.run_dialogue_training(extracted_dir, "dialogueText.csv")
|
||||||
|
#
|
||||||
|
# if train_196_dialogue:
|
||||||
|
# await self.run_dialogue_training(extracted_dir, "dialogueText_196.csv")
|
||||||
|
#
|
||||||
|
# if train_301_dialogue:
|
||||||
|
# await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv")
|
||||||
|
|
||||||
|
|
||||||
|
class UbuntuCorpusTrainer2(KaggleTrainer):
|
||||||
|
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
chatbot,
|
||||||
|
datapath,
|
||||||
|
downloadpath="kaggle_ubuntu",
|
||||||
|
kaggle_dataset="rtatman/ubuntu-dialogue-corpus",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def asynctrain(self, *args, **kwargs):
|
||||||
|
extracted_dir = self.data_directory / "Ubuntu-dialogue-corpus"
|
||||||
|
|
||||||
|
# Download and extract the Ubuntu dialog corpus if needed
|
||||||
|
if not extracted_dir.exists():
|
||||||
|
await self.download(self.kaggle_dataset)
|
||||||
|
else:
|
||||||
|
log.info("Ubuntu dialogue already downloaded")
|
||||||
|
if not extracted_dir.exists():
|
||||||
|
raise FileNotFoundError("Did not extract in the expected way")
|
||||||
|
|
||||||
|
train_dialogue = kwargs.get("train_dialogue", True)
|
||||||
|
train_196_dialogue = kwargs.get("train_196", False)
|
||||||
|
train_301_dialogue = kwargs.get("train_301", False)
|
||||||
|
|
||||||
|
if train_dialogue:
|
||||||
|
await self.run_dialogue_training(extracted_dir, "dialogueText.csv")
|
||||||
|
|
||||||
|
if train_196_dialogue:
|
||||||
|
await self.run_dialogue_training(extracted_dir, "dialogueText_196.csv")
|
||||||
|
|
||||||
|
if train_301_dialogue:
|
||||||
|
await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def run_dialogue_training(self, extracted_dir, dialogue_file):
|
||||||
|
log.info(f"Beginning dialogue training on {dialogue_file}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
tagger = PosLemmaTagger(language=self.chatbot.storage.tagger.language)
|
||||||
|
|
||||||
|
with open(extracted_dir / dialogue_file, "r", encoding="utf-8") as dg:
|
||||||
|
reader = csv.DictReader(dg)
|
||||||
|
|
||||||
|
next(reader) # Skip the header
|
||||||
|
|
||||||
|
last_dialogue_id = None
|
||||||
|
previous_statement_text = None
|
||||||
|
previous_statement_search_text = ""
|
||||||
|
statements_from_file = []
|
||||||
|
|
||||||
|
save_every = 50
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
async for row in AsyncIter(reader):
|
||||||
|
dialogue_id = row["dialogueID"]
|
||||||
|
if dialogue_id != last_dialogue_id:
|
||||||
|
previous_statement_text = None
|
||||||
|
previous_statement_search_text = ""
|
||||||
|
last_dialogue_id = dialogue_id
|
||||||
|
count += 1
|
||||||
|
if count >= save_every:
|
||||||
|
if statements_from_file:
|
||||||
|
self.chatbot.storage.create_many(statements_from_file)
|
||||||
|
statements_from_file = []
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
if len(row) > 0:
|
||||||
|
statement = Statement(
|
||||||
|
text=row["text"],
|
||||||
|
in_response_to=previous_statement_text,
|
||||||
|
conversation="training",
|
||||||
|
# created_at=date_parser.parse(row["date"]),
|
||||||
|
persona=row["from"],
|
||||||
|
)
|
||||||
|
|
||||||
|
for preprocessor in self.chatbot.preprocessors:
|
||||||
|
statement = preprocessor(statement)
|
||||||
|
|
||||||
|
statement.search_text = tagger.get_text_index_string(statement.text)
|
||||||
|
statement.search_in_response_to = previous_statement_search_text
|
||||||
|
|
||||||
|
previous_statement_text = statement.text
|
||||||
|
previous_statement_search_text = statement.search_text
|
||||||
|
|
||||||
|
statements_from_file.append(statement)
|
||||||
|
|
||||||
|
if statements_from_file:
|
||||||
|
self.chatbot.storage.create_many(statements_from_file)
|
||||||
|
|
||||||
|
log.info(f"Training took {time.time() - start_time} seconds.")
|
||||||
|
|
||||||
|
|
||||||
|
class TwitterCorpusTrainer(Trainer):
|
||||||
|
pass
|
||||||
|
# def train(self, *args, **kwargs):
|
||||||
|
# """
|
||||||
|
# Train the chat bot based on the provided list of
|
||||||
|
# statements that represents a single conversation.
|
||||||
|
# """
|
||||||
|
# import twint
|
||||||
|
#
|
||||||
|
# c = twint.Config()
|
||||||
|
# c.__dict__.update(kwargs)
|
||||||
|
# twint.run.Search(c)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# previous_statement_text = None
|
||||||
|
# previous_statement_search_text = ''
|
||||||
|
#
|
||||||
|
# statements_to_create = []
|
||||||
|
#
|
||||||
|
# for conversation_count, text in enumerate(conversation):
|
||||||
|
# if self.show_training_progress:
|
||||||
|
# utils.print_progress_bar(
|
||||||
|
# 'List Trainer',
|
||||||
|
# conversation_count + 1, len(conversation)
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# statement_search_text = self.chatbot.storage.tagger.get_text_index_string(text)
|
||||||
|
#
|
||||||
|
# statement = self.get_preprocessed_statement(
|
||||||
|
# Statement(
|
||||||
|
# text=text,
|
||||||
|
# search_text=statement_search_text,
|
||||||
|
# in_response_to=previous_statement_text,
|
||||||
|
# search_in_response_to=previous_statement_search_text,
|
||||||
|
# conversation='training'
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# previous_statement_text = statement.text
|
||||||
|
# previous_statement_search_text = statement_search_text
|
||||||
|
#
|
||||||
|
# statements_to_create.append(statement)
|
||||||
|
#
|
||||||
|
# self.chatbot.storage.create_many(statements_to_create)
|
Loading…
Reference in new issue