Almost to adding triggers

pull/132/head
bobloy 4 years ago
parent 1a5aaff268
commit c6a9116a92

@ -3,5 +3,4 @@ from .fifo import FIFO
async def setup(bot):
cog = FIFO(bot)
await cog.load_tasks()
bot.add_cog(cog)

@ -0,0 +1,16 @@
from datetime import datetime
from typing import TYPE_CHECKING
from discord.ext.commands import BadArgument, Converter
from dateutil import parser
if TYPE_CHECKING:
DatetimeConverter = datetime
else:
class DatetimeConverter(Converter):
async def convert(self, ctx, argument) -> datetime:
dt = parser.parse(argument)
if dt is not None:
return dt
raise BadArgument()

@ -1,34 +1,34 @@
from datetime import datetime, timedelta
from typing import Dict, Union
from apscheduler.executors.asyncio import AsyncIOExecutor
from apscheduler.jobstores.memory import MemoryJobStore
import discord
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.base import BaseTrigger
from apscheduler.triggers.combining import AndTrigger, OrTrigger
from apscheduler.triggers.combining import OrTrigger
from apscheduler.triggers.date import DateTrigger
from apscheduler.triggers.interval import IntervalTrigger
from dateutil import parser
from redbot.core import Config, checks, commands
from redbot.core.bot import Red
from apscheduler.schedulers.asyncio import AsyncIOScheduler
import discord
import asyncio
import datetime
from redbot.core.commands import DictConverter, TimedeltaConverter, parse_timedelta
from redbot.core.utils import AsyncIter
from .datetimeconverter import DatetimeConverter
from .redconfigjobstore import RedConfigJobStore
def get_trigger(data):
if data["type"] == "interval":
parsed_time = parse_timedelta(data["timedelta_str"])
parsed_time = data["time_data"]
return IntervalTrigger(days=parsed_time.days, seconds=parsed_time.seconds)
if data["type"] == "date":
return DateTrigger(parser.parse(data["strtime"]))
return DateTrigger(data["time_data"])
if data["type"] == "cron":
return None # TODO: Cron parsing
return False
def parse_triggers(data: Union[Dict, None]):
if data is None or not data.get("triggers", False): # No triggers
@ -40,33 +40,95 @@ def parse_triggers(data: Union[Dict, None]):
return get_trigger(data[0])
class Task:
class FakeMessage:
_state = None
# class FakeMessage(discord.Message):
# def __init__(self):
# super().__init__(state=None, channel=None, data=None)
class Task:
default_task_data = {"triggers": [], "command_str": ""}
default_trigger = {
"type": "",
"timedelta_str": "",
"time_data": None, # Used for Interval and Date Triggers
}
def __init__(self, name: str, guild_id, config: Config):
def __init__(self, name: str, guild_id, config: Config, author_id=None, bot: Red = None):
self.name = name
self.guild_id = guild_id
self.config = config
self.bot = bot
self.author_id = author_id
self.data = None
async def load_from_data(self, data: Dict):
self.data = data.copy()
async def _encode_time_data(self):
if not self.data or not self.data.get("triggers", None):
return None
triggers = []
for t in self.data["triggers"]:
if t["type"] == "interval": # Convert into timedelta
td: timedelta = t["time_data"]
triggers.append({"type": t["type"], "time_data": {"days": td.days, "seconds": td.seconds} })
if t["type"] == "date": # Convert into datetime
dt: datetime = t["time_data"]
triggers.append({"type": t["type"], "time_data": {
"year": dt.year,
"month": dt.month,
"day": dt.day,
"hour": dt.hour,
"minute": dt.minute,
"second": dt.second,
}})
if t["type"] == "cron":
raise NotImplemented
raise NotImplemented
return triggers
async def _decode_time_data(self):
if not self.data or not self.data.get("triggers", None):
return
for t in self.data["triggers"]:
if t["type"] == "interval": # Convert into timedelta
t["time_data"] = timedelta(**t["time_data"])
if t["type"] == "date": # Convert into datetime
t["time_data"] = datetime(**t["time_data"])
if t["type"] == "cron":
raise NotImplemented
raise NotImplemented
# async def load_from_data(self, data: Dict):
# self.data = data.copy()
async def load_from_config(self):
self.data = await self.config.guild_from_id(self.guild_id).tasks.get_raw(
data = await self.config.guild_from_id(self.guild_id).tasks.get_raw(
self.name, default=None
)
if not data:
return
self.author_id = data["author_id"]
self.guild_id = data["guild_id"]
self.data = data["data"]
await self._decode_time_data()
return self.data
async def get_trigger(self) -> Union[BaseTrigger, None]:
if self.data is None:
if not self.data:
await self.load_from_config()
return parse_triggers(self.data)
@ -77,14 +139,69 @@ class Task:
#
# self.data["job_id"] = job_id
async def save_all(self):
"""To be used when creating an new task"""
data_to_save = self.default_task_data.copy()
if self.data:
data_to_save["command_str"] = self.data.get("command_str", "")
data_to_save["triggers"] = await self._encode_time_data()
to_save = {
"guild_id": self.guild_id,
"author_id": self.author_id,
"data": data_to_save,
}
await self.config.guild_from_id(self.guild_id).tasks.set_raw(self.name, value=to_save)
async def save_data(self):
await self.config.guild_from_id(self.guild_id).tasks.set_raw(self.name, value=self.data)
"""To be used when updating triggers"""
if not self.data:
return
await self.config.guild_from_id(self.guild_id).tasks.set_raw(
self.name, "data", value=await self._encode_time_data()
)
async def execute(self):
pass # TODO: something something invoke command
if not self.data or self.data["command_str"]:
return False
message = FakeMessage()
message.guild = self.bot.get_guild(self.guild_id) # used for get_prefix
message.author = message.guild.get_member(self.author_id)
message.content = await self.bot.get_prefix(message) + self.data["command_str"]
async def add_trigger(self, param, parsed_time):
pass
if not message.guild or not message.author or not message.content:
return False
new_ctx: commands.Context = await self.bot.get_context(message)
if not new_ctx.valid:
return False
await self.bot.invoke(new_ctx)
return True
async def set_bot(self, bot: Red):
self.bot = bot
async def set_author(self, author: Union[discord.User, str]):
self.author_id = getattr(author, "id", None) or author
async def set_commmand_str(self, command_str):
if not self.data:
self.data = self.default_task_data.copy()
self.data["command_str"] = command_str
return True
async def add_trigger(self, param, parsed_time: Union[timedelta, datetime]):
trigger_data = {"type": param, "time_data": parsed_time}
if not get_trigger(trigger_data):
return False
if not self.data:
self.data = self.default_task_data.copy()
self.data["triggers"].append(trigger_data)
return True
class FIFO(commands.Cog):
@ -105,23 +222,50 @@ class FIFO(commands.Cog):
self.config.register_global(**default_global)
self.config.register_guild(**default_guild)
jobstores = {"default": MemoryJobStore()}
jobstores = {"default": RedConfigJobStore(self.config, self.bot)}
job_defaults = {"coalesce": False, "max_instances": 1}
# executors = {"default": AsyncIOExecutor()}
# Default executor is already AsyncIOExecutor
self.scheduler = AsyncIOScheduler(
jobstores=jobstores, job_defaults=job_defaults
)
self.scheduler = AsyncIOScheduler(jobstores=jobstores, job_defaults=job_defaults)
self.scheduler.start()
async def red_delete_data_for_user(self, **kwargs):
"""Nothing to delete"""
return
async def _parse_command(self, command_to_parse: str):
return False # TODO: parse commands somehow
def _assemble_job_id(self, task_name, guild_id):
return task_name + "_" + guild_id
async def _check_parsable_command(self, ctx: commands.Context, command_to_parse: str):
message = FakeMessage()
message.content = ctx.prefix + command_to_parse
message.author = ctx.author
message.guild = ctx.guild
new_ctx: commands.Context = await self.bot.get_context(message)
return new_ctx.valid
async def _get_job(self, task_name, guild_id):
return self.scheduler.get_job(self._assemble_job_id(task_name, guild_id))
async def _add_job(self, task):
return self.scheduler.add_job(
task.execute,
id=self._assemble_job_id(task.name, task.guild_id),
trigger=await task.get_trigger(),
)
@checks.is_owner()
@commands.command()
async def fifoclear(self, ctx: commands.Context):
"""Debug command to clear fifo config"""
await self.config.guild(ctx.guild).tasks.clear()
await ctx.tick()
@checks.is_owner() # Will be reduced when I figure out permissions later
@commands.group()
@ -149,10 +293,21 @@ class FIFO(commands.Cog):
"""
Add a new task to this guild's task list
"""
pass
if (await self.config.guild(ctx.guild).tasks.get_raw(task_name, default=None)) is not None:
await ctx.maybe_send_embed(f"Task already exists with {task_name=}")
return
if not await self._check_parsable_command(ctx, command_to_execute):
await ctx.maybe_send_embed("Failed to parse command. Make sure to include the prefix")
return
task = Task(task_name, ctx.guild.id, self.config, ctx.author.id)
await task.set_commmand_str(command_to_execute)
await task.save_all()
await ctx.tick()
@fifo.command(name="delete")
async def fifo_delete(self, ctx: commands.Context, task_name: str, *, command_to_execute: str):
async def fifo_delete(self, ctx: commands.Context, task_name: str):
"""
Deletes a task from this guild's task list
"""
@ -189,11 +344,12 @@ class FIFO(commands.Cog):
"Failed to add an interval trigger to this task, see console for logs"
)
return
await task.save_data()
await ctx.tick()
@fifo_trigger.command(name="date")
async def fifo_trigger_date(
self, ctx: commands.Context, task_name: str, datetime_str: TimedeltaConverter
self, ctx: commands.Context, task_name: str, datetime_str: DatetimeConverter
):
"""
Add a "run once" datetime trigger to the specified task
@ -214,6 +370,8 @@ class FIFO(commands.Cog):
"Failed to add a date trigger to this task, see console for logs"
)
return
await task.save_data()
await ctx.tick()
@fifo_trigger.command(name="cron")
@ -225,21 +383,21 @@ class FIFO(commands.Cog):
"""
await ctx.maybe_send_embed("Not yet implemented")
async def load_tasks(self):
"""
Run once on cog load.
"""
all_guilds = await self.config.all_guilds()
async for guild_id, guild_data in AsyncIter(all_guilds["tasks"].items(), steps=100):
for task_name, task_data in guild_data["tasks"].items():
task = Task(task_name, guild_id, self.config)
await task.load_from_data(task_data)
job = self.scheduler.add_job(
task.execute, id=task_name + "_" + guild_id, trigger=await task.get_trigger(),
)
self.scheduler.start()
# async def load_tasks(self):
# """
# Run once on cog load.
# """
# all_guilds = await self.config.all_guilds()
# async for guild_id, guild_data in AsyncIter(all_guilds["tasks"].items(), steps=100):
# for task_name, task_data in guild_data["tasks"].items():
# task = Task(task_name, guild_id, self.config)
# await task.load_from_data(task_data)
#
# job = self.scheduler.add_job(
# task.execute, id=task_name + "_" + guild_id, trigger=await task.get_trigger(),
# )
#
# self.scheduler.start()
# async def parent_loop(self):
# await asyncio.sleep(60)

