You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
398 lines
11 KiB
398 lines
11 KiB
from chatter.chatterbot.storage import StorageAdapter
|
|
|
|
|
|
class Query(object):
|
|
|
|
def __init__(self, query=None):
|
|
if query is None:
|
|
self.query = {}
|
|
else:
|
|
self.query = query
|
|
|
|
def value(self):
|
|
return self.query.copy()
|
|
|
|
def raw(self, data):
|
|
query = self.query.copy()
|
|
|
|
query.update(data)
|
|
|
|
return Query(query)
|
|
|
|
def statement_text_equals(self, statement_text):
|
|
query = self.query.copy()
|
|
|
|
query['text'] = statement_text
|
|
|
|
return Query(query)
|
|
|
|
def statement_text_not_in(self, statements):
|
|
query = self.query.copy()
|
|
|
|
if 'text' not in query:
|
|
query['text'] = {}
|
|
|
|
if '$nin' not in query['text']:
|
|
query['text']['$nin'] = []
|
|
|
|
query['text']['$nin'].extend(statements)
|
|
|
|
return Query(query)
|
|
|
|
def statement_response_list_contains(self, statement_text):
|
|
query = self.query.copy()
|
|
|
|
if 'in_response_to' not in query:
|
|
query['in_response_to'] = {}
|
|
|
|
if '$elemMatch' not in query['in_response_to']:
|
|
query['in_response_to']['$elemMatch'] = {}
|
|
|
|
query['in_response_to']['$elemMatch']['text'] = statement_text
|
|
|
|
return Query(query)
|
|
|
|
def statement_response_list_equals(self, response_list):
|
|
query = self.query.copy()
|
|
|
|
query['in_response_to'] = response_list
|
|
|
|
return Query(query)
|
|
|
|
|
|
class MongoDatabaseAdapter(StorageAdapter):
|
|
"""
|
|
The MongoDatabaseAdapter is an interface that allows
|
|
ChatterBot to store statements in a MongoDB database.
|
|
|
|
:keyword database: The name of the database you wish to connect to.
|
|
:type database: str
|
|
|
|
.. code-block:: python
|
|
|
|
database='chatterbot-database'
|
|
|
|
:keyword database_uri: The URI of a remote instance of MongoDB.
|
|
:type database_uri: str
|
|
|
|
.. code-block:: python
|
|
|
|
database_uri='mongodb://example.com:8100/'
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super(MongoDatabaseAdapter, self).__init__(**kwargs)
|
|
from pymongo import MongoClient
|
|
from pymongo.errors import OperationFailure
|
|
|
|
self.database_name = self.kwargs.get(
|
|
'database', 'chatterbot-database'
|
|
)
|
|
self.database_uri = self.kwargs.get(
|
|
'database_uri', 'mongodb://localhost:27017/'
|
|
)
|
|
|
|
# Use the default host and port
|
|
self.client = MongoClient(self.database_uri)
|
|
|
|
# Increase the sort buffer to 42M if possible
|
|
try:
|
|
self.client.admin.command({'setParameter': 1, 'internalQueryExecMaxBlockingSortBytes': 44040192})
|
|
except OperationFailure:
|
|
pass
|
|
|
|
# Specify the name of the database
|
|
self.database = self.client[self.database_name]
|
|
|
|
# The mongo collection of statement documents
|
|
self.statements = self.database['statements']
|
|
|
|
# The mongo collection of conversation documents
|
|
self.conversations = self.database['conversations']
|
|
|
|
# Set a requirement for the text attribute to be unique
|
|
self.statements.create_index('text', unique=True)
|
|
|
|
self.base_query = Query()
|
|
|
|
def get_statement_model(self):
|
|
"""
|
|
Return the class for the statement model.
|
|
"""
|
|
from chatter.chatterbot.conversation import Statement
|
|
|
|
# Create a storage-aware statement
|
|
statement = Statement
|
|
statement.storage = self
|
|
|
|
return statement
|
|
|
|
def get_response_model(self):
|
|
"""
|
|
Return the class for the response model.
|
|
"""
|
|
from chatter.chatterbot.conversation import Response
|
|
|
|
# Create a storage-aware response
|
|
response = Response
|
|
response.storage = self
|
|
|
|
return response
|
|
|
|
def count(self):
|
|
return self.statements.count()
|
|
|
|
def find(self, statement_text):
|
|
Statement = self.get_model('statement')
|
|
query = self.base_query.statement_text_equals(statement_text)
|
|
|
|
values = self.statements.find_one(query.value())
|
|
|
|
if not values:
|
|
return None
|
|
|
|
del values['text']
|
|
|
|
# Build the objects for the response list
|
|
values['in_response_to'] = self.deserialize_responses(
|
|
values.get('in_response_to', [])
|
|
)
|
|
|
|
return Statement(statement_text, **values)
|
|
|
|
def deserialize_responses(self, response_list):
|
|
"""
|
|
Takes the list of response items and returns
|
|
the list converted to Response objects.
|
|
"""
|
|
Statement = self.get_model('statement')
|
|
Response = self.get_model('response')
|
|
proxy_statement = Statement('')
|
|
|
|
for response in response_list:
|
|
text = response['text']
|
|
del response['text']
|
|
|
|
proxy_statement.add_response(
|
|
Response(text, **response)
|
|
)
|
|
|
|
return proxy_statement.in_response_to
|
|
|
|
def mongo_to_object(self, statement_data):
|
|
"""
|
|
Return Statement object when given data
|
|
returned from Mongo DB.
|
|
"""
|
|
Statement = self.get_model('statement')
|
|
statement_text = statement_data['text']
|
|
del statement_data['text']
|
|
|
|
statement_data['in_response_to'] = self.deserialize_responses(
|
|
statement_data.get('in_response_to', [])
|
|
)
|
|
|
|
return Statement(statement_text, **statement_data)
|
|
|
|
def filter(self, **kwargs):
|
|
"""
|
|
Returns a list of statements in the database
|
|
that match the parameters specified.
|
|
"""
|
|
import pymongo
|
|
|
|
query = self.base_query
|
|
|
|
order_by = kwargs.pop('order_by', None)
|
|
|
|
# Convert Response objects to data
|
|
if 'in_response_to' in kwargs:
|
|
serialized_responses = []
|
|
for response in kwargs['in_response_to']:
|
|
serialized_responses.append({'text': response})
|
|
|
|
query = query.statement_response_list_equals(serialized_responses)
|
|
del kwargs['in_response_to']
|
|
|
|
if 'in_response_to__contains' in kwargs:
|
|
query = query.statement_response_list_contains(
|
|
kwargs['in_response_to__contains']
|
|
)
|
|
del kwargs['in_response_to__contains']
|
|
|
|
query = query.raw(kwargs)
|
|
|
|
matches = self.statements.find(query.value())
|
|
|
|
if order_by:
|
|
|
|
direction = pymongo.ASCENDING
|
|
|
|
# Sort so that newer datetimes appear first
|
|
if order_by == 'created_at':
|
|
direction = pymongo.DESCENDING
|
|
|
|
matches = matches.sort(order_by, direction)
|
|
|
|
results = []
|
|
|
|
for match in list(matches):
|
|
results.append(self.mongo_to_object(match))
|
|
|
|
return results
|
|
|
|
def update(self, statement):
|
|
from pymongo import UpdateOne
|
|
from pymongo.errors import BulkWriteError
|
|
|
|
data = statement.serialize()
|
|
|
|
operations = []
|
|
|
|
update_operation = UpdateOne(
|
|
{'text': statement.text},
|
|
{'$set': data},
|
|
upsert=True
|
|
)
|
|
operations.append(update_operation)
|
|
|
|
# Make sure that an entry for each response is saved
|
|
for response_dict in data.get('in_response_to', []):
|
|
response_text = response_dict.get('text')
|
|
|
|
# $setOnInsert does nothing if the document is not created
|
|
update_operation = UpdateOne(
|
|
{'text': response_text},
|
|
{'$set': response_dict},
|
|
upsert=True
|
|
)
|
|
operations.append(update_operation)
|
|
|
|
try:
|
|
self.statements.bulk_write(operations, ordered=False)
|
|
except BulkWriteError as bwe:
|
|
# Log the details of a bulk write error
|
|
self.logger.error(str(bwe.details))
|
|
|
|
return statement
|
|
|
|
def create_conversation(self):
|
|
"""
|
|
Create a new conversation.
|
|
"""
|
|
conversation_id = self.conversations.insert_one({}).inserted_id
|
|
return conversation_id
|
|
|
|
def get_latest_response(self, conversation_id):
|
|
"""
|
|
Returns the latest response in a conversation if it exists.
|
|
Returns None if a matching conversation cannot be found.
|
|
"""
|
|
from pymongo import DESCENDING
|
|
|
|
statements = list(self.statements.find({
|
|
'conversations.id': conversation_id
|
|
}).sort('conversations.created_at', DESCENDING))
|
|
|
|
if not statements:
|
|
return None
|
|
|
|
return self.mongo_to_object(statements[-2])
|
|
|
|
def add_to_conversation(self, conversation_id, statement, response):
|
|
"""
|
|
Add the statement and response to the conversation.
|
|
"""
|
|
from datetime import datetime, timedelta
|
|
self.statements.update_one(
|
|
{
|
|
'text': statement.text
|
|
},
|
|
{
|
|
'$push': {
|
|
'conversations': {
|
|
'id': conversation_id,
|
|
'created_at': datetime.utcnow()
|
|
}
|
|
}
|
|
}
|
|
)
|
|
self.statements.update_one(
|
|
{
|
|
'text': response.text
|
|
},
|
|
{
|
|
'$push': {
|
|
'conversations': {
|
|
'id': conversation_id,
|
|
# Force the response to be at least one millisecond after the input statement
|
|
'created_at': datetime.utcnow() + timedelta(milliseconds=1)
|
|
}
|
|
}
|
|
}
|
|
)
|
|
|
|
def get_random(self):
|
|
"""
|
|
Returns a random statement from the database
|
|
"""
|
|
from random import randint
|
|
|
|
count = self.count()
|
|
|
|
if count < 1:
|
|
raise self.EmptyDatabaseException()
|
|
|
|
random_integer = randint(0, count - 1)
|
|
|
|
statements = self.statements.find().limit(1).skip(random_integer)
|
|
|
|
return self.mongo_to_object(list(statements)[0])
|
|
|
|
def remove(self, statement_text):
|
|
"""
|
|
Removes the statement that matches the input text.
|
|
Removes any responses from statements if the response text matches the
|
|
input text.
|
|
"""
|
|
for statement in self.filter(in_response_to__contains=statement_text):
|
|
statement.remove_response(statement_text)
|
|
self.update(statement)
|
|
|
|
self.statements.delete_one({'text': statement_text})
|
|
|
|
def get_response_statements(self):
|
|
"""
|
|
Return only statements that are in response to another statement.
|
|
A statement must exist which lists the closest matching statement in the
|
|
in_response_to field. Otherwise, the logic adapter may find a closest
|
|
matching statement that does not have a known response.
|
|
"""
|
|
response_query = self.statements.aggregate([{'$group': {'_id': '$in_response_to.text'}}])
|
|
|
|
responses = []
|
|
for r in response_query:
|
|
try:
|
|
responses.extend(r['_id'])
|
|
except TypeError:
|
|
pass
|
|
|
|
_statement_query = {
|
|
'text': {
|
|
'$in': responses
|
|
}
|
|
}
|
|
|
|
_statement_query.update(self.base_query.value())
|
|
statement_query = self.statements.find(_statement_query)
|
|
statement_objects = []
|
|
for statement in list(statement_query):
|
|
statement_objects.append(self.mongo_to_object(statement))
|
|
return statement_objects
|
|
|
|
def drop(self):
|
|
"""
|
|
Remove the database.
|
|
"""
|
|
self.client.drop_database(self.database_name)
|