You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Fox-V3/chatter/chatterbot/storage/sql_storage.py

404 lines
13 KiB

from . import StorageAdapter
def get_response_table(response):
from ..ext.sqlalchemy_app.models import Response
return Response(text=response.text, occurrence=response.occurrence)
class SQLStorageAdapter(StorageAdapter):
"""
SQLStorageAdapter allows ChatterBot to store conversation
data semi-structured T-SQL database, virtually, any database
that SQL Alchemy supports.
Notes:
Tables may change (and will), so, save your training data.
There is no data migration (yet).
Performance test not done yet.
Tests using other databases not finished.
All parameters are optional, by default a sqlite database is used.
It will check if tables are present, if they are not, it will attempt
to create the required tables.
:keyword database: Used for sqlite database. Ignored if database_uri is specified.
:type database: str
:keyword database_uri: eg: sqlite:///database_test.db", use database_uri or database,
database_uri can be specified to choose database driver (database parameter will be ignored).
:type database_uri: str
:keyword read_only: False by default, makes all operations read only, has priority over all DB operations
so, create, update, delete will NOT be executed
:type read_only: bool
"""
def __init__(self, **kwargs):
super(SQLStorageAdapter, self).__init__(**kwargs)
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
default_uri = "sqlite:///db.sqlite3"
database_name = self.kwargs.get("database", False)
# None results in a sqlite in-memory database as the default
if database_name is None:
default_uri = "sqlite://"
self.database_uri = self.kwargs.get(
"database_uri", default_uri
)
# Create a sqlite file if a database name is provided
if database_name:
self.database_uri = "sqlite:///" + database_name
self.engine = create_engine(self.database_uri, convert_unicode=True)
from re import search
if search('^sqlite://', self.database_uri):
from sqlalchemy.engine import Engine
from sqlalchemy import event
@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
dbapi_connection.execute('PRAGMA journal_mode=WAL')
dbapi_connection.execute('PRAGMA synchronous=NORMAL')
self.read_only = self.kwargs.get(
"read_only", False
)
if not self.engine.dialect.has_table(self.engine, 'Statement'):
self.create()
self.Session = sessionmaker(bind=self.engine, expire_on_commit=True)
# ChatterBot's internal query builder is not yet supported for this adapter
self.adapter_supports_queries = False
def get_statement_model(self):
"""
Return the statement model.
"""
from ..ext.sqlalchemy_app.models import Statement
return Statement
def get_response_model(self):
"""
Return the response model.
"""
from ..ext.sqlalchemy_app.models import Response
return Response
def get_conversation_model(self):
"""
Return the conversation model.
"""
from ..ext.sqlalchemy_app.models import Conversation
return Conversation
def get_tag_model(self):
"""
Return the conversation model.
"""
from ..ext.sqlalchemy_app.models import Tag
return Tag
def count(self):
"""
Return the number of entries in the database.
"""
Statement = self.get_model('statement')
session = self.Session()
statement_count = session.query(Statement).count()
session.close()
return statement_count
def find(self, statement_text):
"""
Returns a statement if it exists otherwise None
"""
Statement = self.get_model('statement')
session = self.Session()
query = session.query(Statement).filter_by(text=statement_text)
record = query.first()
if record:
statement = record.get_statement()
session.close()
return statement
session.close()
return None
def remove(self, statement_text):
"""
Removes the statement that matches the input text.
Removes any responses from statements where the response text matches
the input text.
"""
Statement = self.get_model('statement')
session = self.Session()
query = session.query(Statement).filter_by(text=statement_text)
record = query.first()
session.delete(record)
self._session_finish(session)
def filter(self, **kwargs):
"""
Returns a list of objects from the database.
The kwargs parameter can contain any number
of attributes. Only objects which contain
all listed attributes and in which all values
match for all listed attributes will be returned.
"""
Statement = self.get_model('statement')
Response = self.get_model('response')
session = self.Session()
filter_parameters = kwargs.copy()
statements = []
_query = None
if len(filter_parameters) == 0:
_response_query = session.query(Statement)
statements.extend(_response_query.all())
else:
for i, fp in enumerate(filter_parameters):
_filter = filter_parameters[fp]
if fp in ['in_response_to', 'in_response_to__contains']:
_response_query = session.query(Statement)
if isinstance(_filter, list):
if len(_filter) == 0:
_query = _response_query.filter(
Statement.in_response_to is None # NOQA Here must use == instead of is
)
else:
for f in _filter:
_query = _response_query.filter(
Statement.in_response_to.contains(get_response_table(f)))
else:
if fp == 'in_response_to__contains':
_query = _response_query.join(Response).filter(Response.text == _filter)
else:
_query = _response_query.filter(Statement.in_response_to is None) # NOQA
else:
if _query:
_query = _query.filter(Response.statement_text.like('%' + _filter + '%'))
else:
_response_query = session.query(Response)
_query = _response_query.filter(Response.statement_text.like('%' + _filter + '%'))
if _query is None:
return []
if len(filter_parameters) == i + 1:
statements.extend(_query.all())
results = []
for statement in statements:
if isinstance(statement, Response):
if statement and statement.statement_table:
results.append(statement.statement_table.get_statement())
else:
if statement:
results.append(statement.get_statement())
session.close()
return results
def update(self, statement):
"""
Modifies an entry in the database.
Creates an entry if one does not exist.
"""
Statement = self.get_model('statement')
Response = self.get_model('response')
Tag = self.get_model('tag')
if statement:
session = self.Session()
query = session.query(Statement).filter_by(text=statement.text)
record = query.first()
# Create a new statement entry if one does not already exist
if not record:
record = Statement(text=statement.text)
record.extra_data = dict(statement.extra_data)
for _tag in statement.tags:
tag = session.query(Tag).filter_by(name=_tag).first()
if not tag:
# Create the record
tag = Tag(name=_tag)
record.tags.append(tag)
# Get or create the response records as needed
for response in statement.in_response_to:
_response = session.query(Response).filter_by(
text=response.text,
statement_text=statement.text
).first()
if _response:
_response.occurrence += 1
else:
# Create the record
_response = Response(
text=response.text,
statement_text=statement.text,
occurrence=response.occurrence
)
record.in_response_to.append(_response)
session.add(record)
self._session_finish(session)
def create_conversation(self):
"""
Create a new conversation.
"""
Conversation = self.get_model('conversation')
session = self.Session()
conversation = Conversation()
session.add(conversation)
session.flush()
session.refresh(conversation)
conversation_id = conversation.id
session.commit()
session.close()
return conversation_id
def add_to_conversation(self, conversation_id, statement, response):
"""
Add the statement and response to the conversation.
"""
Statement = self.get_model('statement')
Conversation = self.get_model('conversation')
session = self.Session()
conversation = session.query(Conversation).get(conversation_id)
statement_query = session.query(Statement).filter_by(
text=statement.text
).first()
response_query = session.query(Statement).filter_by(
text=response.text
).first()
# Make sure the statements exist
if not statement_query:
self.update(statement)
statement_query = session.query(Statement).filter_by(
text=statement.text
).first()
if not response_query:
self.update(response)
response_query = session.query(Statement).filter_by(
text=response.text
).first()
conversation.statements.append(statement_query)
conversation.statements.append(response_query)
session.add(conversation)
self._session_finish(session)
def get_latest_response(self, conversation_id):
"""
Returns the latest response in a conversation if it exists.
Returns None if a matching conversation cannot be found.
"""
Statement = self.get_model('statement')
session = self.Session()
statement = None
statement_query = session.query(Statement).filter(
Statement.conversations.any(id=conversation_id)
).order_by(Statement.id)
if statement_query.count() >= 2:
statement = statement_query[-2].get_statement()
# Handle the case of the first statement in the list
elif statement_query.count() == 1:
statement = statement_query[0].get_statement()
session.close()
return statement
def get_random(self):
"""
Returns a random statement from the database
"""
import random
Statement = self.get_model('statement')
session = self.Session()
count = self.count()
if count < 1:
raise self.EmptyDatabaseException()
rand = random.randrange(0, count)
stmt = session.query(Statement)[rand]
statement = stmt.get_statement()
session.close()
return statement
def drop(self):
"""
Drop the database attached to a given adapter.
"""
from ..ext.sqlalchemy_app.models import Base
Base.metadata.drop_all(self.engine)
def create(self):
"""
Populate the database with the tables.
"""
from ..ext.sqlalchemy_app.models import Base
Base.metadata.create_all(self.engine)
def _session_finish(self, session, statement_text=None):
from sqlalchemy.exc import InvalidRequestError
try:
if not self.read_only:
session.commit()
else:
session.rollback()
except InvalidRequestError:
# Log the statement text and the exception
self.logger.exception(statement_text)
finally:
session.close()