diff --git a/nudity/info..json b/nudity/info..json index 1d66cf7..34c4804 100644 --- a/nudity/info..json +++ b/nudity/info..json @@ -2,19 +2,25 @@ "author": [ "Bobloy" ], - "bot_version": [ + "min_bot_version": [ 3, - 0, - 0 + 3, + 11 ], - "description": "Keep track of when users were last seen online", - "hidden": true, + "description": "Monitor images for NSFW content and moves them to a nsfw channel if possible", + "hidden": false, "install_msg": "Thank you for installing Nudity. Get started with `[p]load nudity`, then `[p]help Nudity`", - "requirements": ["nudepy"], - "short": "Last seen tracker", + "requirements": [ + "nudenet", + "tensorflow>=1.14,<2.0", + "keras>=2.4" + ], + "short": "NSFW image tracker and mover", "tags": [ "bobloy", "utils", - "tools" + "tools", + "nude", + "nsfw" ] } diff --git a/nudity/nudity.py b/nudity/nudity.py index a7290a5..6eb4221 100644 --- a/nudity/nudity.py +++ b/nudity/nudity.py @@ -1,19 +1,19 @@ -from io import BytesIO +import pathlib import discord -from PIL import Image -from nude import is_nude -from redbot.core import Config -from redbot.core import commands +from nudenet import NudeClassifier +from redbot.core import Config, commands from redbot.core.bot import Red +from redbot.core.data_manager import cog_data_path -class Nudity: +class Nudity(commands.Cog): """ V3 Cog Template """ def __init__(self, bot: Red): + super().__init__() self.bot = bot self.config = Config.get_conf(self, identifier=9811198108111121, force_registration=True) @@ -21,6 +21,17 @@ class Nudity: self.config.register_guild(**default_guild) + # self.detector = NudeDetector() + self.classifier = NudeClassifier() + + self.data_path: pathlib.Path = cog_data_path(self) + + self.current_processes = 0 + + async def red_delete_data_for_user(self, **kwargs): + """Nothing to delete""" + return + @commands.command(aliases=["togglenudity"], name="nudity") async def nudity(self, ctx: commands.Context): """Toggle nude-checking on or off""" @@ -42,14 +53,14 @@ class Nudity: await ctx.send("NSFW channel has been set to {}".format(channel.mention)) async def get_nsfw_channel(self, guild: discord.Guild): - channel_id = self.config.guild(guild).channel_id() + channel_id = await self.config.guild(guild).channel_id() if channel_id is None: return None else: - return await guild.get_channel(channel_id=channel_id) + return guild.get_channel(channel_id=channel_id) - async def nsfw(self, message: discord.Message, image: BytesIO): + async def nsfw(self, message: discord.Message, images: dict): content = message.content guild: discord.Guild = message.guild if not content: @@ -62,7 +73,7 @@ class Nudity: embed = discord.Embed(title="NSFW Image Detected") embed.add_field(name="Original Message", value=content) - + embed.set_author(name=message.author.name, icon_url=message.author.avatar_url) await message.channel.send(embed=embed) nsfw_channel = await self.get_nsfw_channel(guild) @@ -70,15 +81,19 @@ class Nudity: if nsfw_channel is None: return else: - await nsfw_channel.send( - "NSFW Image from {}".format(message.channel.mention), file=image - ) - + for image, r in images.items(): + if r["unsafe"] > 0.7: + await nsfw_channel.send( + "NSFW Image from {}".format(message.channel.mention), + file=discord.File(image,), + ) + + @commands.Cog.listener() async def on_message(self, message: discord.Message): - if not message.attachments: - return + is_private = isinstance(message.channel, discord.abc.PrivateChannel) - if message.guild is None: + if not message.attachments or is_private or message.author.bot: + # print("did not qualify") return try: @@ -87,31 +102,41 @@ class Nudity: return if not is_on: + print("Not on") return channel: discord.TextChannel = message.channel if channel.is_nsfw(): + print("nsfw channel is okay") return - attachment = message.attachments[0] + check_list = [] + for attachment in message.attachments: + # async with aiohttp.ClientSession() as session: + # img = await fetch_img(session, attachment.url) + + ext = attachment.filename - # async with aiohttp.ClientSession() as session: - # img = await fetch_img(session, attachment.url) + temp_name = self.data_path / f"nudecheck{self.current_processes}_{ext}" - temp = BytesIO() - print("Pre attachment save") - await attachment.save(temp) - print("Pre Image open") - temp = Image.open(temp) + self.current_processes += 1 + + print("Pre attachment save") + await attachment.save(temp_name) + check_list.append(temp_name) print("Pre nude check") - if is_nude(temp): - print("Is nude") + # nude_results = self.detector.detect(temp_name) + nude_results = self.classifier.classify([str(n) for n in check_list]) + # print(nude_results) + + if True in [r["unsafe"] > 0.7 for r in nude_results.values()]: + # print("Is nude") await message.add_reaction("❌") - await self.nsfw(message, temp) + await self.nsfw(message, nude_results) else: - print("Is not nude") + # print("Is not nude") await message.add_reaction("✅")