You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
404 lines
13 KiB
404 lines
13 KiB
from chatter.chatterbot.storage import StorageAdapter
|
|
|
|
|
|
def get_response_table(response):
|
|
from chatter.chatterbot.ext.sqlalchemy_app.models import Response
|
|
return Response(text=response.text, occurrence=response.occurrence)
|
|
|
|
|
|
class SQLStorageAdapter(StorageAdapter):
|
|
"""
|
|
SQLStorageAdapter allows ChatterBot to store conversation
|
|
data semi-structured T-SQL database, virtually, any database
|
|
that SQL Alchemy supports.
|
|
|
|
Notes:
|
|
Tables may change (and will), so, save your training data.
|
|
There is no data migration (yet).
|
|
Performance test not done yet.
|
|
Tests using other databases not finished.
|
|
|
|
All parameters are optional, by default a sqlite database is used.
|
|
|
|
It will check if tables are present, if they are not, it will attempt
|
|
to create the required tables.
|
|
|
|
:keyword database: Used for sqlite database. Ignored if database_uri is specified.
|
|
:type database: str
|
|
|
|
:keyword database_uri: eg: sqlite:///database_test.db", use database_uri or database,
|
|
database_uri can be specified to choose database driver (database parameter will be ignored).
|
|
:type database_uri: str
|
|
|
|
:keyword read_only: False by default, makes all operations read only, has priority over all DB operations
|
|
so, create, update, delete will NOT be executed
|
|
:type read_only: bool
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super(SQLStorageAdapter, self).__init__(**kwargs)
|
|
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
default_uri = "sqlite:///db.sqlite3"
|
|
|
|
database_name = self.kwargs.get("database", False)
|
|
|
|
# None results in a sqlite in-memory database as the default
|
|
if database_name is None:
|
|
default_uri = "sqlite://"
|
|
|
|
self.database_uri = self.kwargs.get(
|
|
"database_uri", default_uri
|
|
)
|
|
|
|
# Create a sqlite file if a database name is provided
|
|
if database_name:
|
|
self.database_uri = "sqlite:///" + database_name
|
|
|
|
self.engine = create_engine(self.database_uri, convert_unicode=True)
|
|
|
|
from re import search
|
|
|
|
if search('^sqlite://', self.database_uri):
|
|
from sqlalchemy.engine import Engine
|
|
from sqlalchemy import event
|
|
|
|
@event.listens_for(Engine, "connect")
|
|
def set_sqlite_pragma(dbapi_connection, connection_record):
|
|
dbapi_connection.execute('PRAGMA journal_mode=WAL')
|
|
dbapi_connection.execute('PRAGMA synchronous=NORMAL')
|
|
|
|
self.read_only = self.kwargs.get(
|
|
"read_only", False
|
|
)
|
|
|
|
if not self.engine.dialect.has_table(self.engine, 'Statement'):
|
|
self.create()
|
|
|
|
self.Session = sessionmaker(bind=self.engine, expire_on_commit=True)
|
|
|
|
# ChatterBot's internal query builder is not yet supported for this adapter
|
|
self.adapter_supports_queries = False
|
|
|
|
def get_statement_model(self):
|
|
"""
|
|
Return the statement model.
|
|
"""
|
|
from chatter.chatterbot.ext.sqlalchemy_app.models import Statement
|
|
return Statement
|
|
|
|
def get_response_model(self):
|
|
"""
|
|
Return the response model.
|
|
"""
|
|
from chatter.chatterbot.ext.sqlalchemy_app.models import Response
|
|
return Response
|
|
|
|
def get_conversation_model(self):
|
|
"""
|
|
Return the conversation model.
|
|
"""
|
|
from chatter.chatterbot.ext.sqlalchemy_app.models import Conversation
|
|
return Conversation
|
|
|
|
def get_tag_model(self):
|
|
"""
|
|
Return the conversation model.
|
|
"""
|
|
from chatter.chatterbot.ext.sqlalchemy_app.models import Tag
|
|
return Tag
|
|
|
|
def count(self):
|
|
"""
|
|
Return the number of entries in the database.
|
|
"""
|
|
Statement = self.get_model('statement')
|
|
|
|
session = self.Session()
|
|
statement_count = session.query(Statement).count()
|
|
session.close()
|
|
return statement_count
|
|
|
|
def find(self, statement_text):
|
|
"""
|
|
Returns a statement if it exists otherwise None
|
|
"""
|
|
Statement = self.get_model('statement')
|
|
session = self.Session()
|
|
|
|
query = session.query(Statement).filter_by(text=statement_text)
|
|
record = query.first()
|
|
if record:
|
|
statement = record.get_statement()
|
|
session.close()
|
|
return statement
|
|
|
|
session.close()
|
|
return None
|
|
|
|
def remove(self, statement_text):
|
|
"""
|
|
Removes the statement that matches the input text.
|
|
Removes any responses from statements where the response text matches
|
|
the input text.
|
|
"""
|
|
Statement = self.get_model('statement')
|
|
session = self.Session()
|
|
|
|
query = session.query(Statement).filter_by(text=statement_text)
|
|
record = query.first()
|
|
|
|
session.delete(record)
|
|
|
|
self._session_finish(session)
|
|
|
|
def filter(self, **kwargs):
|
|
"""
|
|
Returns a list of objects from the database.
|
|
The kwargs parameter can contain any number
|
|
of attributes. Only objects which contain
|
|
all listed attributes and in which all values
|
|
match for all listed attributes will be returned.
|
|
"""
|
|
Statement = self.get_model('statement')
|
|
Response = self.get_model('response')
|
|
|
|
session = self.Session()
|
|
|
|
filter_parameters = kwargs.copy()
|
|
|
|
statements = []
|
|
_query = None
|
|
|
|
if len(filter_parameters) == 0:
|
|
_response_query = session.query(Statement)
|
|
statements.extend(_response_query.all())
|
|
else:
|
|
for i, fp in enumerate(filter_parameters):
|
|
_filter = filter_parameters[fp]
|
|
if fp in ['in_response_to', 'in_response_to__contains']:
|
|
_response_query = session.query(Statement)
|
|
if isinstance(_filter, list):
|
|
if len(_filter) == 0:
|
|
_query = _response_query.filter(
|
|
Statement.in_response_to == None # NOQA Here must use == instead of is
|
|
)
|
|
else:
|
|
for f in _filter:
|
|
_query = _response_query.filter(
|
|
Statement.in_response_to.contains(get_response_table(f)))
|
|
else:
|
|
if fp == 'in_response_to__contains':
|
|
_query = _response_query.join(Response).filter(Response.text == _filter)
|
|
else:
|
|
_query = _response_query.filter(Statement.in_response_to == None) # NOQA
|
|
else:
|
|
if _query:
|
|
_query = _query.filter(Response.statement_text.like('%' + _filter + '%'))
|
|
else:
|
|
_response_query = session.query(Response)
|
|
_query = _response_query.filter(Response.statement_text.like('%' + _filter + '%'))
|
|
|
|
if _query is None:
|
|
return []
|
|
if len(filter_parameters) == i + 1:
|
|
statements.extend(_query.all())
|
|
|
|
results = []
|
|
|
|
for statement in statements:
|
|
if isinstance(statement, Response):
|
|
if statement and statement.statement_table:
|
|
results.append(statement.statement_table.get_statement())
|
|
else:
|
|
if statement:
|
|
results.append(statement.get_statement())
|
|
|
|
session.close()
|
|
|
|
return results
|
|
|
|
def update(self, statement):
|
|
"""
|
|
Modifies an entry in the database.
|
|
Creates an entry if one does not exist.
|
|
"""
|
|
Statement = self.get_model('statement')
|
|
Response = self.get_model('response')
|
|
Tag = self.get_model('tag')
|
|
|
|
if statement:
|
|
session = self.Session()
|
|
|
|
query = session.query(Statement).filter_by(text=statement.text)
|
|
record = query.first()
|
|
|
|
# Create a new statement entry if one does not already exist
|
|
if not record:
|
|
record = Statement(text=statement.text)
|
|
|
|
record.extra_data = dict(statement.extra_data)
|
|
|
|
for _tag in statement.tags:
|
|
tag = session.query(Tag).filter_by(name=_tag).first()
|
|
|
|
if not tag:
|
|
# Create the record
|
|
tag = Tag(name=_tag)
|
|
|
|
record.tags.append(tag)
|
|
|
|
# Get or create the response records as needed
|
|
for response in statement.in_response_to:
|
|
_response = session.query(Response).filter_by(
|
|
text=response.text,
|
|
statement_text=statement.text
|
|
).first()
|
|
|
|
if _response:
|
|
_response.occurrence += 1
|
|
else:
|
|
# Create the record
|
|
_response = Response(
|
|
text=response.text,
|
|
statement_text=statement.text,
|
|
occurrence=response.occurrence
|
|
)
|
|
|
|
record.in_response_to.append(_response)
|
|
|
|
session.add(record)
|
|
|
|
self._session_finish(session)
|
|
|
|
def create_conversation(self):
|
|
"""
|
|
Create a new conversation.
|
|
"""
|
|
Conversation = self.get_model('conversation')
|
|
|
|
session = self.Session()
|
|
conversation = Conversation()
|
|
|
|
session.add(conversation)
|
|
session.flush()
|
|
|
|
session.refresh(conversation)
|
|
conversation_id = conversation.id
|
|
|
|
session.commit()
|
|
session.close()
|
|
|
|
return conversation_id
|
|
|
|
def add_to_conversation(self, conversation_id, statement, response):
|
|
"""
|
|
Add the statement and response to the conversation.
|
|
"""
|
|
Statement = self.get_model('statement')
|
|
Conversation = self.get_model('conversation')
|
|
|
|
session = self.Session()
|
|
conversation = session.query(Conversation).get(conversation_id)
|
|
|
|
statement_query = session.query(Statement).filter_by(
|
|
text=statement.text
|
|
).first()
|
|
response_query = session.query(Statement).filter_by(
|
|
text=response.text
|
|
).first()
|
|
|
|
# Make sure the statements exist
|
|
if not statement_query:
|
|
self.update(statement)
|
|
statement_query = session.query(Statement).filter_by(
|
|
text=statement.text
|
|
).first()
|
|
|
|
if not response_query:
|
|
self.update(response)
|
|
response_query = session.query(Statement).filter_by(
|
|
text=response.text
|
|
).first()
|
|
|
|
conversation.statements.append(statement_query)
|
|
conversation.statements.append(response_query)
|
|
|
|
session.add(conversation)
|
|
self._session_finish(session)
|
|
|
|
def get_latest_response(self, conversation_id):
|
|
"""
|
|
Returns the latest response in a conversation if it exists.
|
|
Returns None if a matching conversation cannot be found.
|
|
"""
|
|
Statement = self.get_model('statement')
|
|
|
|
session = self.Session()
|
|
statement = None
|
|
|
|
statement_query = session.query(Statement).filter(
|
|
Statement.conversations.any(id=conversation_id)
|
|
).order_by(Statement.id)
|
|
|
|
if statement_query.count() >= 2:
|
|
statement = statement_query[-2].get_statement()
|
|
|
|
# Handle the case of the first statement in the list
|
|
elif statement_query.count() == 1:
|
|
statement = statement_query[0].get_statement()
|
|
|
|
session.close()
|
|
|
|
return statement
|
|
|
|
def get_random(self):
|
|
"""
|
|
Returns a random statement from the database
|
|
"""
|
|
import random
|
|
|
|
Statement = self.get_model('statement')
|
|
|
|
session = self.Session()
|
|
count = self.count()
|
|
if count < 1:
|
|
raise self.EmptyDatabaseException()
|
|
|
|
rand = random.randrange(0, count)
|
|
stmt = session.query(Statement)[rand]
|
|
|
|
statement = stmt.get_statement()
|
|
|
|
session.close()
|
|
return statement
|
|
|
|
def drop(self):
|
|
"""
|
|
Drop the database attached to a given adapter.
|
|
"""
|
|
from chatter.chatterbot.ext.sqlalchemy_app.models import Base
|
|
Base.metadata.drop_all(self.engine)
|
|
|
|
def create(self):
|
|
"""
|
|
Populate the database with the tables.
|
|
"""
|
|
from chatter.chatterbot.ext.sqlalchemy_app.models import Base
|
|
Base.metadata.create_all(self.engine)
|
|
|
|
def _session_finish(self, session, statement_text=None):
|
|
from sqlalchemy.exc import InvalidRequestError
|
|
try:
|
|
if not self.read_only:
|
|
session.commit()
|
|
else:
|
|
session.rollback()
|
|
except InvalidRequestError:
|
|
# Log the statement text and the exception
|
|
self.logger.exception(statement_text)
|
|
finally:
|
|
session.close()
|