From 2e27e85ac5146ccf40a9b24d50775fdb1c529b6b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 6 Nov 2020 18:57:22 +0200 Subject: [PATCH] Add support for multiple pins --- mautrix_telegram/abstract_user.py | 17 ++++++------ mautrix_telegram/db/message.py | 14 +++++++++- mautrix_telegram/matrix.py | 13 +++++---- mautrix_telegram/portal/base.py | 2 ++ mautrix_telegram/portal/matrix.py | 41 ++++++++++++++--------------- mautrix_telegram/portal/telegram.py | 21 ++++++++++----- 6 files changed, 64 insertions(+), 44 deletions(-) diff --git a/mautrix_telegram/abstract_user.py b/mautrix_telegram/abstract_user.py index dba567e..cd36304 100644 --- a/mautrix_telegram/abstract_user.py +++ b/mautrix_telegram/abstract_user.py @@ -25,8 +25,8 @@ Connection) from telethon.tl.patched import MessageService, Message from telethon.tl.types import ( - Channel, Chat, MessageActionChannelMigrateFrom, PeerUser, TypeUpdate, UpdateChatPinnedMessage, - UpdateChannelPinnedMessage, UpdateChatParticipantAdmin, UpdateChatParticipants, PeerChat, + Channel, Chat, MessageActionChannelMigrateFrom, PeerUser, TypeUpdate, UpdatePinnedMessages, + UpdatePinnedChannelMessages, UpdateChatParticipantAdmin, UpdateChatParticipants, PeerChat, UpdateChatUserTyping, UpdateDeleteChannelMessages, UpdateNewMessage, UpdateDeleteMessages, UpdateEditChannelMessage, UpdateEditMessage, UpdateNewChannelMessage, UpdateReadHistoryOutbox, UpdateShortChatMessage, UpdateShortMessage, UpdateUserName, UpdateUserPhoto, UpdateUserStatus, @@ -252,7 +252,7 @@ async def _update(self, update: TypeUpdate) -> None: await self.update_admin(update) elif isinstance(update, UpdateChatParticipants): await self.update_participants(update) - elif isinstance(update, (UpdateChannelPinnedMessage, UpdateChatPinnedMessage)): + elif isinstance(update, (UpdatePinnedMessages, UpdatePinnedChannelMessages)): await self.update_pinned_messages(update) elif isinstance(update, (UpdateUserName, UpdateUserPhoto)): await self.update_others_info(update) @@ -263,14 +263,15 @@ async def _update(self, update: TypeUpdate) -> None: else: self.log.trace("Unhandled update: %s", update) - async def update_pinned_messages(self, update: Union[UpdateChannelPinnedMessage, - UpdateChatPinnedMessage]) -> None: - if isinstance(update, UpdateChatPinnedMessage): - portal = po.Portal.get_by_tgid(TelegramID(update.chat_id)) + async def update_pinned_messages(self, update: Union[UpdatePinnedMessages, + UpdatePinnedChannelMessages]) -> None: + if isinstance(update, UpdatePinnedMessages): + portal = po.Portal.get_by_entity(update.peer, receiver_id=self.tgid) else: portal = po.Portal.get_by_tgid(TelegramID(update.channel_id)) if portal and portal.mxid: - await portal.receive_telegram_pin_id(update.id, self.tgid) + await portal.receive_telegram_pin_ids(update.messages, self.tgid, + remove=not update.pinned) @staticmethod async def update_participants(update: UpdateChatParticipants) -> None: diff --git a/mautrix_telegram/db/message.py b/mautrix_telegram/db/message.py index 8d6f5bd..e39f2a7 100644 --- a/mautrix_telegram/db/message.py +++ b/mautrix_telegram/db/message.py @@ -13,7 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, Iterator +from typing import Optional, Iterator, List from sqlalchemy import Column, UniqueConstraint, Integer, String, and_, func, desc, select @@ -51,6 +51,12 @@ def get_one_by_tgid(cls, tgid: TelegramID, tg_space: TelegramID, edit_index: int return cls._select_one_or_none(cls.c.tgid == tgid, cls.c.tg_space == tg_space, cls.c.edit_index == edit_index) + @classmethod + def get_first_by_tgids(cls, tgids: List[TelegramID], tg_space: TelegramID + ) -> Iterator['Message']: + return cls._select_all(cls.c.tgid.in_(tgids), cls.c.tg_space == tg_space, + cls.c.edit_index == 0) + @classmethod def count_spaces_by_mxid(cls, mxid: EventID, mx_room: RoomID) -> int: rows = cls.db.execute(select([func.count(cls.c.tg_space)]) @@ -77,6 +83,12 @@ def get_by_mxid(cls, mxid: EventID, mx_room: RoomID, tg_space: TelegramID return cls._select_one_or_none(cls.c.mxid == mxid, cls.c.mx_room == mx_room, cls.c.tg_space == tg_space) + @classmethod + def get_by_mxids(cls, mxids: List[EventID], mx_room: RoomID, tg_space: TelegramID + ) -> Iterator['Message']: + return cls._select_all(cls.c.mxid.in_(mxids), cls.c.mx_room == mx_room, + cls.c.tg_space == tg_space) + @classmethod def update_by_tgid(cls, s_tgid: TelegramID, s_tg_space: TelegramID, s_edit_index: int, **values) -> None: diff --git a/mautrix_telegram/matrix.py b/mautrix_telegram/matrix.py index 062cbe0..379e559 100644 --- a/mautrix_telegram/matrix.py +++ b/mautrix_telegram/matrix.py @@ -283,13 +283,12 @@ async def handle_room_pin(room_id: RoomID, sender_mxid: UserID, portal = po.Portal.get_by_mxid(room_id) sender = await u.User.get_by_mxid(sender_mxid).ensure_started() if await sender.has_full_access(allow_bot=True) and portal: - events = new_events - old_events - if len(events) > 0: - # New event pinned, set that as pinned in Telegram. - await portal.handle_matrix_pin(sender, EventID(events.pop()), event_id) - elif len(new_events) == 0: - # All pinned events removed, remove pinned event in Telegram. - await portal.handle_matrix_pin(sender, None, event_id) + if not new_events: + await portal.handle_matrix_unpin_all(sender) + else: + changes = {event_id: event_id in new_events + for event_id in new_events ^ old_events} + await portal.handle_matrix_pin(sender, changes, event_id) @staticmethod async def handle_room_upgrade(room_id: RoomID, sender: UserID, new_room_id: RoomID, diff --git a/mautrix_telegram/portal/base.py b/mautrix_telegram/portal/base.py index 3f3d510..a0ba796 100644 --- a/mautrix_telegram/portal/base.py +++ b/mautrix_telegram/portal/base.py @@ -104,6 +104,7 @@ class BasePortal(MautrixBasePortal, ABC): dedup: PortalDedup send_lock: PortalSendLock + _pin_lock: asyncio.Lock _db_instance: DBPortal _main_intent: Optional[IntentAPI] @@ -138,6 +139,7 @@ def __init__(self, tgid: TelegramID, peer_type: str, tg_receiver: Optional[Teleg self.dedup = PortalDedup(self) self.send_lock = PortalSendLock() + self._pin_lock = asyncio.Lock() if tgid: self.by_tgid[self.tgid_full] = self diff --git a/mautrix_telegram/portal/matrix.py b/mautrix_telegram/portal/matrix.py index e6adac8..3d601f2 100644 --- a/mautrix_telegram/portal/matrix.py +++ b/mautrix_telegram/portal/matrix.py @@ -22,11 +22,10 @@ from telethon.tl.functions.messages import (EditChatPhotoRequest, EditChatTitleRequest, UpdatePinnedMessageRequest, SetTypingRequest, - EditChatAboutRequest) + EditChatAboutRequest, UnpinAllMessagesRequest) from telethon.tl.functions.channels import EditPhotoRequest, EditTitleRequest, JoinChannelRequest -from telethon.errors import (ChatNotModifiedError, PhotoExtInvalidError, - PhotoInvalidDimensionsError, PhotoSaveFileInvalidError, - RPCError) +from telethon.errors import (ChatNotModifiedError, PhotoExtInvalidError, MessageIdInvalidError, + PhotoInvalidDimensionsError, PhotoSaveFileInvalidError, RPCError) from telethon.tl.patched import Message, MessageService from telethon.tl.types import ( DocumentAttributeFilename, DocumentAttributeImageSize, GeoPoint, @@ -432,23 +431,23 @@ async def _handle_matrix_message(self, sender: 'u.User', content: MessageEventCo else: self.log.trace("Unhandled Matrix event: %s", content) - async def handle_matrix_pin(self, sender: 'u.User', pinned_message: Optional[EventID], + async def handle_matrix_unpin_all(self, sender: 'u.User', pin_event_id: EventID) -> None: + await sender.client(UnpinAllMessagesRequest(peer=self.peer)) + await self._send_delivery_receipt(pin_event_id) + + async def handle_matrix_pin(self, sender: 'u.User', changes: Dict[EventID, bool], pin_event_id: EventID) -> None: - if self.peer_type != "chat" and self.peer_type != "channel": - return - try: - if not pinned_message: - await sender.client(UpdatePinnedMessageRequest(peer=self.peer, id=0)) - else: - tg_space = self.tgid if self.peer_type == "channel" else sender.tgid - message = DBMessage.get_by_mxid(pinned_message, self.mxid, tg_space) - if message is None: - self.log.warning(f"Could not find pinned {pinned_message} in {self.mxid}") - return - await sender.client(UpdatePinnedMessageRequest(peer=self.peer, id=message.tgid)) - await self._send_delivery_receipt(pin_event_id) - except ChatNotModifiedError: - pass + tg_space = self.tgid if self.peer_type == "channel" else sender.tgid + ids = {msg.mxid: msg.tgid + for msg in DBMessage.get_by_mxids(list(changes.keys()), + mx_room=self.mxid, tg_space=tg_space)} + for event_id, pinned in changes.items(): + try: + await sender.client(UpdatePinnedMessageRequest(peer=self.peer, id=ids[event_id], + unpin=not pinned)) + except (ChatNotModifiedError, MessageIdInvalidError, KeyError): + pass + await self._send_delivery_receipt(pin_event_id) async def handle_matrix_deletion(self, deleter: 'u.User', event_id: EventID, redaction_event_id: EventID) -> None: diff --git a/mautrix_telegram/portal/telegram.py b/mautrix_telegram/portal/telegram.py index f74eb52..0e3cfb0 100644 --- a/mautrix_telegram/portal/telegram.py +++ b/mautrix_telegram/portal/telegram.py @@ -694,13 +694,20 @@ async def set_telegram_admin(self, user_id: TelegramID) -> None: levels.users[puppet.mxid] = 50 await self.main_intent.set_power_levels(self.mxid, levels) - async def receive_telegram_pin_id(self, msg_id: TelegramID, receiver: TelegramID) -> None: - tg_space = receiver if self.peer_type != "channel" else self.tgid - message = DBMessage.get_one_by_tgid(msg_id, tg_space) if msg_id != 0 else None - if message: - await self.main_intent.set_pinned_messages(self.mxid, [message.mxid]) - else: - await self.main_intent.set_pinned_messages(self.mxid, []) + async def receive_telegram_pin_ids(self, msg_ids: List[TelegramID], receiver: TelegramID, + remove: bool) -> None: + async with self._pin_lock: + tg_space = receiver if self.peer_type != "channel" else self.tgid + previously_pinned = await self.main_intent.get_pinned_messages(self.mxid) + currently_pinned_dict = {event_id: True for event_id in previously_pinned} + for message in DBMessage.get_first_by_tgids(msg_ids, tg_space): + if remove: + currently_pinned_dict.pop(message.mxid, None) + else: + currently_pinned_dict[message.mxid] = True + currently_pinned = list(currently_pinned_dict.keys()) + if currently_pinned != previously_pinned: + await self.main_intent.set_pinned_messages(self.mxid, currently_pinned) async def set_telegram_admins_enabled(self, enabled: bool) -> None: level = 50 if enabled else 10