commit
837cff7a26
7
fifo/date_trigger.py
Normal file
7
fifo/date_trigger.py
Normal file
@ -0,0 +1,7 @@
|
||||
from apscheduler.triggers.date import DateTrigger
|
||||
|
||||
|
||||
class CustomDateTrigger(DateTrigger):
|
||||
def get_next_fire_time(self, previous_fire_time, now):
|
||||
next_run = super().get_next_fire_time(previous_fire_time, now)
|
||||
return next_run if next_run >= now else None
|
48
fifo/fifo.py
48
fifo/fifo.py
@ -4,6 +4,7 @@ from datetime import MAXYEAR, datetime, timedelta, tzinfo
|
||||
from typing import Optional, Union
|
||||
|
||||
import discord
|
||||
import pytz
|
||||
from apscheduler.job import Job
|
||||
from apscheduler.jobstores.base import JobLookupError
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
@ -51,7 +52,7 @@ def _get_run_times(job: Job, now: datetime = None):
|
||||
|
||||
if now is None:
|
||||
now = datetime(MAXYEAR, 12, 31, 23, 59, 59, 999999, tzinfo=job.next_run_time.tzinfo)
|
||||
yield from _get_run_times(job, now)
|
||||
yield from _get_run_times(job, now) # Recursion
|
||||
raise StopIteration()
|
||||
|
||||
next_run_time = job.next_run_time
|
||||
@ -145,28 +146,38 @@ class FIFO(commands.Cog):
|
||||
await task.delete_self()
|
||||
|
||||
async def _process_task(self, task: Task):
|
||||
job: Union[Job, None] = await self._get_job(task)
|
||||
if job is not None:
|
||||
job.reschedule(await task.get_combined_trigger())
|
||||
return job
|
||||
# None of this is necessar, we have `replace_existing` already
|
||||
# job: Union[Job, None] = await self._get_job(task)
|
||||
# if job is not None:
|
||||
# combined_trigger_ = await task.get_combined_trigger()
|
||||
# if combined_trigger_ is None:
|
||||
# job.remove()
|
||||
# else:
|
||||
# job.reschedule(combined_trigger_)
|
||||
# return job
|
||||
return await self._add_job(task)
|
||||
|
||||
async def _get_job(self, task: Task) -> Job:
|
||||
return self.scheduler.get_job(_assemble_job_id(task.name, task.guild_id))
|
||||
|
||||
async def _add_job(self, task: Task):
|
||||
combined_trigger_ = await task.get_combined_trigger()
|
||||
if combined_trigger_ is None:
|
||||
return None
|
||||
|
||||
return self.scheduler.add_job(
|
||||
_execute_task,
|
||||
kwargs=task.__getstate__(),
|
||||
id=_assemble_job_id(task.name, task.guild_id),
|
||||
trigger=await task.get_combined_trigger(),
|
||||
trigger=combined_trigger_,
|
||||
name=task.name,
|
||||
)
|
||||
|
||||
async def _resume_job(self, task: Task):
|
||||
try:
|
||||
job = self.scheduler.resume_job(job_id=_assemble_job_id(task.name, task.guild_id))
|
||||
except JobLookupError:
|
||||
job: Union[Job, None] = await self._get_job(task)
|
||||
if job is not None:
|
||||
job.resume()
|
||||
else:
|
||||
job = await self._process_task(task)
|
||||
return job
|
||||
|
||||
@ -221,6 +232,17 @@ class FIFO(commands.Cog):
|
||||
if ctx.invoked_subcommand is None:
|
||||
pass
|
||||
|
||||
@fifo.command(name="wakeup")
|
||||
async def fifo_wakeup(self, ctx: commands.Context):
|
||||
"""Debug command to fix missed executions.
|
||||
|
||||
If you see a negative "Next run time" when adding a trigger, this may help resolve it.
|
||||
Check the logs when using this command.
|
||||
"""
|
||||
|
||||
self.scheduler.wakeup()
|
||||
await ctx.tick()
|
||||
|
||||
@fifo.command(name="checktask", aliases=["checkjob", "check"])
|
||||
async def fifo_checktask(self, ctx: commands.Context, task_name: str):
|
||||
"""Returns the next 10 scheduled executions of the task"""
|
||||
@ -372,10 +394,14 @@ class FIFO(commands.Cog):
|
||||
|
||||
else:
|
||||
embed.add_field(name="Server", value="Server not found", inline=False)
|
||||
triggers, expired_triggers = await task.get_triggers()
|
||||
|
||||
trigger_str = "\n".join(str(t) for t in await task.get_triggers())
|
||||
trigger_str = "\n".join(str(t) for t in triggers)
|
||||
expired_str = "\n".join(str(t) for t in expired_triggers)
|
||||
if trigger_str:
|
||||
embed.add_field(name="Triggers", value=trigger_str, inline=False)
|
||||
if expired_str:
|
||||
embed.add_field(name="Expired Triggers", value=expired_str, inline=False)
|
||||
|
||||
job = await self._get_job(task)
|
||||
if job and job.next_run_time:
|
||||
@ -546,7 +572,7 @@ class FIFO(commands.Cog):
|
||||
)
|
||||
return
|
||||
|
||||
time_to_run = datetime.now() + time_from_now
|
||||
time_to_run = datetime.now(pytz.utc) + time_from_now
|
||||
|
||||
result = await task.add_trigger("date", time_to_run, time_to_run.tzinfo)
|
||||
if not result:
|
||||
|
@ -39,7 +39,7 @@ class RedConfigJobStore(MemoryJobStore):
|
||||
# self._jobs = [
|
||||
# (await self._decode_job(job), timestamp) async for (job, timestamp) in AsyncIter(_jobs)
|
||||
# ]
|
||||
async for job, timestamp in AsyncIter(_jobs):
|
||||
async for job, timestamp in AsyncIter(_jobs, steps=5):
|
||||
job = await self._decode_job(job)
|
||||
index = self._get_job_index(timestamp, job.id)
|
||||
self._jobs.insert(index, (job, timestamp))
|
||||
@ -109,83 +109,6 @@ class RedConfigJobStore(MemoryJobStore):
|
||||
|
||||
return job
|
||||
|
||||
# @run_in_event_loop
|
||||
# def add_job(self, job: Job):
|
||||
# if job.id in self._jobs_index:
|
||||
# raise ConflictingIdError(job.id)
|
||||
# # log.debug(f"Check job args: {job.args=}")
|
||||
# timestamp = datetime_to_utc_timestamp(job.next_run_time)
|
||||
# index = self._get_job_index(timestamp, job.id) # This is fine
|
||||
# self._jobs.insert(index, (job, timestamp))
|
||||
# self._jobs_index[job.id] = (job, timestamp)
|
||||
# task = asyncio.create_task(self._async_add_job(job, index, timestamp))
|
||||
# self._eventloop.run_until_complete(task)
|
||||
# # log.debug(f"Added job: {self._jobs[index][0].args}")
|
||||
#
|
||||
# async def _async_add_job(self, job, index, timestamp):
|
||||
# encoded_job = self._encode_job(job)
|
||||
# job_tuple = tuple([encoded_job, timestamp])
|
||||
# async with self.config.jobs() as jobs:
|
||||
# jobs.insert(index, job_tuple)
|
||||
# # await self.config.jobs_index.set_raw(job.id, value=job_tuple)
|
||||
# return True
|
||||
|
||||
# @run_in_event_loop
|
||||
# def update_job(self, job):
|
||||
# old_tuple: Tuple[Union[Job, None], Union[datetime, None]] = self._jobs_index.get(
|
||||
# job.id, (None, None)
|
||||
# )
|
||||
# old_job = old_tuple[0]
|
||||
# old_timestamp = old_tuple[1]
|
||||
# if old_job is None:
|
||||
# raise JobLookupError(job.id)
|
||||
#
|
||||
# # If the next run time has not changed, simply replace the job in its present index.
|
||||
# # Otherwise, reinsert the job to the list to preserve the ordering.
|
||||
# old_index = self._get_job_index(old_timestamp, old_job.id)
|
||||
# new_timestamp = datetime_to_utc_timestamp(job.next_run_time)
|
||||
# task = asyncio.create_task(
|
||||
# self._async_update_job(job, new_timestamp, old_index, old_job, old_timestamp)
|
||||
# )
|
||||
# self._eventloop.run_until_complete(task)
|
||||
#
|
||||
# async def _async_update_job(self, job, new_timestamp, old_index, old_job, old_timestamp):
|
||||
# encoded_job = self._encode_job(job)
|
||||
# if old_timestamp == new_timestamp:
|
||||
# self._jobs[old_index] = (job, new_timestamp)
|
||||
# async with self.config.jobs() as jobs:
|
||||
# jobs[old_index] = (encoded_job, new_timestamp)
|
||||
# else:
|
||||
# del self._jobs[old_index]
|
||||
# new_index = self._get_job_index(new_timestamp, job.id) # This is fine
|
||||
# self._jobs.insert(new_index, (job, new_timestamp))
|
||||
# async with self.config.jobs() as jobs:
|
||||
# del jobs[old_index]
|
||||
# jobs.insert(new_index, (encoded_job, new_timestamp))
|
||||
# self._jobs_index[old_job.id] = (job, new_timestamp)
|
||||
# # await self.config.jobs_index.set_raw(old_job.id, value=(encoded_job, new_timestamp))
|
||||
#
|
||||
# log.debug(f"Async Updated {job.id=}")
|
||||
# # log.debug(f"Check job args: {job.kwargs=}")
|
||||
|
||||
# @run_in_event_loop
|
||||
# def remove_job(self, job_id):
|
||||
# """Copied instead of super for the asyncio args"""
|
||||
# job, timestamp = self._jobs_index.get(job_id, (None, None))
|
||||
# if job is None:
|
||||
# raise JobLookupError(job_id)
|
||||
#
|
||||
# index = self._get_job_index(timestamp, job_id)
|
||||
# del self._jobs[index]
|
||||
# del self._jobs_index[job.id]
|
||||
# task = asyncio.create_task(self._async_remove_job(index, job))
|
||||
# self._eventloop.run_until_complete(task)
|
||||
#
|
||||
# async def _async_remove_job(self, index, job):
|
||||
# async with self.config.jobs() as jobs:
|
||||
# del jobs[index]
|
||||
# # await self.config.jobs_index.clear_raw(job.id)
|
||||
|
||||
@run_in_event_loop
|
||||
def remove_all_jobs(self):
|
||||
super().remove_all_jobs()
|
||||
@ -201,4 +124,5 @@ class RedConfigJobStore(MemoryJobStore):
|
||||
|
||||
async def async_shutdown(self):
|
||||
await self.save_to_config()
|
||||
super().remove_all_jobs()
|
||||
self._jobs = []
|
||||
self._jobs_index = {}
|
||||
|
74
fifo/task.py
74
fifo/task.py
@ -1,18 +1,19 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import discord
|
||||
from apscheduler.triggers.base import BaseTrigger
|
||||
from apscheduler.triggers.combining import OrTrigger
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.date import DateTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from discord.utils import time_snowflake
|
||||
from pytz import timezone
|
||||
import pytz
|
||||
from redbot.core import Config, commands
|
||||
from redbot.core.bot import Red
|
||||
|
||||
from fifo.date_trigger import CustomDateTrigger
|
||||
|
||||
log = logging.getLogger("red.fox_v3.fifo.task")
|
||||
|
||||
|
||||
@ -26,7 +27,7 @@ def get_trigger(data):
|
||||
return IntervalTrigger(days=parsed_time.days, seconds=parsed_time.seconds)
|
||||
|
||||
if data["type"] == "date":
|
||||
return DateTrigger(data["time_data"], timezone=data["tzinfo"])
|
||||
return CustomDateTrigger(data["time_data"], timezone=data["tzinfo"])
|
||||
|
||||
if data["type"] == "cron":
|
||||
return CronTrigger.from_crontab(data["time_data"], timezone=data["tzinfo"])
|
||||
@ -34,14 +35,25 @@ def get_trigger(data):
|
||||
return False
|
||||
|
||||
|
||||
def check_expired_trigger(trigger: BaseTrigger):
|
||||
return trigger.get_next_fire_time(None, datetime.now(pytz.utc)) is None
|
||||
|
||||
|
||||
def parse_triggers(data: Union[Dict, None]):
|
||||
if data is None or not data.get("triggers", False): # No triggers
|
||||
return None
|
||||
|
||||
if len(data["triggers"]) > 1: # Multiple triggers
|
||||
return OrTrigger([get_trigger(t_data) for t_data in data["triggers"]])
|
||||
triggers_list = [get_trigger(t_data) for t_data in data["triggers"]]
|
||||
triggers_list = [t for t in triggers_list if not check_expired_trigger(t)]
|
||||
if not triggers_list:
|
||||
return None
|
||||
return OrTrigger(triggers_list)
|
||||
else:
|
||||
return get_trigger(data["triggers"][0])
|
||||
trigger = get_trigger(data["triggers"][0])
|
||||
if check_expired_trigger(trigger):
|
||||
return None
|
||||
return trigger
|
||||
|
||||
|
||||
class FakeMessage:
|
||||
@ -66,11 +78,11 @@ def neuter_message(message: FakeMessage):
|
||||
|
||||
|
||||
class Task:
|
||||
default_task_data = {"triggers": [], "command_str": ""}
|
||||
default_task_data = {"triggers": [], "command_str": "", "expired_triggers": []}
|
||||
|
||||
default_trigger = {
|
||||
"type": "",
|
||||
"time_data": None, # Used for Interval and Date Triggers
|
||||
"time_data": None,
|
||||
"tzinfo": None,
|
||||
}
|
||||
|
||||
@ -87,9 +99,10 @@ class Task:
|
||||
|
||||
async def _encode_time_triggers(self):
|
||||
if not self.data or not self.data.get("triggers", None):
|
||||
return []
|
||||
return [], []
|
||||
|
||||
triggers = []
|
||||
expired_triggers = []
|
||||
for t in self.data["triggers"]:
|
||||
if t["type"] == "interval": # Convert into timedelta
|
||||
td: timedelta = t["time_data"]
|
||||
@ -101,13 +114,15 @@ class Task:
|
||||
|
||||
if t["type"] == "date": # Convert into datetime
|
||||
dt: datetime = t["time_data"]
|
||||
triggers.append(
|
||||
{
|
||||
data_to_append = {
|
||||
"type": t["type"],
|
||||
"time_data": dt.isoformat(),
|
||||
"tzinfo": getattr(t["tzinfo"], "zone", None),
|
||||
}
|
||||
)
|
||||
if dt < datetime.now(pytz.utc):
|
||||
expired_triggers.append(data_to_append)
|
||||
else:
|
||||
triggers.append(data_to_append)
|
||||
continue
|
||||
|
||||
if t["type"] == "cron":
|
||||
@ -125,7 +140,7 @@ class Task:
|
||||
|
||||
raise NotImplemented
|
||||
|
||||
return triggers
|
||||
return triggers, expired_triggers
|
||||
|
||||
async def _decode_time_triggers(self):
|
||||
if not self.data or not self.data.get("triggers", None):
|
||||
@ -138,7 +153,7 @@ class Task:
|
||||
|
||||
# First decode timezone if there is one
|
||||
if t["tzinfo"] is not None:
|
||||
t["tzinfo"] = timezone(t["tzinfo"])
|
||||
t["tzinfo"] = pytz.timezone(t["tzinfo"])
|
||||
|
||||
if t["type"] == "interval": # Convert into timedelta
|
||||
t["time_data"] = timedelta(**t["time_data"])
|
||||
@ -174,14 +189,23 @@ class Task:
|
||||
await self._decode_time_triggers()
|
||||
return self.data
|
||||
|
||||
async def get_triggers(self) -> List[Union[IntervalTrigger, DateTrigger]]:
|
||||
async def get_triggers(self) -> Tuple[List[BaseTrigger], List[BaseTrigger]]:
|
||||
if not self.data:
|
||||
await self.load_from_config()
|
||||
|
||||
if self.data is None or "triggers" not in self.data: # No triggers
|
||||
return []
|
||||
return [], []
|
||||
|
||||
return [get_trigger(t) for t in self.data["triggers"]]
|
||||
trigs = []
|
||||
expired_trigs = []
|
||||
for t in self.data["triggers"]:
|
||||
trig = get_trigger(t)
|
||||
if check_expired_trigger(trig):
|
||||
expired_trigs.append(t)
|
||||
else:
|
||||
trigs.append(t)
|
||||
|
||||
return trigs, expired_trigs
|
||||
|
||||
async def get_combined_trigger(self) -> Union[BaseTrigger, None]:
|
||||
if not self.data:
|
||||
@ -201,7 +225,10 @@ class Task:
|
||||
data_to_save = self.default_task_data.copy()
|
||||
if self.data:
|
||||
data_to_save["command_str"] = self.get_command_str()
|
||||
data_to_save["triggers"] = await self._encode_time_triggers()
|
||||
(
|
||||
data_to_save["triggers"],
|
||||
data_to_save["expired_triggers"],
|
||||
) = await self._encode_time_triggers()
|
||||
|
||||
to_save = {
|
||||
"guild_id": self.guild_id,
|
||||
@ -217,7 +244,10 @@ class Task:
|
||||
return
|
||||
|
||||
data_to_save = self.data.copy()
|
||||
data_to_save["triggers"] = await self._encode_time_triggers()
|
||||
(
|
||||
data_to_save["triggers"],
|
||||
data_to_save["expired_triggers"],
|
||||
) = await self._encode_time_triggers()
|
||||
|
||||
await self.config.guild_from_id(self.guild_id).tasks.set_raw(
|
||||
self.name, "data", value=data_to_save
|
||||
@ -247,12 +277,16 @@ class Task:
|
||||
)
|
||||
return False
|
||||
|
||||
actual_message: discord.Message = channel.last_message
|
||||
actual_message: Optional[discord.Message] = channel.last_message
|
||||
# I'd like to present you my chain of increasingly desperate message fetching attempts
|
||||
if actual_message is None:
|
||||
# log.warning("No message found in channel cache yet, skipping execution")
|
||||
# return
|
||||
if channel.last_message_id is not None:
|
||||
try:
|
||||
actual_message = await channel.fetch_message(channel.last_message_id)
|
||||
except discord.NotFound:
|
||||
actual_message = None
|
||||
if actual_message is None: # last_message_id was an invalid message I guess
|
||||
actual_message = await channel.history(limit=1).flatten()
|
||||
if not actual_message: # Basically only happens if the channel has no messages
|
||||
|
Loading…
x
Reference in New Issue
Block a user