Add start of AsyncSQLStorageAdapterpull/175/head
parent
eac7aee82c
commit
04ccb435f8
@ -0,0 +1,73 @@
|
||||
from chatterbot.storage import StorageAdapter, SQLStorageAdapter
|
||||
|
||||
|
||||
class MyDumbSQLStorageAdapter(SQLStorageAdapter):
|
||||
def __init__(self, **kwargs):
|
||||
super(SQLStorageAdapter, self).__init__(**kwargs)
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
self.database_uri = kwargs.get("database_uri", False)
|
||||
|
||||
# None results in a sqlite in-memory database as the default
|
||||
if self.database_uri is None:
|
||||
self.database_uri = "sqlite://"
|
||||
|
||||
# Create a file database if the database is not a connection string
|
||||
if not self.database_uri:
|
||||
self.database_uri = "sqlite:///db.sqlite3"
|
||||
|
||||
self.engine = create_engine(
|
||||
self.database_uri, convert_unicode=True, connect_args={"check_same_thread": False}
|
||||
)
|
||||
|
||||
if self.database_uri.startswith("sqlite://"):
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy import event
|
||||
|
||||
@event.listens_for(Engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
dbapi_connection.execute("PRAGMA journal_mode=WAL")
|
||||
dbapi_connection.execute("PRAGMA synchronous=NORMAL")
|
||||
|
||||
if not self.engine.dialect.has_table(self.engine, "Statement"):
|
||||
self.create_database()
|
||||
|
||||
self.Session = sessionmaker(bind=self.engine, expire_on_commit=True)
|
||||
|
||||
|
||||
class AsyncSQLStorageAdapter(SQLStorageAdapter):
|
||||
def __init__(self, **kwargs):
|
||||
super(SQLStorageAdapter, self).__init__(**kwargs)
|
||||
|
||||
self.database_uri = kwargs.get("database_uri", False)
|
||||
|
||||
# None results in a sqlite in-memory database as the default
|
||||
if self.database_uri is None:
|
||||
self.database_uri = "sqlite://"
|
||||
|
||||
# Create a file database if the database is not a connection string
|
||||
if not self.database_uri:
|
||||
self.database_uri = "sqlite:///db.sqlite3"
|
||||
|
||||
async def initialize(self):
|
||||
# from sqlalchemy import create_engine
|
||||
from aiomysql.sa import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
self.engine = await create_engine(self.database_uri, convert_unicode=True)
|
||||
|
||||
if self.database_uri.startswith("sqlite://"):
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy import event
|
||||
|
||||
@event.listens_for(Engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
dbapi_connection.execute("PRAGMA journal_mode=WAL")
|
||||
dbapi_connection.execute("PRAGMA synchronous=NORMAL")
|
||||
|
||||
if not self.engine.dialect.has_table(self.engine, "Statement"):
|
||||
self.create_database()
|
||||
|
||||
self.Session = sessionmaker(bind=self.engine, expire_on_commit=True)
|
Loading…
Reference in new issue