commit
ea126db0c5
@ -59,6 +59,35 @@ Install these on your windows machine before attempting the installation:
|
||||
[Pandoc - Universal Document Converter](https://pandoc.org/installing.html)
|
||||
|
||||
## Methods
|
||||
### Automatic
|
||||
|
||||
This method requires some luck to pull off.
|
||||
|
||||
#### Step 1: Add repo and install cog
|
||||
|
||||
```
|
||||
[p]repo add Fox https://github.com/bobloy/Fox-V3
|
||||
[p]cog install Fox chatter
|
||||
```
|
||||
|
||||
If you get an error at this step, stop and skip to one of the manual methods below.
|
||||
|
||||
#### Step 2: Install additional dependencies
|
||||
|
||||
Assuming the previous commands had no error, you can now use `pipinstall` to add the remaining dependencies.
|
||||
|
||||
NOTE: This method is not the intended use case for `pipinstall` and may stop working in the future.
|
||||
|
||||
```
|
||||
[p]pipinstall --no-deps chatterbot>=1.1
|
||||
```
|
||||
|
||||
#### Step 3: Load the cog and get started
|
||||
|
||||
```
|
||||
[p]load chatter
|
||||
```
|
||||
|
||||
### Windows - Manually
|
||||
#### Step 1: Built-in Downloader
|
||||
|
||||
|
269
chatter/chat.py
269
chatter/chat.py
@ -2,8 +2,10 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from functools import partial
|
||||
from typing import Dict, Optional
|
||||
|
||||
import discord
|
||||
from chatterbot import ChatBot
|
||||
@ -15,6 +17,8 @@ from redbot.core.commands import Cog
|
||||
from redbot.core.data_manager import cog_data_path
|
||||
from redbot.core.utils.predicates import MessagePredicate
|
||||
|
||||
from chatter.trainers import MovieTrainer, TwitterCorpusTrainer, UbuntuCorpusTrainer2
|
||||
|
||||
log = logging.getLogger("red.fox_v3.chatter")
|
||||
|
||||
|
||||
@ -59,6 +63,7 @@ class Chatter(Cog):
|
||||
"convo_delta": 15,
|
||||
"chatchannel": None,
|
||||
"reply": True,
|
||||
"learning": True,
|
||||
}
|
||||
path: pathlib.Path = cog_data_path(self)
|
||||
self.data_path = path / "database.sqlite3"
|
||||
@ -79,6 +84,10 @@ class Chatter(Cog):
|
||||
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
self._guild_cache = defaultdict(dict)
|
||||
|
||||
self._last_message_per_channel: Dict[Optional[discord.Message]] = defaultdict(lambda: None)
|
||||
|
||||
async def red_delete_data_for_user(self, **kwargs):
|
||||
"""Nothing to delete"""
|
||||
return
|
||||
@ -87,7 +96,8 @@ class Chatter(Cog):
|
||||
|
||||
return ChatBot(
|
||||
"ChatterBot",
|
||||
storage_adapter="chatterbot.storage.SQLStorageAdapter",
|
||||
# storage_adapter="chatterbot.storage.SQLStorageAdapter",
|
||||
storage_adapter="chatter.storage_adapters.MyDumbSQLStorageAdapter",
|
||||
database_uri="sqlite:///" + str(self.data_path),
|
||||
statement_comparison_function=self.similarity_algo,
|
||||
response_selection_method=get_random_response,
|
||||
@ -111,15 +121,7 @@ class Chatter(Cog):
|
||||
return msg.clean_content
|
||||
|
||||
def new_conversation(msg, sent, out_in, delta):
|
||||
# if sent is None:
|
||||
# return False
|
||||
|
||||
# Don't do "too short" processing here. Sometimes people don't respond.
|
||||
# if len(out_in) < 2:
|
||||
# return False
|
||||
|
||||
# print(msg.created_at - sent)
|
||||
|
||||
# Should always be positive numbers
|
||||
return msg.created_at - sent >= delta
|
||||
|
||||
for channel in ctx.guild.text_channels:
|
||||
@ -164,6 +166,11 @@ class Chatter(Cog):
|
||||
|
||||
return out
|
||||
|
||||
def _train_twitter(self, *args, **kwargs):
|
||||
trainer = TwitterCorpusTrainer(self.chatbot)
|
||||
trainer.train(*args, **kwargs)
|
||||
return True
|
||||
|
||||
def _train_ubuntu(self):
|
||||
trainer = UbuntuCorpusTrainer(
|
||||
self.chatbot, ubuntu_corpus_data_directory=cog_data_path(self) / "ubuntu_data"
|
||||
@ -171,6 +178,30 @@ class Chatter(Cog):
|
||||
trainer.train()
|
||||
return True
|
||||
|
||||
async def _train_movies(self):
|
||||
trainer = MovieTrainer(self.chatbot, cog_data_path(self))
|
||||
return await trainer.asynctrain()
|
||||
|
||||
async def _train_ubuntu2(self, intensity):
|
||||
train_kwarg = {}
|
||||
if intensity == 196:
|
||||
train_kwarg["train_dialogue"] = False
|
||||
train_kwarg["train_196"] = True
|
||||
elif intensity == 301:
|
||||
train_kwarg["train_dialogue"] = False
|
||||
train_kwarg["train_301"] = True
|
||||
elif intensity == 497:
|
||||
train_kwarg["train_dialogue"] = False
|
||||
train_kwarg["train_196"] = True
|
||||
train_kwarg["train_301"] = True
|
||||
elif intensity >= 9000: # NOT 9000!
|
||||
train_kwarg["train_dialogue"] = True
|
||||
train_kwarg["train_196"] = True
|
||||
train_kwarg["train_301"] = True
|
||||
|
||||
trainer = UbuntuCorpusTrainer2(self.chatbot, cog_data_path(self))
|
||||
return await trainer.asynctrain(**train_kwarg)
|
||||
|
||||
def _train_english(self):
|
||||
trainer = ChatterBotCorpusTrainer(self.chatbot)
|
||||
# try:
|
||||
@ -196,9 +227,9 @@ class Chatter(Cog):
|
||||
"""
|
||||
Base command for this cog. Check help for the commands list.
|
||||
"""
|
||||
pass
|
||||
self._guild_cache[ctx.guild.id] = {} # Clear cache when modifying values
|
||||
|
||||
@checks.admin()
|
||||
@commands.admin()
|
||||
@chatter.command(name="channel")
|
||||
async def chatter_channel(
|
||||
self, ctx: commands.Context, channel: Optional[discord.TextChannel] = None
|
||||
@ -218,7 +249,7 @@ class Chatter(Cog):
|
||||
await self.config.guild(ctx.guild).chatchannel.set(channel.id)
|
||||
await ctx.maybe_send_embed(f"Chat channel is now {channel.mention}")
|
||||
|
||||
@checks.admin()
|
||||
@commands.admin()
|
||||
@chatter.command(name="reply")
|
||||
async def chatter_reply(self, ctx: commands.Context, toggle: Optional[bool] = None):
|
||||
"""
|
||||
@ -231,19 +262,41 @@ class Chatter(Cog):
|
||||
await self.config.guild(ctx.guild).reply.set(toggle)
|
||||
|
||||
if toggle:
|
||||
await ctx.send("I will now respond to you if conversation continuity is not present")
|
||||
await ctx.maybe_send_embed(
|
||||
"I will now respond to you if conversation continuity is not present"
|
||||
)
|
||||
else:
|
||||
await ctx.send(
|
||||
await ctx.maybe_send_embed(
|
||||
"I will not reply to your message if conversation continuity is not present, anymore"
|
||||
)
|
||||
|
||||
@checks.is_owner()
|
||||
@commands.admin()
|
||||
@chatter.command(name="learning")
|
||||
async def chatter_learning(self, ctx: commands.Context, toggle: Optional[bool] = None):
|
||||
"""
|
||||
Toggle the bot learning from its conversations.
|
||||
|
||||
This is on by default.
|
||||
"""
|
||||
learning = await self.config.guild(ctx.guild).learning()
|
||||
if toggle is None:
|
||||
toggle = not learning
|
||||
await self.config.guild(ctx.guild).learning.set(toggle)
|
||||
|
||||
if toggle:
|
||||
await ctx.maybe_send_embed("I will now learn from conversations.")
|
||||
else:
|
||||
await ctx.maybe_send_embed("I will no longer learn from conversations.")
|
||||
|
||||
@commands.is_owner()
|
||||
@chatter.command(name="cleardata")
|
||||
async def chatter_cleardata(self, ctx: commands.Context, confirm: bool = False):
|
||||
"""
|
||||
This command will erase all training data and reset your configuration settings
|
||||
This command will erase all training data and reset your configuration settings.
|
||||
|
||||
Use `[p]chatter cleardata True`
|
||||
This applies to all guilds.
|
||||
|
||||
Use `[p]chatter cleardata True` to confirm.
|
||||
"""
|
||||
|
||||
if not confirm:
|
||||
@ -270,7 +323,7 @@ class Chatter(Cog):
|
||||
|
||||
await ctx.tick()
|
||||
|
||||
@checks.is_owner()
|
||||
@commands.is_owner()
|
||||
@chatter.command(name="algorithm", aliases=["algo"])
|
||||
async def chatter_algorithm(
|
||||
self, ctx: commands.Context, algo_number: int, threshold: float = None
|
||||
@ -304,7 +357,7 @@ class Chatter(Cog):
|
||||
|
||||
await ctx.tick()
|
||||
|
||||
@checks.is_owner()
|
||||
@commands.is_owner()
|
||||
@chatter.command(name="model")
|
||||
async def chatter_model(self, ctx: commands.Context, model_number: int):
|
||||
"""
|
||||
@ -342,7 +395,7 @@ class Chatter(Cog):
|
||||
f"Model has been switched to {self.tagger_language.ISO_639_1}"
|
||||
)
|
||||
|
||||
@checks.is_owner()
|
||||
@commands.is_owner()
|
||||
@chatter.command(name="minutes")
|
||||
async def minutes(self, ctx: commands.Context, minutes: int):
|
||||
"""
|
||||
@ -354,11 +407,11 @@ class Chatter(Cog):
|
||||
await ctx.send_help()
|
||||
return
|
||||
|
||||
await self.config.guild(ctx.guild).convo_length.set(minutes)
|
||||
await self.config.guild(ctx.guild).convo_delta.set(minutes)
|
||||
|
||||
await ctx.tick()
|
||||
|
||||
@checks.is_owner()
|
||||
@commands.is_owner()
|
||||
@chatter.command(name="age")
|
||||
async def age(self, ctx: commands.Context, days: int):
|
||||
"""
|
||||
@ -373,7 +426,16 @@ class Chatter(Cog):
|
||||
await self.config.guild(ctx.guild).days.set(days)
|
||||
await ctx.tick()
|
||||
|
||||
@checks.is_owner()
|
||||
@commands.is_owner()
|
||||
@chatter.command(name="kaggle")
|
||||
async def chatter_kaggle(self, ctx: commands.Context):
|
||||
"""Register with the kaggle API to download additional datasets for training"""
|
||||
if not await self.check_for_kaggle():
|
||||
await ctx.maybe_send_embed(
|
||||
"[Click here for instructions to setup the kaggle api](https://github.com/Kaggle/kaggle-api#api-credentials)"
|
||||
)
|
||||
|
||||
@commands.is_owner()
|
||||
@chatter.command(name="backup")
|
||||
async def backup(self, ctx, backupname):
|
||||
"""
|
||||
@ -395,8 +457,71 @@ class Chatter(Cog):
|
||||
else:
|
||||
await ctx.maybe_send_embed("Error occurred :(")
|
||||
|
||||
@checks.is_owner()
|
||||
@chatter.command(name="trainubuntu")
|
||||
@commands.is_owner()
|
||||
@chatter.group(name="train")
|
||||
async def chatter_train(self, ctx: commands.Context):
|
||||
"""Commands for training the bot"""
|
||||
pass
|
||||
|
||||
@chatter_train.group(name="kaggle")
|
||||
async def chatter_train_kaggle(self, ctx: commands.Context):
|
||||
"""
|
||||
Base command for kaggle training sets.
|
||||
|
||||
See `[p]chatter kaggle` for details on how to enable this option
|
||||
"""
|
||||
pass
|
||||
|
||||
@chatter_train_kaggle.command(name="ubuntu")
|
||||
async def chatter_train_kaggle_ubuntu(
|
||||
self, ctx: commands.Context, confirmation: bool = False, intensity=0
|
||||
):
|
||||
"""
|
||||
WARNING: Large Download! Trains the bot using *NEW* Ubuntu Dialog Corpus data.
|
||||
"""
|
||||
|
||||
if not confirmation:
|
||||
await ctx.maybe_send_embed(
|
||||
"Warning: This command downloads ~800MB and is CPU intensive during training\n"
|
||||
"If you're sure you want to continue, run `[p]chatter train kaggle ubuntu True`"
|
||||
)
|
||||
return
|
||||
|
||||
async with ctx.typing():
|
||||
future = await self._train_ubuntu2(intensity)
|
||||
|
||||
if future:
|
||||
await ctx.maybe_send_embed("Training successful!")
|
||||
else:
|
||||
await ctx.maybe_send_embed("Error occurred :(")
|
||||
|
||||
@chatter_train_kaggle.command(name="movies")
|
||||
async def chatter_train_kaggle_movies(self, ctx: commands.Context, confirmation: bool = False):
|
||||
"""
|
||||
WARNING: Language! Trains the bot using Cornell University's "Movie Dialog Corpus".
|
||||
|
||||
This training set contains dialog from a spread of movies with different MPAA.
|
||||
This dialog includes racism, sexism, and any number of sensitive topics.
|
||||
|
||||
Use at your own risk.
|
||||
"""
|
||||
|
||||
if not confirmation:
|
||||
await ctx.maybe_send_embed(
|
||||
"Warning: This command downloads ~29MB and is CPU intensive during training\n"
|
||||
"If you're sure you want to continue, run `[p]chatter train kaggle movies True`"
|
||||
)
|
||||
return
|
||||
|
||||
async with ctx.typing():
|
||||
future = await self._train_movies()
|
||||
|
||||
if future:
|
||||
await ctx.maybe_send_embed("Training successful!")
|
||||
else:
|
||||
await ctx.maybe_send_embed("Error occurred :(")
|
||||
|
||||
@chatter_train.command(name="ubuntu")
|
||||
async def chatter_train_ubuntu(self, ctx: commands.Context, confirmation: bool = False):
|
||||
"""
|
||||
WARNING: Large Download! Trains the bot using Ubuntu Dialog Corpus data.
|
||||
@ -404,8 +529,8 @@ class Chatter(Cog):
|
||||
|
||||
if not confirmation:
|
||||
await ctx.maybe_send_embed(
|
||||
"Warning: This command downloads ~500MB then eats your CPU for training\n"
|
||||
"If you're sure you want to continue, run `[p]chatter trainubuntu True`"
|
||||
"Warning: This command downloads ~500MB and is CPU intensive during training\n"
|
||||
"If you're sure you want to continue, run `[p]chatter train ubuntu True`"
|
||||
)
|
||||
return
|
||||
|
||||
@ -413,12 +538,11 @@ class Chatter(Cog):
|
||||
future = await self.loop.run_in_executor(None, self._train_ubuntu)
|
||||
|
||||
if future:
|
||||
await ctx.send("Training successful!")
|
||||
await ctx.maybe_send_embed("Training successful!")
|
||||
else:
|
||||
await ctx.send("Error occurred :(")
|
||||
await ctx.maybe_send_embed("Error occurred :(")
|
||||
|
||||
@checks.is_owner()
|
||||
@chatter.command(name="trainenglish")
|
||||
@chatter_train.command(name="english")
|
||||
async def chatter_train_english(self, ctx: commands.Context):
|
||||
"""
|
||||
Trains the bot in english
|
||||
@ -431,11 +555,26 @@ class Chatter(Cog):
|
||||
else:
|
||||
await ctx.maybe_send_embed("Error occurred :(")
|
||||
|
||||
@checks.is_owner()
|
||||
@chatter.command()
|
||||
async def train(self, ctx: commands.Context, channel: discord.TextChannel):
|
||||
@chatter_train.command(name="list")
|
||||
async def chatter_train_list(self, ctx: commands.Context):
|
||||
"""Trains the bot based on an uploaded list.
|
||||
|
||||
Must be a file in the format of a python list: ['prompt', 'response1', 'response2']
|
||||
"""
|
||||
Trains the bot based on language in this guild
|
||||
if not ctx.message.attachments:
|
||||
await ctx.maybe_send_embed("You must upload a file when using this command")
|
||||
return
|
||||
|
||||
attachment: discord.Attachment = ctx.message.attachments[0]
|
||||
|
||||
a_bytes = await attachment.read()
|
||||
|
||||
await ctx.send("Not yet implemented")
|
||||
|
||||
@chatter_train.command(name="channel")
|
||||
async def chatter_train_channel(self, ctx: commands.Context, channel: discord.TextChannel):
|
||||
"""
|
||||
Trains the bot based on language in this guild.
|
||||
"""
|
||||
|
||||
await ctx.maybe_send_embed(
|
||||
@ -499,15 +638,18 @@ class Chatter(Cog):
|
||||
# Thank you Cog-Creators
|
||||
channel: discord.TextChannel = message.channel
|
||||
|
||||
# is_reply = False # this is only useful with in_response_to
|
||||
if not self._guild_cache[guild.id]:
|
||||
self._guild_cache[guild.id] = await self.config.guild(guild).all()
|
||||
|
||||
is_reply = False # this is only useful with in_response_to
|
||||
if (
|
||||
message.reference is not None
|
||||
and isinstance(message.reference.resolved, discord.Message)
|
||||
and message.reference.resolved.author.id == self.bot.user.id
|
||||
):
|
||||
# is_reply = True # this is only useful with in_response_to
|
||||
is_reply = True # this is only useful with in_response_to
|
||||
pass # this is a reply to the bot, good to go
|
||||
elif guild is not None and channel.id == await self.config.guild(guild).chatchannel():
|
||||
elif guild is not None and channel.id == self._guild_cache[guild.id]["chatchannel"]:
|
||||
pass # good to go
|
||||
else:
|
||||
when_mentionables = commands.when_mentioned(self.bot, message)
|
||||
@ -522,15 +664,52 @@ class Chatter(Cog):
|
||||
|
||||
text = message.clean_content
|
||||
|
||||
async with channel.typing():
|
||||
future = await self.loop.run_in_executor(None, self.chatbot.get_response, text)
|
||||
async with ctx.typing():
|
||||
|
||||
if is_reply:
|
||||
in_response_to = message.reference.resolved.content
|
||||
elif self._last_message_per_channel[ctx.channel.id] is not None:
|
||||
last_m: discord.Message = self._last_message_per_channel[ctx.channel.id]
|
||||
minutes = self._guild_cache[ctx.guild.id]["convo_delta"]
|
||||
if (datetime.utcnow() - last_m.created_at).seconds > minutes * 60:
|
||||
in_response_to = None
|
||||
else:
|
||||
in_response_to = last_m.content
|
||||
else:
|
||||
in_response_to = None
|
||||
|
||||
# Always use generate reponse
|
||||
# Chatterbot tries to learn based on the result it comes up with, which is dumb
|
||||
log.debug("Generating response")
|
||||
Statement = self.chatbot.storage.get_object("statement")
|
||||
future = await self.loop.run_in_executor(
|
||||
None, self.chatbot.generate_response, Statement(text)
|
||||
)
|
||||
|
||||
if in_response_to is not None and self._guild_cache[guild.id]["learning"]:
|
||||
log.debug("learning response")
|
||||
await self.loop.run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self.chatbot.learn_response,
|
||||
Statement(text),
|
||||
previous_statement=in_response_to,
|
||||
),
|
||||
)
|
||||
|
||||
replying = None
|
||||
if await self.config.guild(guild).reply():
|
||||
if self._guild_cache[guild.id]["reply"]:
|
||||
if message != ctx.channel.last_message:
|
||||
replying = message
|
||||
|
||||
if future and str(future):
|
||||
await channel.send(str(future), reference=replying)
|
||||
self._last_message_per_channel[ctx.channel.id] = await channel.send(
|
||||
str(future), reference=replying
|
||||
)
|
||||
else:
|
||||
await channel.send(":thinking:")
|
||||
await ctx.send(":thinking:")
|
||||
|
||||
async def check_for_kaggle(self):
|
||||
"""Check whether Kaggle is installed and configured properly"""
|
||||
# TODO: This
|
||||
return False
|
||||
|
@ -17,7 +17,8 @@
|
||||
"pytz",
|
||||
"https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz#egg=en_core_web_sm",
|
||||
"https://github.com/explosion/spacy-models/releases/download/en_core_web_md-2.3.1/en_core_web_md-2.3.1.tar.gz#egg=en_core_web_md",
|
||||
"spacy>=2.3,<2.4"
|
||||
"spacy>=2.3,<2.4",
|
||||
"kaggle"
|
||||
],
|
||||
"short": "Local Chatbot run on machine learning",
|
||||
"end_user_data_statement": "This cog only stores anonymous conversations data; no End User Data is stored.",
|
||||
|
73
chatter/storage_adapters.py
Normal file
73
chatter/storage_adapters.py
Normal file
@ -0,0 +1,73 @@
|
||||
from chatterbot.storage import StorageAdapter, SQLStorageAdapter
|
||||
|
||||
|
||||
class MyDumbSQLStorageAdapter(SQLStorageAdapter):
|
||||
def __init__(self, **kwargs):
|
||||
super(SQLStorageAdapter, self).__init__(**kwargs)
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
self.database_uri = kwargs.get("database_uri", False)
|
||||
|
||||
# None results in a sqlite in-memory database as the default
|
||||
if self.database_uri is None:
|
||||
self.database_uri = "sqlite://"
|
||||
|
||||
# Create a file database if the database is not a connection string
|
||||
if not self.database_uri:
|
||||
self.database_uri = "sqlite:///db.sqlite3"
|
||||
|
||||
self.engine = create_engine(
|
||||
self.database_uri, convert_unicode=True, connect_args={"check_same_thread": False}
|
||||
)
|
||||
|
||||
if self.database_uri.startswith("sqlite://"):
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy import event
|
||||
|
||||
@event.listens_for(Engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
dbapi_connection.execute("PRAGMA journal_mode=WAL")
|
||||
dbapi_connection.execute("PRAGMA synchronous=NORMAL")
|
||||
|
||||
if not self.engine.dialect.has_table(self.engine, "Statement"):
|
||||
self.create_database()
|
||||
|
||||
self.Session = sessionmaker(bind=self.engine, expire_on_commit=True)
|
||||
|
||||
|
||||
class AsyncSQLStorageAdapter(SQLStorageAdapter):
|
||||
def __init__(self, **kwargs):
|
||||
super(SQLStorageAdapter, self).__init__(**kwargs)
|
||||
|
||||
self.database_uri = kwargs.get("database_uri", False)
|
||||
|
||||
# None results in a sqlite in-memory database as the default
|
||||
if self.database_uri is None:
|
||||
self.database_uri = "sqlite://"
|
||||
|
||||
# Create a file database if the database is not a connection string
|
||||
if not self.database_uri:
|
||||
self.database_uri = "sqlite:///db.sqlite3"
|
||||
|
||||
async def initialize(self):
|
||||
# from sqlalchemy import create_engine
|
||||
from aiomysql.sa import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
self.engine = await create_engine(self.database_uri, convert_unicode=True)
|
||||
|
||||
if self.database_uri.startswith("sqlite://"):
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy import event
|
||||
|
||||
@event.listens_for(Engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
dbapi_connection.execute("PRAGMA journal_mode=WAL")
|
||||
dbapi_connection.execute("PRAGMA synchronous=NORMAL")
|
||||
|
||||
if not self.engine.dialect.has_table(self.engine, "Statement"):
|
||||
self.create_database()
|
||||
|
||||
self.Session = sessionmaker(bind=self.engine, expire_on_commit=True)
|
351
chatter/trainers.py
Normal file
351
chatter/trainers.py
Normal file
@ -0,0 +1,351 @@
|
||||
import asyncio
|
||||
import csv
|
||||
import html
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
from chatterbot import utils
|
||||
from chatterbot.conversation import Statement
|
||||
from chatterbot.tagging import PosLemmaTagger
|
||||
from chatterbot.trainers import Trainer
|
||||
from redbot.core.bot import Red
|
||||
from dateutil import parser as date_parser
|
||||
from redbot.core.utils import AsyncIter
|
||||
|
||||
log = logging.getLogger("red.fox_v3.chatter.trainers")
|
||||
|
||||
|
||||
class KaggleTrainer(Trainer):
|
||||
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
|
||||
super().__init__(chatbot, **kwargs)
|
||||
|
||||
self.data_directory = datapath / kwargs.get("downloadpath", "kaggle_download")
|
||||
|
||||
self.kaggle_dataset = kwargs.get(
|
||||
"kaggle_dataset",
|
||||
"Cornell-University/movie-dialog-corpus",
|
||||
)
|
||||
|
||||
# Create the data directory if it does not already exist
|
||||
if not os.path.exists(self.data_directory):
|
||||
os.makedirs(self.data_directory)
|
||||
|
||||
def is_downloaded(self, file_path):
|
||||
"""
|
||||
Check if the data file is already downloaded.
|
||||
"""
|
||||
if os.path.exists(file_path):
|
||||
self.chatbot.logger.info("File is already downloaded")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def download(self, dataset):
|
||||
import kaggle # This triggers the API token check
|
||||
|
||||
future = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
kaggle.api.dataset_download_files,
|
||||
dataset=dataset,
|
||||
path=self.data_directory,
|
||||
quiet=False,
|
||||
unzip=True,
|
||||
),
|
||||
)
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
log.error("See asynctrain instead")
|
||||
|
||||
def asynctrain(self, *args, **kwargs):
|
||||
raise self.TrainerInitializationException()
|
||||
|
||||
|
||||
class SouthParkTrainer(KaggleTrainer):
|
||||
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
|
||||
super().__init__(
|
||||
chatbot,
|
||||
datapath,
|
||||
downloadpath="ubuntu_data_v2",
|
||||
kaggle_dataset="tovarischsukhov/southparklines",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class MovieTrainer(KaggleTrainer):
|
||||
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
|
||||
super().__init__(
|
||||
chatbot,
|
||||
datapath,
|
||||
downloadpath="kaggle_movies",
|
||||
kaggle_dataset="Cornell-University/movie-dialog-corpus",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def run_movie_training(self):
|
||||
dialogue_file = "movie_lines.tsv"
|
||||
conversation_file = "movie_conversations.tsv"
|
||||
log.info(f"Beginning dialogue training on {dialogue_file}")
|
||||
start_time = time.time()
|
||||
|
||||
tagger = PosLemmaTagger(language=self.chatbot.storage.tagger.language)
|
||||
|
||||
# [lineID, characterID, movieID, character name, text of utterance]
|
||||
# File parsing from https://www.kaggle.com/mushaya/conversation-chatbot
|
||||
|
||||
with open(self.data_directory / conversation_file, "r", encoding="utf-8-sig") as conv_tsv:
|
||||
conv_lines = conv_tsv.readlines()
|
||||
with open(self.data_directory / dialogue_file, "r", encoding="utf-8-sig") as lines_tsv:
|
||||
dialog_lines = lines_tsv.readlines()
|
||||
|
||||
# trans_dict = str.maketrans({"<u>": "__", "</u>": "__", '""': '"'})
|
||||
|
||||
lines_dict = {}
|
||||
for line in dialog_lines:
|
||||
_line = line[:-1].strip('"').split("\t")
|
||||
if len(_line) >= 5: # Only good lines
|
||||
lines_dict[_line[0]] = (
|
||||
html.unescape(("".join(_line[4:])).strip())
|
||||
.replace("<u>", "__")
|
||||
.replace("</u>", "__")
|
||||
.replace('""', '"')
|
||||
)
|
||||
else:
|
||||
log.debug(f"Bad line {_line}")
|
||||
|
||||
# collecting line ids for each conversation
|
||||
conv = []
|
||||
for line in conv_lines[:-1]:
|
||||
_line = line[:-1].split("\t")[-1][1:-1].replace("'", "").replace(" ", ",")
|
||||
conv.append(_line.split(","))
|
||||
|
||||
# conversations = csv.reader(conv_tsv, delimiter="\t")
|
||||
#
|
||||
# reader = csv.reader(lines_tsv, delimiter="\t")
|
||||
#
|
||||
#
|
||||
#
|
||||
# lines_dict = {}
|
||||
# for row in reader:
|
||||
# try:
|
||||
# lines_dict[row[0].strip('"')] = row[4]
|
||||
# except:
|
||||
# log.exception(f"Bad line: {row}")
|
||||
# pass
|
||||
# else:
|
||||
# # log.info(f"Good line: {row}")
|
||||
# pass
|
||||
#
|
||||
# # lines_dict = {row[0].strip('"'): row[4] for row in reader_list}
|
||||
|
||||
statements_from_file = []
|
||||
save_every = 300
|
||||
count = 0
|
||||
|
||||
# [characterID of first, characterID of second, movieID, list of utterances]
|
||||
async for lines in AsyncIter(conv):
|
||||
previous_statement_text = None
|
||||
previous_statement_search_text = ""
|
||||
|
||||
for line in lines:
|
||||
text = lines_dict[line]
|
||||
statement = Statement(
|
||||
text=text,
|
||||
in_response_to=previous_statement_text,
|
||||
conversation="training",
|
||||
)
|
||||
|
||||
for preprocessor in self.chatbot.preprocessors:
|
||||
statement = preprocessor(statement)
|
||||
|
||||
statement.search_text = tagger.get_text_index_string(statement.text)
|
||||
statement.search_in_response_to = previous_statement_search_text
|
||||
|
||||
previous_statement_text = statement.text
|
||||
previous_statement_search_text = statement.search_text
|
||||
|
||||
statements_from_file.append(statement)
|
||||
|
||||
count += 1
|
||||
if count >= save_every:
|
||||
if statements_from_file:
|
||||
self.chatbot.storage.create_many(statements_from_file)
|
||||
statements_from_file = []
|
||||
count = 0
|
||||
|
||||
if statements_from_file:
|
||||
self.chatbot.storage.create_many(statements_from_file)
|
||||
|
||||
log.info(f"Training took {time.time() - start_time} seconds.")
|
||||
|
||||
async def asynctrain(self, *args, **kwargs):
|
||||
extracted_lines = self.data_directory / "movie_lines.tsv"
|
||||
extracted_lines: pathlib.Path
|
||||
|
||||
# Download and extract the Ubuntu dialog corpus if needed
|
||||
if not extracted_lines.exists():
|
||||
await self.download(self.kaggle_dataset)
|
||||
else:
|
||||
log.info("Movie dialog already downloaded")
|
||||
if not extracted_lines.exists():
|
||||
raise FileNotFoundError(f"{extracted_lines}")
|
||||
|
||||
await self.run_movie_training()
|
||||
|
||||
return True
|
||||
|
||||
# train_dialogue = kwargs.get("train_dialogue", True)
|
||||
# train_196_dialogue = kwargs.get("train_196", False)
|
||||
# train_301_dialogue = kwargs.get("train_301", False)
|
||||
#
|
||||
# if train_dialogue:
|
||||
# await self.run_dialogue_training(extracted_dir, "dialogueText.csv")
|
||||
#
|
||||
# if train_196_dialogue:
|
||||
# await self.run_dialogue_training(extracted_dir, "dialogueText_196.csv")
|
||||
#
|
||||
# if train_301_dialogue:
|
||||
# await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv")
|
||||
|
||||
|
||||
class UbuntuCorpusTrainer2(KaggleTrainer):
|
||||
def __init__(self, chatbot, datapath: pathlib.Path, **kwargs):
|
||||
super().__init__(
|
||||
chatbot,
|
||||
datapath,
|
||||
downloadpath="kaggle_ubuntu",
|
||||
kaggle_dataset="rtatman/ubuntu-dialogue-corpus",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def asynctrain(self, *args, **kwargs):
|
||||
extracted_dir = self.data_directory / "Ubuntu-dialogue-corpus"
|
||||
|
||||
# Download and extract the Ubuntu dialog corpus if needed
|
||||
if not extracted_dir.exists():
|
||||
await self.download(self.kaggle_dataset)
|
||||
else:
|
||||
log.info("Ubuntu dialogue already downloaded")
|
||||
if not extracted_dir.exists():
|
||||
raise FileNotFoundError("Did not extract in the expected way")
|
||||
|
||||
train_dialogue = kwargs.get("train_dialogue", True)
|
||||
train_196_dialogue = kwargs.get("train_196", False)
|
||||
train_301_dialogue = kwargs.get("train_301", False)
|
||||
|
||||
if train_dialogue:
|
||||
await self.run_dialogue_training(extracted_dir, "dialogueText.csv")
|
||||
|
||||
if train_196_dialogue:
|
||||
await self.run_dialogue_training(extracted_dir, "dialogueText_196.csv")
|
||||
|
||||
if train_301_dialogue:
|
||||
await self.run_dialogue_training(extracted_dir, "dialogueText_301.csv")
|
||||
|
||||
return True
|
||||
|
||||
async def run_dialogue_training(self, extracted_dir, dialogue_file):
|
||||
log.info(f"Beginning dialogue training on {dialogue_file}")
|
||||
start_time = time.time()
|
||||
|
||||
tagger = PosLemmaTagger(language=self.chatbot.storage.tagger.language)
|
||||
|
||||
with open(extracted_dir / dialogue_file, "r", encoding="utf-8") as dg:
|
||||
reader = csv.DictReader(dg)
|
||||
|
||||
next(reader) # Skip the header
|
||||
|
||||
last_dialogue_id = None
|
||||
previous_statement_text = None
|
||||
previous_statement_search_text = ""
|
||||
statements_from_file = []
|
||||
|
||||
save_every = 50
|
||||
count = 0
|
||||
|
||||
async for row in AsyncIter(reader):
|
||||
dialogue_id = row["dialogueID"]
|
||||
if dialogue_id != last_dialogue_id:
|
||||
previous_statement_text = None
|
||||
previous_statement_search_text = ""
|
||||
last_dialogue_id = dialogue_id
|
||||
count += 1
|
||||
if count >= save_every:
|
||||
if statements_from_file:
|
||||
self.chatbot.storage.create_many(statements_from_file)
|
||||
statements_from_file = []
|
||||
count = 0
|
||||
|
||||
if len(row) > 0:
|
||||
statement = Statement(
|
||||
text=row["text"],
|
||||
in_response_to=previous_statement_text,
|
||||
conversation="training",
|
||||
# created_at=date_parser.parse(row["date"]),
|
||||
persona=row["from"],
|
||||
)
|
||||
|
||||
for preprocessor in self.chatbot.preprocessors:
|
||||
statement = preprocessor(statement)
|
||||
|
||||
statement.search_text = tagger.get_text_index_string(statement.text)
|
||||
statement.search_in_response_to = previous_statement_search_text
|
||||
|
||||
previous_statement_text = statement.text
|
||||
previous_statement_search_text = statement.search_text
|
||||
|
||||
statements_from_file.append(statement)
|
||||
|
||||
if statements_from_file:
|
||||
self.chatbot.storage.create_many(statements_from_file)
|
||||
|
||||
log.info(f"Training took {time.time() - start_time} seconds.")
|
||||
|
||||
|
||||
class TwitterCorpusTrainer(Trainer):
|
||||
pass
|
||||
# def train(self, *args, **kwargs):
|
||||
# """
|
||||
# Train the chat bot based on the provided list of
|
||||
# statements that represents a single conversation.
|
||||
# """
|
||||
# import twint
|
||||
#
|
||||
# c = twint.Config()
|
||||
# c.__dict__.update(kwargs)
|
||||
# twint.run.Search(c)
|
||||
#
|
||||
#
|
||||
# previous_statement_text = None
|
||||
# previous_statement_search_text = ''
|
||||
#
|
||||
# statements_to_create = []
|
||||
#
|
||||
# for conversation_count, text in enumerate(conversation):
|
||||
# if self.show_training_progress:
|
||||
# utils.print_progress_bar(
|
||||
# 'List Trainer',
|
||||
# conversation_count + 1, len(conversation)
|
||||
# )
|
||||
#
|
||||
# statement_search_text = self.chatbot.storage.tagger.get_text_index_string(text)
|
||||
#
|
||||
# statement = self.get_preprocessed_statement(
|
||||
# Statement(
|
||||
# text=text,
|
||||
# search_text=statement_search_text,
|
||||
# in_response_to=previous_statement_text,
|
||||
# search_in_response_to=previous_statement_search_text,
|
||||
# conversation='training'
|
||||
# )
|
||||
# )
|
||||
#
|
||||
# previous_statement_text = statement.text
|
||||
# previous_statement_search_text = statement_search_text
|
||||
#
|
||||
# statements_to_create.append(statement)
|
||||
#
|
||||
# self.chatbot.storage.create_many(statements_to_create)
|
Loading…
x
Reference in New Issue
Block a user