diff --git a/chatter/README.md b/chatter/README.md
index d92ad2b..06331b2 100644
--- a/chatter/README.md
+++ b/chatter/README.md
@@ -59,6 +59,35 @@ Install these on your windows machine before attempting the installation:
[Pandoc - Universal Document Converter](https://pandoc.org/installing.html)
## Methods
+### Automatic
+
+This method requires some luck to pull off.
+
+#### Step 1: Add repo and install cog
+
+```
+[p]repo add Fox https://github.com/bobloy/Fox-V3
+[p]cog install Fox chatter
+```
+
+If you get an error at this step, stop and skip to one of the manual methods below.
+
+#### Step 2: Install additional dependencies
+
+Assuming the previous commands had no error, you can now use `pipinstall` to add the remaining dependencies.
+
+NOTE: This method is not the intended use case for `pipinstall` and may stop working in the future.
+
+```
+[p]pipinstall --no-deps chatterbot>=1.1
+```
+
+#### Step 3: Load the cog and get started
+
+```
+[p]load chatter
+```
+
### Windows - Manually
#### Step 1: Built-in Downloader
diff --git a/chatter/chat.py b/chatter/chat.py
index de0e20a..d999d94 100644
--- a/chatter/chat.py
+++ b/chatter/chat.py
@@ -2,8 +2,10 @@ import asyncio
import logging
import os
import pathlib
+from collections import defaultdict
from datetime import datetime, timedelta
-from typing import Optional
+from functools import partial
+from typing import Dict, Optional
import discord
from chatterbot import ChatBot
@@ -15,6 +17,8 @@ from redbot.core.commands import Cog
from redbot.core.data_manager import cog_data_path
from redbot.core.utils.predicates import MessagePredicate
+from chatter.trainers import MovieTrainer, TwitterCorpusTrainer, UbuntuCorpusTrainer2
+
log = logging.getLogger("red.fox_v3.chatter")
@@ -59,6 +63,7 @@ class Chatter(Cog):
"convo_delta": 15,
"chatchannel": None,
"reply": True,
+ "learning": True,
}
path: pathlib.Path = cog_data_path(self)
self.data_path = path / "database.sqlite3"
@@ -79,6 +84,10 @@ class Chatter(Cog):
self.loop = asyncio.get_event_loop()
+ self._guild_cache = defaultdict(dict)
+
+ self._last_message_per_channel: Dict[Optional[discord.Message]] = defaultdict(lambda: None)
+
async def red_delete_data_for_user(self, **kwargs):
"""Nothing to delete"""
return
@@ -87,7 +96,8 @@ class Chatter(Cog):
return ChatBot(
"ChatterBot",
- storage_adapter="chatterbot.storage.SQLStorageAdapter",
+ # storage_adapter="chatterbot.storage.SQLStorageAdapter",
+ storage_adapter="chatter.storage_adapters.MyDumbSQLStorageAdapter",
database_uri="sqlite:///" + str(self.data_path),
statement_comparison_function=self.similarity_algo,
response_selection_method=get_random_response,
@@ -111,15 +121,7 @@ class Chatter(Cog):
return msg.clean_content
def new_conversation(msg, sent, out_in, delta):
- # if sent is None:
- # return False
-
- # Don't do "too short" processing here. Sometimes people don't respond.
- # if len(out_in) < 2:
- # return False
-
- # print(msg.created_at - sent)
-
+ # Should always be positive numbers
return msg.created_at - sent >= delta
for channel in ctx.guild.text_channels:
@@ -164,6 +166,11 @@ class Chatter(Cog):
return out
+ def _train_twitter(self, *args, **kwargs):
+ trainer = TwitterCorpusTrainer(self.chatbot)
+ trainer.train(*args, **kwargs)
+ return True
+
def _train_ubuntu(self):
trainer = UbuntuCorpusTrainer(
self.chatbot, ubuntu_corpus_data_directory=cog_data_path(self) / "ubuntu_data"
@@ -171,6 +178,30 @@ class Chatter(Cog):
trainer.train()
return True
+ async def _train_movies(self):
+ trainer = MovieTrainer(self.chatbot, cog_data_path(self))
+ return await trainer.asynctrain()
+
+ async def _train_ubuntu2(self, intensity):
+ train_kwarg = {}
+ if intensity == 196:
+ train_kwarg["train_dialogue"] = False
+ train_kwarg["train_196"] = True
+ elif intensity == 301:
+ train_kwarg["train_dialogue"] = False
+ train_kwarg["train_301"] = True
+ elif intensity == 497:
+ train_kwarg["train_dialogue"] = False
+ train_kwarg["train_196"] = True
+ train_kwarg["train_301"] = True
+ elif intensity >= 9000: # NOT 9000!
+ train_kwarg["train_dialogue"] = True
+ train_kwarg["train_196"] = True
+ train_kwarg["train_301"] = True
+
+ trainer = UbuntuCorpusTrainer2(self.chatbot, cog_data_path(self))
+ return await trainer.asynctrain(**train_kwarg)
+
def _train_english(self):
trainer = ChatterBotCorpusTrainer(self.chatbot)
# try:
@@ -196,9 +227,9 @@ class Chatter(Cog):
"""
Base command for this cog. Check help for the commands list.
"""
- pass
+ self._guild_cache[ctx.guild.id] = {} # Clear cache when modifying values
- @checks.admin()
+ @commands.admin()
@chatter.command(name="channel")
async def chatter_channel(
self, ctx: commands.Context, channel: Optional[discord.TextChannel] = None
@@ -218,7 +249,7 @@ class Chatter(Cog):
await self.config.guild(ctx.guild).chatchannel.set(channel.id)
await ctx.maybe_send_embed(f"Chat channel is now {channel.mention}")
- @checks.admin()
+ @commands.admin()
@chatter.command(name="reply")
async def chatter_reply(self, ctx: commands.Context, toggle: Optional[bool] = None):
"""
@@ -231,19 +262,41 @@ class Chatter(Cog):
await self.config.guild(ctx.guild).reply.set(toggle)
if toggle:
- await ctx.send("I will now respond to you if conversation continuity is not present")
+ await ctx.maybe_send_embed(
+ "I will now respond to you if conversation continuity is not present"
+ )
else:
- await ctx.send(
+ await ctx.maybe_send_embed(
"I will not reply to your message if conversation continuity is not present, anymore"
)
- @checks.is_owner()
+ @commands.admin()
+ @chatter.command(name="learning")
+ async def chatter_learning(self, ctx: commands.Context, toggle: Optional[bool] = None):
+ """
+ Toggle the bot learning from its conversations.
+
+ This is on by default.
+ """
+ learning = await self.config.guild(ctx.guild).learning()
+ if toggle is None:
+ toggle = not learning
+ await self.config.guild(ctx.guild).learning.set(toggle)
+
+ if toggle:
+ await ctx.maybe_send_embed("I will now learn from conversations.")
+ else:
+ await ctx.maybe_send_embed("I will no longer learn from conversations.")
+
+ @commands.is_owner()
@chatter.command(name="cleardata")
async def chatter_cleardata(self, ctx: commands.Context, confirm: bool = False):
"""
- This command will erase all training data and reset your configuration settings
+ This command will erase all training data and reset your configuration settings.
- Use `[p]chatter cleardata True`
+ This applies to all guilds.
+
+ Use `[p]chatter cleardata True` to confirm.
"""
if not confirm:
@@ -270,7 +323,7 @@ class Chatter(Cog):
await ctx.tick()
- @checks.is_owner()
+ @commands.is_owner()
@chatter.command(name="algorithm", aliases=["algo"])
async def chatter_algorithm(
self, ctx: commands.Context, algo_number: int, threshold: float = None
@@ -304,7 +357,7 @@ class Chatter(Cog):
await ctx.tick()
- @checks.is_owner()
+ @commands.is_owner()
@chatter.command(name="model")
async def chatter_model(self, ctx: commands.Context, model_number: int):
"""
@@ -342,7 +395,7 @@ class Chatter(Cog):
f"Model has been switched to {self.tagger_language.ISO_639_1}"
)
- @checks.is_owner()
+ @commands.is_owner()
@chatter.command(name="minutes")
async def minutes(self, ctx: commands.Context, minutes: int):
"""
@@ -354,11 +407,11 @@ class Chatter(Cog):
await ctx.send_help()
return
- await self.config.guild(ctx.guild).convo_length.set(minutes)
+ await self.config.guild(ctx.guild).convo_delta.set(minutes)
await ctx.tick()
- @checks.is_owner()
+ @commands.is_owner()
@chatter.command(name="age")
async def age(self, ctx: commands.Context, days: int):
"""
@@ -373,7 +426,16 @@ class Chatter(Cog):
await self.config.guild(ctx.guild).days.set(days)
await ctx.tick()
- @checks.is_owner()
+ @commands.is_owner()
+ @chatter.command(name="kaggle")
+ async def chatter_kaggle(self, ctx: commands.Context):
+ """Register with the kaggle API to download additional datasets for training"""
+ if not await self.check_for_kaggle():
+ await ctx.maybe_send_embed(
+ "[Click here for instructions to setup the kaggle api](https://github.com/Kaggle/kaggle-api#api-credentials)"
+ )
+
+ @commands.is_owner()
@chatter.command(name="backup")
async def backup(self, ctx, backupname):
"""
@@ -395,8 +457,71 @@ class Chatter(Cog):
else:
await ctx.maybe_send_embed("Error occurred :(")
- @checks.is_owner()
- @chatter.command(name="trainubuntu")
+ @commands.is_owner()
+ @chatter.group(name="train")
+ async def chatter_train(self, ctx: commands.Context):
+ """Commands for training the bot"""
+ pass
+
+ @chatter_train.group(name="kaggle")
+ async def chatter_train_kaggle(self, ctx: commands.Context):
+ """
+ Base command for kaggle training sets.
+
+ See `[p]chatter kaggle` for details on how to enable this option
+ """
+ pass
+
+ @chatter_train_kaggle.command(name="ubuntu")
+ async def chatter_train_kaggle_ubuntu(
+ self, ctx: commands.Context, confirmation: bool = False, intensity=0
+ ):
+ """
+ WARNING: Large Download! Trains the bot using *NEW* Ubuntu Dialog Corpus data.
+ """
+
+ if not confirmation:
+ await ctx.maybe_send_embed(
+ "Warning: This command downloads ~800MB and is CPU intensive during training\n"
+ "If you're sure you want to continue, run `[p]chatter train kaggle ubuntu True`"
+ )
+ return
+
+ async with ctx.typing():
+ future = await self._train_ubuntu2(intensity)
+
+ if future:
+ await ctx.maybe_send_embed("Training successful!")
+ else:
+ await ctx.maybe_send_embed("Error occurred :(")
+
+ @chatter_train_kaggle.command(name="movies")
+ async def chatter_train_kaggle_movies(self, ctx: commands.Context, confirmation: bool = False):
+ """
+ WARNING: Language! Trains the bot using Cornell University's "Movie Dialog Corpus".
+
+ This training set contains dialog from a spread of movies with different MPAA.
+ This dialog includes racism, sexism, and any number of sensitive topics.
+
+ Use at your own risk.
+ """
+
+ if not confirmation:
+ await ctx.maybe_send_embed(
+ "Warning: This command downloads ~29MB and is CPU intensive during training\n"
+ "If you're sure you want to continue, run `[p]chatter train kaggle movies True`"
+ )
+ return
+
+ async with ctx.typing():
+ future = await self._train_movies()
+
+ if future:
+ await ctx.maybe_send_embed("Training successful!")
+ else:
+ await ctx.maybe_send_embed("Error occurred :(")
+
+ @chatter_train.command(name="ubuntu")
async def chatter_train_ubuntu(self, ctx: commands.Context, confirmation: bool = False):
"""
WARNING: Large Download! Trains the bot using Ubuntu Dialog Corpus data.
@@ -404,8 +529,8 @@ class Chatter(Cog):
if not confirmation:
await ctx.maybe_send_embed(
- "Warning: This command downloads ~500MB then eats your CPU for training\n"
- "If you're sure you want to continue, run `[p]chatter trainubuntu True`"
+ "Warning: This command downloads ~500MB and is CPU intensive during training\n"
+ "If you're sure you want to continue, run `[p]chatter train ubuntu True`"
)
return
@@ -413,12 +538,11 @@ class Chatter(Cog):
future = await self.loop.run_in_executor(None, self._train_ubuntu)
if future:
- await ctx.send("Training successful!")
+ await ctx.maybe_send_embed("Training successful!")
else:
- await ctx.send("Error occurred :(")
+ await ctx.maybe_send_embed("Error occurred :(")
- @checks.is_owner()
- @chatter.command(name="trainenglish")
+ @chatter_train.command(name="english")
async def chatter_train_english(self, ctx: commands.Context):
"""
Trains the bot in english
@@ -431,11 +555,26 @@ class Chatter(Cog):
else:
await ctx.maybe_send_embed("Error occurred :(")
- @checks.is_owner()
- @chatter.command()
- async def train(self, ctx: commands.Context, channel: discord.TextChannel):
+ @chatter_train.command(name="list")
+ async def chatter_train_list(self, ctx: commands.Context):
+ """Trains the bot based on an uploaded list.
+
+ Must be a file in the format of a python list: ['prompt', 'response1', 'response2']
+ """
+ if not ctx.message.attachments:
+ await ctx.maybe_send_embed("You must upload a file when using this command")
+ return
+
+ attachment: discord.Attachment = ctx.message.attachments[0]
+
+ a_bytes = await attachment.read()
+
+ await ctx.send("Not yet implemented")
+
+ @chatter_train.command(name="channel")
+ async def chatter_train_channel(self, ctx: commands.Context, channel: discord.TextChannel):
"""
- Trains the bot based on language in this guild
+ Trains the bot based on language in this guild.
"""
await ctx.maybe_send_embed(
@@ -499,15 +638,18 @@ class Chatter(Cog):
# Thank you Cog-Creators
channel: discord.TextChannel = message.channel
- # is_reply = False # this is only useful with in_response_to
+ if not self._guild_cache[guild.id]:
+ self._guild_cache[guild.id] = await self.config.guild(guild).all()
+
+ is_reply = False # this is only useful with in_response_to
if (
message.reference is not None
and isinstance(message.reference.resolved, discord.Message)
and message.reference.resolved.author.id == self.bot.user.id
):
- # is_reply = True # this is only useful with in_response_to
+ is_reply = True # this is only useful with in_response_to
pass # this is a reply to the bot, good to go
- elif guild is not None and channel.id == await self.config.guild(guild).chatchannel():
+ elif guild is not None and channel.id == self._guild_cache[guild.id]["chatchannel"]:
pass # good to go
else:
when_mentionables = commands.when_mentioned(self.bot, message)
@@ -522,15 +664,52 @@ class Chatter(Cog):
text = message.clean_content
- async with channel.typing():
- future = await self.loop.run_in_executor(None, self.chatbot.get_response, text)
+ async with ctx.typing():
+
+ if is_reply:
+ in_response_to = message.reference.resolved.content
+ elif self._last_message_per_channel[ctx.channel.id] is not None:
+ last_m: discord.Message = self._last_message_per_channel[ctx.channel.id]
+ minutes = self._guild_cache[ctx.guild.id]["convo_delta"]
+ if (datetime.utcnow() - last_m.created_at).seconds > minutes * 60:
+ in_response_to = None
+ else:
+ in_response_to = last_m.content
+ else:
+ in_response_to = None
+
+ # Always use generate reponse
+ # Chatterbot tries to learn based on the result it comes up with, which is dumb
+ log.debug("Generating response")
+ Statement = self.chatbot.storage.get_object("statement")
+ future = await self.loop.run_in_executor(
+ None, self.chatbot.generate_response, Statement(text)
+ )
+
+ if in_response_to is not None and self._guild_cache[guild.id]["learning"]:
+ log.debug("learning response")
+ await self.loop.run_in_executor(
+ None,
+ partial(
+ self.chatbot.learn_response,
+ Statement(text),
+ previous_statement=in_response_to,
+ ),
+ )
replying = None
- if await self.config.guild(guild).reply():
+ if self._guild_cache[guild.id]["reply"]:
if message != ctx.channel.last_message:
replying = message
if future and str(future):
- await channel.send(str(future), reference=replying)
+ self._last_message_per_channel[ctx.channel.id] = await channel.send(
+ str(future), reference=replying
+ )
else:
- await channel.send(":thinking:")
+ await ctx.send(":thinking:")
+
+ async def check_for_kaggle(self):
+ """Check whether Kaggle is installed and configured properly"""
+ # TODO: This
+ return False
diff --git a/chatter/info.json b/chatter/info.json
index a3fe0da..fc31e7c 100644
--- a/chatter/info.json
+++ b/chatter/info.json
@@ -17,7 +17,8 @@
"pytz",
"https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz#egg=en_core_web_sm",
"https://github.com/explosion/spacy-models/releases/download/en_core_web_md-2.3.1/en_core_web_md-2.3.1.tar.gz#egg=en_core_web_md",
- "spacy>=2.3,<2.4"
+ "spacy>=2.3,<2.4",
+ "kaggle"
],
"short": "Local Chatbot run on machine learning",
"end_user_data_statement": "This cog only stores anonymous conversations data; no End User Data is stored.",
diff --git a/chatter/storage_adapters.py b/chatter/storage_adapters.py
new file mode 100644
index 0000000..4de2f00
--- /dev/null
+++ b/chatter/storage_adapters.py
@@ -0,0 +1,73 @@
+from chatterbot.storage import StorageAdapter, SQLStorageAdapter
+
+
+class MyDumbSQLStorageAdapter(SQLStorageAdapter):
+ def __init__(self, **kwargs):
+ super(SQLStorageAdapter, self).__init__(**kwargs)
+
+ from sqlalchemy import create_engine
+ from sqlalchemy.orm import sessionmaker
+
+ self.database_uri = kwargs.get("database_uri", False)
+
+ # None results in a sqlite in-memory database as the default
+ if self.database_uri is None:
+ self.database_uri = "sqlite://"
+
+ # Create a file database if the database is not a connection string
+ if not self.database_uri:
+ self.database_uri = "sqlite:///db.sqlite3"
+
+ self.engine = create_engine(
+ self.database_uri, convert_unicode=True, connect_args={"check_same_thread": False}
+ )
+
+ if self.database_uri.startswith("sqlite://"):
+ from sqlalchemy.engine import Engine
+ from sqlalchemy import event
+
+ @event.listens_for(Engine, "connect")
+ def set_sqlite_pragma(dbapi_connection, connection_record):
+ dbapi_connection.execute("PRAGMA journal_mode=WAL")
+ dbapi_connection.execute("PRAGMA synchronous=NORMAL")
+
+ if not self.engine.dialect.has_table(self.engine, "Statement"):
+ self.create_database()
+
+ self.Session = sessionmaker(bind=self.engine, expire_on_commit=True)
+
+
+class AsyncSQLStorageAdapter(SQLStorageAdapter):
+ def __init__(self, **kwargs):
+ super(SQLStorageAdapter, self).__init__(**kwargs)
+
+ self.database_uri = kwargs.get("database_uri", False)
+
+ # None results in a sqlite in-memory database as the default
+ if self.database_uri is None:
+ self.database_uri = "sqlite://"
+
+ # Create a file database if the database is not a connection string
+ if not self.database_uri:
+ self.database_uri = "sqlite:///db.sqlite3"
+
+ async def initialize(self):
+ # from sqlalchemy import create_engine
+ from aiomysql.sa import create_engine
+ from sqlalchemy.orm import sessionmaker
+
+ self.engine = await create_engine(self.database_uri, convert_unicode=True)
+
+ if self.database_uri.startswith("sqlite://"):
+ from sqlalchemy.engine import Engine
+ from sqlalchemy import event
+
+ @event.listens_for(Engine, "connect")
+ def set_sqlite_pragma(dbapi_connection, connection_record):
+ dbapi_connection.execute("PRAGMA journal_mode=WAL")
+ dbapi_connection.execute("PRAGMA synchronous=NORMAL")
+
+ if not self.engine.dialect.has_table(self.engine, "Statement"):
+ self.create_database()
+
+ self.Session = sessionmaker(bind=self.engine, expire_on_commit=True)
diff --git a/chatter/trainers.py b/chatter/trainers.py
new file mode 100644
index 0000000..3cc92da
--- /dev/null
+++ b/chatter/trainers.py
@@ -0,0 +1,351 @@
+import asyncio
+import csv
+import html
+import logging
+import os
+import pathlib
+import time
+from functools import partial
+
+from chatterbot import utils
+from chatterbot.conversation import Statement
+from chatterbot.tagging import PosLemmaTagger
+from chatterbot.trainers import Trainer
+from redbot.core.bot import Red
+from dateutil import parser as date_parser
+from redbot.core.utils import AsyncIter
+
+log = logging.getLogger("red.fox_v3.chatter.trainers")
+
+
+class KaggleTrainer(Trainer):
+ def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
+ super().__init__(chatbot, **kwargs)
+
+ self.data_directory = datapath / kwargs.get("downloadpath", "kaggle_download")
+
+ self.kaggle_dataset = kwargs.get(
+ "kaggle_dataset",
+ "Cornell-University/movie-dialog-corpus",
+ )
+
+ # Create the data directory if it does not already exist
+ if not os.path.exists(self.data_directory):
+ os.makedirs(self.data_directory)
+
+ def is_downloaded(self, file_path):
+ """
+ Check if the data file is already downloaded.
+ """
+ if os.path.exists(file_path):
+ self.chatbot.logger.info("File is already downloaded")
+ return True
+
+ return False
+
+ async def download(self, dataset):
+ import kaggle # This triggers the API token check
+
+ future = await asyncio.get_event_loop().run_in_executor(
+ None,
+ partial(
+ kaggle.api.dataset_download_files,
+ dataset=dataset,
+ path=self.data_directory,
+ quiet=False,
+ unzip=True,
+ ),
+ )
+
+ def train(self, *args, **kwargs):
+ log.error("See asynctrain instead")
+
+ def asynctrain(self, *args, **kwargs):
+ raise self.TrainerInitializationException()
+
+
+class SouthParkTrainer(KaggleTrainer):
+ def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
+ super().__init__(
+ chatbot,
+ datapath,
+ downloadpath="ubuntu_data_v2",
+ kaggle_dataset="tovarischsukhov/southparklines",
+ **kwargs,
+ )
+
+
+class MovieTrainer(KaggleTrainer):
+ def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
+ super().__init__(
+ chatbot,
+ datapath,
+ downloadpath="kaggle_movies",
+ kaggle_dataset="Cornell-University/movie-dialog-corpus",
+ **kwargs,
+ )
+
+ async def run_movie_training(self):
+ dialogue_file = "movie_lines.tsv"
+ conversation_file = "movie_conversations.tsv"
+ log.info(f"Beginning dialogue training on {dialogue_file}")
+ start_time = time.time()
+
+ tagger = PosLemmaTagger(language=self.chatbot.storage.tagger.language)
+
+ # [lineID, characterID, movieID, character name, text of utterance]
+ # File parsing from https://www.kaggle.com/mushaya/conversation-chatbot
+
+ with open(self.data_directory / conversation_file, "r", encoding="utf-8-sig") as conv_tsv:
+ conv_lines = conv_tsv.readlines()
+ with open(self.data_directory / dialogue_file, "r", encoding="utf-8-sig") as lines_tsv:
+ dialog_lines = lines_tsv.readlines()
+
+ # trans_dict = str.maketrans({"": "__", "": "__", '""': '"'})
+
+ lines_dict = {}
+ for line in dialog_lines:
+ _line = line[:-1].strip('"').split("\t")
+ if len(_line) >= 5: # Only good lines
+ lines_dict[_line[0]] = (
+ html.unescape(("".join(_line[4:])).strip())
+ .replace("", "__")
+ .replace("", "__")
+ .replace('""', '"')
+ )
+ else:
+ log.debug(f"Bad line {_line}")
+
+ # collecting line ids for each conversation
+ conv = []
+ for line in conv_lines[:-1]:
+ _line = line[:-1].split("\t")[-1][1:-1].replace("'", "").replace(" ", ",")
+ conv.append(_line.split(","))
+
+ # conversations = csv.reader(conv_tsv, delimiter="\t")
+ #
+ # reader = csv.reader(lines_tsv, delimiter="\t")
+ #
+ #
+ #
+ # lines_dict = {}
+ # for row in reader:
+ # try:
+ # lines_dict[row[0].strip('"')] = row[4]
+ # except:
+ # log.exception(f"Bad line: {row}")
+ # pass
+ # else:
+ # # log.info(f"Good line: {row}")
+ # pass
+ #
+ # # lines_dict = {row[0].strip('"'): row[4] for row in reader_list}
+
+ statements_from_file = []
+ save_every = 300
+ count = 0
+
+ # [characterID of first, characterID of second, movieID, list of utterances]
+ async for lines in AsyncIter(conv):
+ previous_statement_text = None
+ previous_statement_search_text = ""
+
+ for line in lines:
+ text = lines_dict[line]
+ statement = Statement(
+ text=text,
+ in_response_to=previous_statement_text,
+ conversation="training",
+ )
+
+ for preprocessor in self.chatbot.preprocessors:
+ statement = preprocessor(statement)
+
+ statement.search_text = tagger.get_text_index_string(statement.text)
+ statement.search_in_response_to = previous_statement_search_text
+
+ previous_statement_text = statement.text
+ previous_statement_search_text = statement.search_text
+
+ statements_from_file.append(statement)
+
+ count += 1
+ if count >= save_every:
+ if statements_from_file:
+ self.chatbot.storage.create_many(statements_from_file)
+ statements_from_file = []
+ count = 0
+
+ if statements_from_file:
+ self.chatbot.storage.create_many(statements_from_file)
+
+ log.info(f"Training took {time.time() - start_time} seconds.")
+
+ async def asynctrain(self, *args, **kwargs):
+ extracted_lines = self.data_directory / "movie_lines.tsv"
+ extracted_lines: pathlib.Path
+
+ # Download and extract the Ubuntu dialog corpus if needed
+ if not extracted_lines.exists():
+ await self.download(self.kaggle_dataset)
+ else:
+ log.info("Movie dialog already downloaded")
+ if not extracted_lines.exists():
+ raise FileNotFoundError(f"{extracted_lines}")
+
+ await self.run_movie_training()
+
+ return True
+
+ # train_dialogue = kwargs.get("train_dialogue", True)
+ # train_196_dialogue = kwargs.get("train_196", False)
+ # train_301_dialogue = kwargs.get("train_301", False)
+ #
+ # if train_dialogue:
+ # await self.run_dialogue_training(extracted_dir, "dialogueText.csv")
+ #
+ # if train_196_dialogue:
+ # await self.run_dialogue_training(extracted_dir, "dialogueText_196.csv")
+ #
+ # if train_301_dialogue:
+ # await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv")
+
+
+class UbuntuCorpusTrainer2(KaggleTrainer):
+ def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
+ super().__init__(
+ chatbot,
+ datapath,
+ downloadpath="kaggle_ubuntu",
+ kaggle_dataset="rtatman/ubuntu-dialogue-corpus",
+ **kwargs,
+ )
+
+ async def asynctrain(self, *args, **kwargs):
+ extracted_dir = self.data_directory / "Ubuntu-dialogue-corpus"
+
+ # Download and extract the Ubuntu dialog corpus if needed
+ if not extracted_dir.exists():
+ await self.download(self.kaggle_dataset)
+ else:
+ log.info("Ubuntu dialogue already downloaded")
+ if not extracted_dir.exists():
+ raise FileNotFoundError("Did not extract in the expected way")
+
+ train_dialogue = kwargs.get("train_dialogue", True)
+ train_196_dialogue = kwargs.get("train_196", False)
+ train_301_dialogue = kwargs.get("train_301", False)
+
+ if train_dialogue:
+ await self.run_dialogue_training(extracted_dir, "dialogueText.csv")
+
+ if train_196_dialogue:
+ await self.run_dialogue_training(extracted_dir, "dialogueText_196.csv")
+
+ if train_301_dialogue:
+ await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv")
+
+ return True
+
+ async def run_dialogue_training(self, extracted_dir, dialogue_file):
+ log.info(f"Beginning dialogue training on {dialogue_file}")
+ start_time = time.time()
+
+ tagger = PosLemmaTagger(language=self.chatbot.storage.tagger.language)
+
+ with open(extracted_dir / dialogue_file, "r", encoding="utf-8") as dg:
+ reader = csv.DictReader(dg)
+
+ next(reader) # Skip the header
+
+ last_dialogue_id = None
+ previous_statement_text = None
+ previous_statement_search_text = ""
+ statements_from_file = []
+
+ save_every = 50
+ count = 0
+
+ async for row in AsyncIter(reader):
+ dialogue_id = row["dialogueID"]
+ if dialogue_id != last_dialogue_id:
+ previous_statement_text = None
+ previous_statement_search_text = ""
+ last_dialogue_id = dialogue_id
+ count += 1
+ if count >= save_every:
+ if statements_from_file:
+ self.chatbot.storage.create_many(statements_from_file)
+ statements_from_file = []
+ count = 0
+
+ if len(row) > 0:
+ statement = Statement(
+ text=row["text"],
+ in_response_to=previous_statement_text,
+ conversation="training",
+ # created_at=date_parser.parse(row["date"]),
+ persona=row["from"],
+ )
+
+ for preprocessor in self.chatbot.preprocessors:
+ statement = preprocessor(statement)
+
+ statement.search_text = tagger.get_text_index_string(statement.text)
+ statement.search_in_response_to = previous_statement_search_text
+
+ previous_statement_text = statement.text
+ previous_statement_search_text = statement.search_text
+
+ statements_from_file.append(statement)
+
+ if statements_from_file:
+ self.chatbot.storage.create_many(statements_from_file)
+
+ log.info(f"Training took {time.time() - start_time} seconds.")
+
+
+class TwitterCorpusTrainer(Trainer):
+ pass
+ # def train(self, *args, **kwargs):
+ # """
+ # Train the chat bot based on the provided list of
+ # statements that represents a single conversation.
+ # """
+ # import twint
+ #
+ # c = twint.Config()
+ # c.__dict__.update(kwargs)
+ # twint.run.Search(c)
+ #
+ #
+ # previous_statement_text = None
+ # previous_statement_search_text = ''
+ #
+ # statements_to_create = []
+ #
+ # for conversation_count, text in enumerate(conversation):
+ # if self.show_training_progress:
+ # utils.print_progress_bar(
+ # 'List Trainer',
+ # conversation_count + 1, len(conversation)
+ # )
+ #
+ # statement_search_text = self.chatbot.storage.tagger.get_text_index_string(text)
+ #
+ # statement = self.get_preprocessed_statement(
+ # Statement(
+ # text=text,
+ # search_text=statement_search_text,
+ # in_response_to=previous_statement_text,
+ # search_in_response_to=previous_statement_search_text,
+ # conversation='training'
+ # )
+ # )
+ #
+ # previous_statement_text = statement.text
+ # previous_statement_search_text = statement_search_text
+ #
+ # statements_to_create.append(statement)
+ #
+ # self.chatbot.storage.create_many(statements_to_create)