@ -1,35 +0,0 @@
import asyncio
from apscheduler.jobstores.base import BaseJobStore
from redbot.core import Config
class RedConfigJobStore(BaseJobStore):
def __init__(self, config: Config, loop):
super().__init__()
self.config = config
self.loop: asyncio.BaseEventLoop = loop
def lookup_job(self, job_id):
task = self.loop.create_task(self.config.jobs_index.get_raw(job_id))
def get_due_jobs(self, now):
pass
def get_next_run_time(self):
pass
def get_all_jobs(self):
pass
def add_job(self, job):
pass
def update_job(self, job):
pass
def remove_job(self, job_id):
pass
def remove_all_jobs(self):
pass

@ -0,0 +1,189 @@
import asyncio
from apscheduler.jobstores.base import ConflictingIdError, JobLookupError
from apscheduler.jobstores.memory import MemoryJobStore
from apscheduler.util import datetime_to_utc_timestamp
from redbot.core import Config
# TODO: use get_lock on config
from redbot.core.bot import Red
class RedConfigJobStore(MemoryJobStore):
def __init__(self, config: Config, bot: Red):
super().__init__()
self.config = config
# nest_asyncio.apply()
self.bot = bot
asyncio.ensure_future(self._load_from_config(), loop=self.bot.loop)
async def _load_from_config(self):
self._jobs = await self.config.jobs()
self._jobs_index = await self.config.jobs_index.all()
def add_job(self, job):
if job.id in self._jobs_index:
raise ConflictingIdError(job.id)
timestamp = datetime_to_utc_timestamp(job.next_run_time)
index = self._get_job_index(timestamp, job.id) # This is fine
self._jobs.insert(index, (job, timestamp))
self._jobs_index[job.id] = (job, timestamp)
asyncio.create_task(self._async_add_job(job, index, timestamp))
async def _async_add_job(self, job, index, timestamp):
async with self.config.jobs() as jobs:
jobs.insert(index, (job, timestamp))
await self.config.jobs_index.set_raw(job.id, value=(job, timestamp))
return True
def update_job(self, job):
old_job, old_timestamp = self._jobs_index.get(job.id, (None, None))
if old_job is None:
raise JobLookupError(job.id)
# If the next run time has not changed, simply replace the job in its present index.
# Otherwise, reinsert the job to the list to preserve the ordering.
old_index = self._get_job_index(old_timestamp, old_job.id)
new_timestamp = datetime_to_utc_timestamp(job.next_run_time)
asyncio.ensure_future(self._async_update_job(job, new_timestamp, old_index, old_job, old_timestamp), loop=self.bot.loop)
async def _async_update_job(self, job, new_timestamp, old_index, old_job, old_timestamp):
if old_timestamp == new_timestamp:
self._jobs[old_index] = (job, new_timestamp)
async with self.config.jobs() as jobs:
jobs[old_index] = (job, new_timestamp)
else:
del self._jobs[old_index]
new_index = self._get_job_index(new_timestamp, job.id) # This is fine
self._jobs.insert(new_index, (job, new_timestamp))
async with self.config.jobs() as jobs:
del jobs[old_index]
jobs.insert(new_index, (job, new_timestamp))
self._jobs_index[old_job.id] = (job, new_timestamp)
await self.config.jobs_index.set_raw(old_job.id, value=(job, new_timestamp))
def remove_job(self, job_id):
job, timestamp = self._jobs_index.get(job_id, (None, None))
if job is None:
raise JobLookupError(job_id)
index = self._get_job_index(timestamp, job_id)
del self._jobs[index]
del self._jobs_index[job.id]
asyncio.create_task(self._async_remove_job(index, job))
async def _async_remove_job(self, index, job):
async with self.config.jobs() as jobs:
del jobs[index]
await self.config.jobs_index.clear_raw(job.id)
def remove_all_jobs(self):
super().remove_all_jobs()
asyncio.create_task(self._async_remove_all_jobs())
async def _async_remove_all_jobs(self):
await self.config.jobs.clear()
await self.config.jobs_index.clear()
# import asyncio
#
# from apscheduler.jobstores.base import BaseJobStore, ConflictingIdError
# from apscheduler.util import datetime_to_utc_timestamp
# from redbot.core import Config
# from redbot.core.utils import AsyncIter
#
#
# class RedConfigJobStore(BaseJobStore):
# def __init__(self, config: Config, loop):
# super().__init__()
# self.config = config
# self.loop: asyncio.BaseEventLoop = loop
#
# self._jobs = []
# self._jobs_index = {} # id -> (job, timestamp) lookup table
#
# def lookup_job(self, job_id):
# return asyncio.run(self._async_lookup_job(job_id))
#
# async def _async_lookup_job(self, job_id):
# return (await self.config.jobs_index.get_raw(job_id, default=(None, None)))[0]
#
# def get_due_jobs(self, now):
# return asyncio.run(self._async_get_due_jobs(now))
#
# async def _async_get_due_jobs(self, now):
# now_timestamp = datetime_to_utc_timestamp(now)
# pending = []
# all_jobs = await self.config.jobs()
# async for job, timestamp in AsyncIter(all_jobs, steps=100):
# if timestamp is None or timestamp > now_timestamp:
# break
# pending.append(job)
#
# return pending
#
# def get_next_run_time(self):
# return asyncio.run(self._async_get_next_run_time())
#
# async def _async_get_next_run_time(self):
# _jobs = await self.config.jobs()
# return _jobs[0][0].next_run_time if _jobs else None
#
# def get_all_jobs(self):
# return asyncio.run(self._async_get_all_jobs())
#
# async def _async_get_all_jobs(self):
# return [j[0] for j in (await self.config.jobs())]
#
# def add_job(self, job):
# return asyncio.run(self._async_add_job(job))
#
# async def _async_add_job(self, job):
# if await self.config.jobs_index.get_raw(job.id, default=None) is not None:
# raise ConflictingIdError(job.id)
#
# timestamp = datetime_to_utc_timestamp(job.next_run_time)
# index = self._get_job_index(timestamp, job.id)
# self._jobs.insert(index, (job, timestamp))
# self._jobs_index[job.id] = (job, timestamp)
#
# def update_job(self, job):
# pass
#
# def remove_job(self, job_id):
# pass
#
# def remove_all_jobs(self):
# pass
#
# def _get_job_index(self, timestamp, job_id):
# """
# Returns the index of the given job, or if it's not found, the index where the job should be
# inserted based on the given timestamp.
#
# :type timestamp: int
# :type job_id: str
#
# """
# lo, hi = 0, len(self._jobs)
# timestamp = float('inf') if timestamp is None else timestamp
# while lo < hi:
# mid = (lo + hi) // 2
# mid_job, mid_timestamp = self._jobs[mid]
# mid_timestamp = float('inf') if mid_timestamp is None else mid_timestamp
# if mid_timestamp > timestamp:
# hi = mid
# elif mid_timestamp < timestamp:
# lo = mid + 1
# elif mid_job.id > job_id:
# hi = mid
# elif mid_job.id < job_id:
# lo = mid + 1
# else:
# return mid
#
# return lo
Loading…
Cancel
Save