from . import StorageAdapter def get_response_table(response): from ..ext.sqlalchemy_app.models import Response return Response(text=response.text, occurrence=response.occurrence) class SQLStorageAdapter(StorageAdapter): """ SQLStorageAdapter allows ChatterBot to store conversation data semi-structured T-SQL database, virtually, any database that SQL Alchemy supports. Notes: Tables may change (and will), so, save your training data. There is no data migration (yet). Performance test not done yet. Tests using other databases not finished. All parameters are optional, by default a sqlite database is used. It will check if tables are present, if they are not, it will attempt to create the required tables. :keyword database: Used for sqlite database. Ignored if database_uri is specified. :type database: str :keyword database_uri: eg: sqlite:///database_test.db", use database_uri or database, database_uri can be specified to choose database driver (database parameter will be ignored). :type database_uri: str :keyword read_only: False by default, makes all operations read only, has priority over all DB operations so, create, update, delete will NOT be executed :type read_only: bool """ def __init__(self, **kwargs): super(SQLStorageAdapter, self).__init__(**kwargs) from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker default_uri = "sqlite:///db.sqlite3" database_name = self.kwargs.get("database", False) # None results in a sqlite in-memory database as the default if database_name is None: default_uri = "sqlite://" self.database_uri = self.kwargs.get( "database_uri", default_uri ) # Create a sqlite file if a database name is provided if database_name: self.database_uri = "sqlite:///" + database_name self.engine = create_engine(self.database_uri, convert_unicode=True) from re import search if search('^sqlite://', self.database_uri): from sqlalchemy.engine import Engine from sqlalchemy import event @event.listens_for(Engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): dbapi_connection.execute('PRAGMA journal_mode=WAL') dbapi_connection.execute('PRAGMA synchronous=NORMAL') self.read_only = self.kwargs.get( "read_only", False ) if not self.engine.dialect.has_table(self.engine, 'Statement'): self.create() self.Session = sessionmaker(bind=self.engine, expire_on_commit=True) # ChatterBot's internal query builder is not yet supported for this adapter self.adapter_supports_queries = False def get_statement_model(self): """ Return the statement model. """ from ..ext.sqlalchemy_app.models import Statement return Statement def get_response_model(self): """ Return the response model. """ from ..ext.sqlalchemy_app.models import Response return Response def get_conversation_model(self): """ Return the conversation model. """ from ..ext.sqlalchemy_app.models import Conversation return Conversation def get_tag_model(self): """ Return the conversation model. """ from ..ext.sqlalchemy_app.models import Tag return Tag def count(self): """ Return the number of entries in the database. """ Statement = self.get_model('statement') session = self.Session() statement_count = session.query(Statement).count() session.close() return statement_count def find(self, statement_text): """ Returns a statement if it exists otherwise None """ Statement = self.get_model('statement') session = self.Session() query = session.query(Statement).filter_by(text=statement_text) record = query.first() if record: statement = record.get_statement() session.close() return statement session.close() return None def remove(self, statement_text): """ Removes the statement that matches the input text. Removes any responses from statements where the response text matches the input text. """ Statement = self.get_model('statement') session = self.Session() query = session.query(Statement).filter_by(text=statement_text) record = query.first() session.delete(record) self._session_finish(session) def filter(self, **kwargs): """ Returns a list of objects from the database. The kwargs parameter can contain any number of attributes. Only objects which contain all listed attributes and in which all values match for all listed attributes will be returned. """ Statement = self.get_model('statement') Response = self.get_model('response') session = self.Session() filter_parameters = kwargs.copy() statements = [] _query = None if len(filter_parameters) == 0: _response_query = session.query(Statement) statements.extend(_response_query.all()) else: for i, fp in enumerate(filter_parameters): _filter = filter_parameters[fp] if fp in ['in_response_to', 'in_response_to__contains']: _response_query = session.query(Statement) if isinstance(_filter, list): if len(_filter) == 0: _query = _response_query.filter( Statement.in_response_to == None # NOQA Here must use == instead of is ) else: for f in _filter: _query = _response_query.filter( Statement.in_response_to.contains(get_response_table(f))) else: if fp == 'in_response_to__contains': _query = _response_query.join(Response).filter(Response.text == _filter) else: _query = _response_query.filter(Statement.in_response_to == None) # NOQA else: if _query: _query = _query.filter(Response.statement_text.like('%' + _filter + '%')) else: _response_query = session.query(Response) _query = _response_query.filter(Response.statement_text.like('%' + _filter + '%')) if _query is None: return [] if len(filter_parameters) == i + 1: statements.extend(_query.all()) results = [] for statement in statements: if isinstance(statement, Response): if statement and statement.statement_table: results.append(statement.statement_table.get_statement()) else: if statement: results.append(statement.get_statement()) session.close() return results def update(self, statement): """ Modifies an entry in the database. Creates an entry if one does not exist. """ Statement = self.get_model('statement') Response = self.get_model('response') Tag = self.get_model('tag') if statement: session = self.Session() query = session.query(Statement).filter_by(text=statement.text) record = query.first() # Create a new statement entry if one does not already exist if not record: record = Statement(text=statement.text) record.extra_data = dict(statement.extra_data) for _tag in statement.tags: tag = session.query(Tag).filter_by(name=_tag).first() if not tag: # Create the record tag = Tag(name=_tag) record.tags.append(tag) # Get or create the response records as needed for response in statement.in_response_to: _response = session.query(Response).filter_by( text=response.text, statement_text=statement.text ).first() if _response: _response.occurrence += 1 else: # Create the record _response = Response( text=response.text, statement_text=statement.text, occurrence=response.occurrence ) record.in_response_to.append(_response) session.add(record) self._session_finish(session) def create_conversation(self): """ Create a new conversation. """ Conversation = self.get_model('conversation') session = self.Session() conversation = Conversation() session.add(conversation) session.flush() session.refresh(conversation) conversation_id = conversation.id session.commit() session.close() return conversation_id def add_to_conversation(self, conversation_id, statement, response): """ Add the statement and response to the conversation. """ Statement = self.get_model('statement') Conversation = self.get_model('conversation') session = self.Session() conversation = session.query(Conversation).get(conversation_id) statement_query = session.query(Statement).filter_by( text=statement.text ).first() response_query = session.query(Statement).filter_by( text=response.text ).first() # Make sure the statements exist if not statement_query: self.update(statement) statement_query = session.query(Statement).filter_by( text=statement.text ).first() if not response_query: self.update(response) response_query = session.query(Statement).filter_by( text=response.text ).first() conversation.statements.append(statement_query) conversation.statements.append(response_query) session.add(conversation) self._session_finish(session) def get_latest_response(self, conversation_id): """ Returns the latest response in a conversation if it exists. Returns None if a matching conversation cannot be found. """ Statement = self.get_model('statement') session = self.Session() statement = None statement_query = session.query(Statement).filter( Statement.conversations.any(id=conversation_id) ).order_by(Statement.id) if statement_query.count() >= 2: statement = statement_query[-2].get_statement() # Handle the case of the first statement in the list elif statement_query.count() == 1: statement = statement_query[0].get_statement() session.close() return statement def get_random(self): """ Returns a random statement from the database """ import random Statement = self.get_model('statement') session = self.Session() count = self.count() if count < 1: raise self.EmptyDatabaseException() rand = random.randrange(0, count) stmt = session.query(Statement)[rand] statement = stmt.get_statement() session.close() return statement def drop(self): """ Drop the database attached to a given adapter. """ from ..ext.sqlalchemy_app.models import Base Base.metadata.drop_all(self.engine) def create(self): """ Populate the database with the tables. """ from ..ext.sqlalchemy_app.models import Base Base.metadata.create_all(self.engine) def _session_finish(self, session, statement_text=None): from sqlalchemy.exc import InvalidRequestError try: if not self.read_only: session.commit() else: session.rollback() except InvalidRequestError: # Log the statement text and the exception self.logger.exception(statement_text) finally: session.close()