Skip to content

Commit 2c902a2

Browse files
EepyElvyraDamegoi0bsFayeDel
authored
chore: merge unstable into stable (#1117)
* feat: Implement helper methods for invites (#1098) * feat: mention spam trigger type * fix: Reimplement manual sharding/presence, fix forum tag implementation (#1115) * fix: Reimplement manual sharding/presence instantiation. (This was accidentally removed per gateway rework) * refactor: Reorganise tag creation/updating/deletion to non-deprecated endpoints and make it cache-reflective. * chore: bump version (#1116) * fix: properly initialise private attributes in iterators (#1114) * fix: set `message.member.user` as `message.author` again (#1118) Co-authored-by: Damego <[email protected]> Co-authored-by: i0 <[email protected]> Co-authored-by: DeltaX <[email protected]>
1 parent 838330d commit 2c902a2

File tree

11 files changed

+144
-17
lines changed

11 files changed

+144
-17
lines changed

interactions/api/gateway/client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def __init__(
122122
intents: Intents,
123123
session_id: Optional[str] = MISSING,
124124
sequence: Optional[int] = MISSING,
125+
shards: Optional[List[Tuple[int]]] = MISSING,
126+
presence: Optional[ClientPresence] = MISSING,
125127
) -> None:
126128
"""
127129
:param token: The token of the application for connecting to the Gateway.
@@ -132,6 +134,10 @@ def __init__(
132134
:type session_id?: Optional[str]
133135
:param sequence?: The identifier sequence if trying to reconnect. Defaults to ``None``.
134136
:type sequence?: Optional[int]
137+
:param shards?: The list of shards for the application's initial connection, if provided. Defaults to ``None``.
138+
:type shards?: Optional[List[Tuple[int]]]
139+
:param presence?: The presence shown on an application once first connected. Defaults to ``None``.
140+
:type presence?: Optional[ClientPresence]
135141
"""
136142
try:
137143
self._loop = get_event_loop() if version_info < (3, 10) else get_running_loop()
@@ -161,8 +167,8 @@ def __init__(
161167
}
162168

163169
self._intents: Intents = intents
164-
self.__shard: Optional[List[Tuple[int]]] = None
165-
self.__presence: Optional[ClientPresence] = None
170+
self.__shard: Optional[List[Tuple[int]]] = None if shards is MISSING else shards
171+
self.__presence: Optional[ClientPresence] = None if presence is MISSING else presence
166172

167173
self._task: Optional[Task] = None
168174
self.__heartbeat_event = Event(loop=self._loop) if version_info < (3, 10) else Event()

interactions/api/http/channel.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ..error import LibraryException
55
from ..models.channel import Channel
66
from ..models.message import Message
7+
from ..models.misc import Snowflake
78
from .request import _Request
89
from .route import Route
910

@@ -312,8 +313,10 @@ async def create_tag(
312313
self,
313314
channel_id: int,
314315
name: str,
316+
moderated: bool = False,
315317
emoji_id: Optional[int] = None,
316318
emoji_name: Optional[str] = None,
319+
reason: Optional[str] = None,
317320
) -> dict:
318321
"""
319322
Create a new tag.
@@ -324,25 +327,41 @@ async def create_tag(
324327
325328
:param channel_id: Channel ID snowflake.
326329
:param name: The name of the tag
330+
:param moderated: Whether the tag can only be assigned to moderators or not. Defaults to ``False``
327331
:param emoji_id: The ID of the emoji to use for the tag
328332
:param emoji_name: The name of the emoji to use for the tag
333+
:param reason: The reason for the creating the tag, if any.
334+
:return: A Forum tag.
329335
"""
330336

331-
_dct = {"name": name}
337+
# This *assumes* cache is up-to-date.
338+
339+
_channel = self.cache[Channel].get(Snowflake(channel_id))
340+
_tags = [_._json for _ in _channel.available_tags] # list of tags in dict form
341+
342+
_dct = {"name": name, "moderated": moderated}
332343
if emoji_id:
333344
_dct["emoji_id"] = emoji_id
334345
if emoji_name:
335346
_dct["emoji_name"] = emoji_name
336347

337-
return await self._req.request(Route("POST", f"/channels/{channel_id}/tags"), json=_dct)
348+
_tags.append(_dct)
349+
350+
updated_channel = await self.modify_channel(
351+
channel_id, {"available_tags": _tags}, reason=reason
352+
)
353+
_channel_obj = Channel(**updated_channel, _client=self)
354+
return _channel_obj.available_tags[-1]._json
338355

339356
async def edit_tag(
340357
self,
341358
channel_id: int,
342359
tag_id: int,
343360
name: str,
361+
moderated: Optional[bool] = None,
344362
emoji_id: Optional[int] = None,
345363
emoji_name: Optional[str] = None,
364+
reason: Optional[str] = None,
346365
) -> dict:
347366
"""
348367
Update a tag.
@@ -351,28 +370,62 @@ async def edit_tag(
351370
Can either have an emoji_id or an emoji_name, but not both.
352371
emoji_id is meant for custom emojis, emoji_name is meant for unicode emojis.
353372
373+
The object returns *will* have a different tag ID.
374+
354375
:param channel_id: Channel ID snowflake.
355376
:param tag_id: The ID of the tag to update.
377+
:param moderated: Whether the tag can only be assigned to moderators or not. Defaults to ``False``
356378
:param name: The new name of the tag
357379
:param emoji_id: The ID of the emoji to use for the tag
358380
:param emoji_name: The name of the emoji to use for the tag
381+
:param reason: The reason for deleting the tag, if any.
382+
383+
:return The updated tag object.
359384
"""
360385

361-
_dct = {"name": name}
386+
# This *assumes* cache is up-to-date.
387+
388+
_channel = self.cache[Channel].get(Snowflake(channel_id))
389+
_tags = [_._json for _ in _channel.available_tags] # list of tags in dict form
390+
391+
_old_tag = [tag for tag in _tags if tag["id"] == tag_id][0]
392+
393+
_tags.remove(_old_tag)
394+
395+
_dct = {"name": name, "tag_id": tag_id}
396+
if moderated:
397+
_dct["moderated"] = moderated
362398
if emoji_id:
363399
_dct["emoji_id"] = emoji_id
364400
if emoji_name:
365401
_dct["emoji_name"] = emoji_name
366402

367-
return await self._req.request(
368-
Route("PUT", f"/channels/{channel_id}/tags/{tag_id}"), json=_dct
403+
_tags.append(_dct)
404+
405+
updated_channel = await self.modify_channel(
406+
channel_id, {"available_tags": _tags}, reason=reason
369407
)
408+
_channel_obj = Channel(**updated_channel, _client=self)
409+
410+
self.cache[Channel].merge(_channel_obj)
411+
412+
return [tag for tag in _channel_obj.available_tags if tag.name == name][0]
370413

371-
async def delete_tag(self, channel_id: int, tag_id: int) -> None: # wha?
414+
async def delete_tag(self, channel_id: int, tag_id: int, reason: Optional[str] = None) -> None:
372415
"""
373416
Delete a forum tag.
374417
375418
:param channel_id: Channel ID snowflake.
376419
:param tag_id: The ID of the tag to delete
420+
:param reason: The reason for deleting the tag, if any.
377421
"""
378-
return await self._req.request(Route("DELETE", f"/channels/{channel_id}/tags/{tag_id}"))
422+
_channel = self.cache[Channel].get(Snowflake(channel_id))
423+
_tags = [_._json for _ in _channel.available_tags]
424+
425+
_old_tag = [tag for tag in _tags if tag["id"] == Snowflake(tag_id)][0]
426+
427+
_tags.remove(_old_tag)
428+
429+
request = await self.modify_channel(channel_id, {"available_tags": _tags}, reason=reason)
430+
431+
self.cache[Channel].merge(Channel(**request, _client=self))

interactions/api/http/invite.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@ async def get_invite(
2424
"""
2525
Gets a Discord invite using its code.
2626
27-
.. note:: with_expiration is currently broken, the API will always return expiration_date.
28-
2927
:param invite_code: A string representing the invite code.
3028
:param with_counts: Whether approximate_member_count and approximate_presence_count are returned.
31-
:param with_expiration: Whether the invite's expiration is returned.
29+
:param with_expiration: Whether the invite's expiration date is returned.
3230
:param guild_scheduled_event_id: A guild scheduled event's ID.
3331
"""
3432
params_set = {

interactions/api/http/thread.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ async def create_thread(
159159
reason: Optional[str] = None,
160160
) -> dict:
161161
"""
162-
From a given channel, create a Thread with an optional message to start with..
162+
From a given channel, create a Thread with an optional message to start with.
163163
164164
:param channel_id: The ID of the channel to create this thread in
165165
:param name: The name of the thread
@@ -212,7 +212,7 @@ async def create_thread_in_forum(
212212
:param name: The name of the thread
213213
:param auto_archive_duration: duration in minutes to automatically archive the thread after recent activity,
214214
can be set to: 60, 1440, 4320, 10080
215-
:param message_payload: The payload/dictionary contents of the first message in the forum thread.
215+
:param message: The payload/dictionary contents of the first message in the forum thread.
216216
:param applied_tags: List of tag ids that can be applied to the forum, if any.
217217
:param files: An optional list of files to send attached to the message.
218218
:param rate_limit_per_user: Seconds a user has to wait before sending another message (0 to 21600), if given.

interactions/api/models/channel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ def __init__(
138138
):
139139
super().__init__(obj, _client, maximum=maximum, start_at=start_at, check=check)
140140

141+
self.__stop: bool = False
142+
141143
from .message import Message
142144

143145
if reverse and start_at is MISSING:

interactions/api/models/guild.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,9 @@ def __init__(
241241
start_at: Optional[Union[int, str, Snowflake, Member]] = MISSING,
242242
check: Optional[Callable[[Member], bool]] = None,
243243
):
244+
245+
self.__stop: bool = False
246+
244247
super().__init__(obj, _client, maximum=maximum, start_at=start_at, check=check)
245248

246249
self.after = self.start_at
@@ -2830,6 +2833,66 @@ async def get_full_audit_logs(
28302833

28312834
return AuditLogs(**_audit_log_dict)
28322835

2836+
async def get_invite(
2837+
self,
2838+
invite_code: str,
2839+
with_counts: Optional[bool] = MISSING,
2840+
with_expiration: Optional[bool] = MISSING,
2841+
guild_scheduled_event_id: Optional[int] = MISSING,
2842+
) -> "Invite":
2843+
"""
2844+
Gets the invite using its code.
2845+
2846+
:param str invite_code: A string representing the invite code.
2847+
:param Optional[bool] with_counts: Whether approximate_member_count and approximate_presence_count are returned.
2848+
:param Optional[bool] with_expiration: Whether the invite's expiration date is returned.
2849+
:param Optional[int] guild_scheduled_event_id: A guild scheduled event's ID.
2850+
:return: An invite
2851+
:rtype: Invite
2852+
"""
2853+
if not self._client:
2854+
raise LibraryException(code=13)
2855+
2856+
_with_counts = with_counts if with_counts is not MISSING else None
2857+
_with_expiration = with_expiration if with_expiration is not MISSING else None
2858+
_guild_scheduled_event_id = (
2859+
guild_scheduled_event_id if guild_scheduled_event_id is not MISSING else None
2860+
)
2861+
2862+
res = await self._client.get_invite(
2863+
invite_code=invite_code,
2864+
with_counts=_with_counts,
2865+
with_expiration=_with_expiration,
2866+
guild_scheduled_event_id=_guild_scheduled_event_id,
2867+
)
2868+
2869+
return Invite(**res, _client=self._client)
2870+
2871+
async def delete_invite(self, invite_code: str, reason: Optional[str] = None) -> None:
2872+
"""
2873+
Deletes the invite using its code.
2874+
2875+
:param str invite_code: A string representing the invite code.
2876+
:param Optional[str] reason: The reason of the deletion
2877+
"""
2878+
if not self._client:
2879+
raise LibraryException(code=13)
2880+
2881+
await self._client.delete_invite(invite_code=invite_code, reason=reason)
2882+
2883+
async def get_invites(self) -> List["Invite"]:
2884+
"""
2885+
Gets invites of the guild.
2886+
2887+
:return: A list of guild invites
2888+
:rtype: List[Invite]
2889+
"""
2890+
if not self._client:
2891+
raise LibraryException(code=13)
2892+
2893+
res = await self._client.get_guild_invites(guild_id=int(self.id))
2894+
return [Invite(**_, _client=self._client) for _ in res]
2895+
28332896
@property
28342897
def icon_url(self) -> Optional[str]:
28352898
"""

interactions/api/models/message.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,9 @@ def __attrs_post_init__(self):
831831
if self.guild_id:
832832
self.member._extras["guild_id"] = self.guild_id
833833

834+
if self.author and self.member:
835+
self.member.user = self.author
836+
834837
async def get_channel(self) -> Channel:
835838
"""
836839
Gets the channel where the message was sent.

interactions/api/models/misc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ class AutoModTriggerType(IntEnum):
208208
HARMFUL_LINK = 2
209209
SPAM = 3
210210
KEYWORD_PRESET = 4
211+
MENTION_SPAM = 5
211212

212213

213214
class AutoModKeywordPresetTypes(IntEnum):

interactions/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"__authors__",
77
)
88

9-
__version__ = "4.3.2"
9+
__version__ = "4.3.3"
1010

1111
__authors__ = {
1212
"current": [

interactions/client/bot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,13 @@ def __init__(
8080
self._loop: AbstractEventLoop = get_event_loop()
8181
self._http: HTTPClient = token
8282
self._intents: Intents = kwargs.get("intents", Intents.DEFAULT)
83-
self._websocket: WSClient = WSClient(token=token, intents=self._intents)
8483
self._shards: List[Tuple[int]] = kwargs.get("shards", [])
8584
self._commands: List[Command] = []
8685
self._default_scope = kwargs.get("default_scope")
8786
self._presence = kwargs.get("presence")
87+
self._websocket: WSClient = WSClient(
88+
token=token, intents=self._intents, shards=self._shards, presence=self._presence
89+
)
8890
self._token = token
8991
self._extensions = {}
9092
self._scopes = set([])

interactions/utils/abc/base_iterators.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def __init__(
5656
if not hasattr(start_at, "id")
5757
else int(start_at.id)
5858
)
59-
self.__stop: bool = False
6059
self.objects: Optional[List[_O]] = None
6160

6261

0 commit comments

Comments
 (0)