如何在 Telethon 中正确使用 iter

本文介绍了如何在 Telethon 中正确使用 iter_download 功能进行多连接下载的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!



我一直在尝试实现多线程电报下载客户端.对于单次下载,我们可以简单地使用 download_media 功能.

I've been trying to implement a multi threaded Telegram download client. For single downloads we can simply use the download_media functionality.

但 Telethon 提供 iter_download 功能,根据文档,它用于流媒体,其中还包括暂停和恢复功能.我们可以使用它来下载具有多个连接的单个文件.

But telethon offers iter_download function, as per documenation it is used for streaming which also includes pausing and resuming functionality. We can use this to download a single file with multiple connections.


This is what I've scripted so far. No where to find any solid examples for multi connection download

async def multi_downloader(file, total_size, part, offset, part_size):

    f = open('output.mkv.'+str(part), 'wb')
    size = 0

    global chunk_size
    limit = 10485760#closestInteger(part_size / chunk_size, 10485760)
    async for chunk in client.iter_download(obj, offset = offset, limit = limit, chunk_size = chunk_size, request_size = chunk_size, file_size = total_size):
        size += (len(chunk))
        if size >= (part_size):
            print("Part "+str(part)+" completed. "+str(part_size))


The thing is it always throws the invalid limit error, if I change the offset for seeking. If the offset is zero then everything is fine.

telethon.errors.rpcerrorlist.LimitInvalidError: 一个无效的限制是假如.请参阅 https://core.telegram/api/files#downloading-files(由 GetFileRequest 引起)

telethon.errors.rpcerrorlist.LimitInvalidError: An invalid limit was provided. See https://core.telegram/api/files#downloading-files (caused by GetFileRequest)


我们已经制作了类似的东西,你可以在这里找到 https://gist.github/painor/7e74de80ae0c819d3e9abcf9989a8dd6.代码:

We have already made somlething similar which you can find here https://gist.github/painor/7e74de80ae0c819d3e9abcf9989a8dd6 . Code :

> Based on parallel_file_transfer.py from mautrix-telegram, with permission to distribute under the MIT license
> Copyright (C) 2019 Tulir Asokan - https://github/tulir/mautrix-telegram
import asyncio
import hashlib
import inspect
import logging
import os
from collections import defaultdict
from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict, Tuple, BinaryIO

import math
from telethon import utils, helpers, TelegramClient
from telethon.crypto import AuthKey
from telethonwork import MTProtoSender
from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
from telethon.tl.functions.upload import (GetFileRequest, SaveFilePartRequest,
from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation,
                               InputPhotoFileLocation, InputPeerPhotoFileLocation, TypeInputFile,
                               InputFileBig, InputFile)

log: logging.Logger = logging.getLogger("telethon")
TypeLocation = Union[Document, InputDocumentFileLocation, InputPeerPhotoFileLocation,
                     InputFileLocation, InputPhotoFileLocation]

def stream_file(file_to_stream: BinaryIO, chunk_size=1024):
    while True:
        data_read = file_to_stream.read(chunk_size)
        if not data_read:
        yield data_read

class DownloadSender:
    sender: MTProtoSender
    request: GetFileRequest
    remaining: int
    stride: int

    def __init__(self, sender: MTProtoSender, file: TypeLocation, offset: int, limit: int,
                 stride: int, count: int) -> None:
        self.sender = sender
        self.request = GetFileRequest(file, offset=offset, limit=limit)
        self.stride = stride
        self.remaining = count

    async def next(self) -> Optional[bytes]:
        if not self.remaining:
            return None
        result = await self.sender.send(self.request)
        self.remaining -= 1
        self.request.offset += self.stride
        return result.bytes

    def disconnect(self) -> Awaitable[None]:
        return self.sender.disconnect()

class UploadSender:
    sender: MTProtoSender
    request: Union[SaveFilePartRequest, SaveBigFilePartRequest]
    part_count: int
    stride: int
    previous: Optional[asyncio.Task]
    loop: asyncio.AbstractEventLoop

    def __init__(self, sender: MTProtoSender, file_id: int, part_count: int, big: bool, index: int,
                 stride: int, loop: asyncio.AbstractEventLoop) -> None:
        self.sender = sender
        self.part_count = part_count
        if big:
            self.request = SaveBigFilePartRequest(file_id, index, part_count, b"")
            self.request = SaveFilePartRequest(file_id, index, b"")
        self.stride = stride
        self.previous = None
        self.loop = loop

    async def next(self, data: bytes) -> None:
        if self.previous:
            await self.previous
        self.previous = self.loop.create_task(self._next(data))

    async def _next(self, data: bytes) -> None:
        self.request.bytes = data
        log.debug(f"Sending file part {self.request.file_part}/{self.part_count}"
                  f" with {len(data)} bytes")
        await self.sender.send(self.request)
        self.request.file_part += self.stride

    async def disconnect(self) -> None:
        if self.previous:
            await self.previous
        return await self.sender.disconnect()

class ParallelTransferrer:
    client: TelegramClient
    loop: asyncio.AbstractEventLoop
    dc_id: int
    senders: Optional[List[Union[DownloadSender, UploadSender]]]
    auth_key: AuthKey
    upload_ticker: int

    def __init__(self, client: TelegramClient, dc_id: Optional[int] = None) -> None:
        self.client = client
        self.loop = self.client.loop
        self.dc_id = dc_id or self.client.session.dc_id
        self.auth_key = (None if dc_id and self.client.session.dc_id != dc_id
                         else self.client.session.auth_key)
        self.senders = None
        self.upload_ticker = 0

    async def _cleanup(self) -> None:
        await asyncio.gather(*[sender.disconnect() for sender in self.senders])
        self.senders = None

    def _get_connection_count(file_size: int, max_count: int = 20,
                              full_size: int = 100 * 1024 * 1024) -> int:
        if file_size > full_size:
            return max_count
        return math.ceil((file_size / full_size) * max_count)

    async def _init_download(self, connections: int, file: TypeLocation, part_count: int,
                             part_size: int) -> None:
        minimum, remainder = divmod(part_count, connections)

        def get_part_count() -> int:
            nonlocal remainder
            if remainder > 0:
                remainder -= 1
                return minimum + 1
            return minimum

        # The first cross-DC sender will export+import the authorization, so we always create it
        # before creating any other senders.
        self.senders = [
            await self._create_download_sender(file, 0, part_size, connections * part_size,
            *await asyncio.gather(
                *[self._create_download_sender(file, i, part_size, connections * part_size,
                  for i in range(1, connections)])

    async def _create_download_sender(self, file: TypeLocation, index: int, part_size: int,
                                      stride: int,
                                      part_count: int) -> DownloadSender:
        return DownloadSender(await self._create_sender(), file, index * part_size, part_size,
                              stride, part_count)

    async def _init_upload(self, connections: int, file_id: int, part_count: int, big: bool
                           ) -> None:
        self.senders = [
            await self._create_upload_sender(file_id, part_count, big, 0, connections),
            *await asyncio.gather(
                *[self._create_upload_sender(file_id, part_count, big, i, connections)
                  for i in range(1, connections)])

    async def _create_upload_sender(self, file_id: int, part_count: int, big: bool, index: int,
                                    stride: int) -> UploadSender:
        return UploadSender(await self._create_sender(), file_id, part_count, big, index, stride,

    async def _create_sender(self) -> MTProtoSender:
        dc = await self.client._get_dc(self.dc_id)
        sender = MTProtoSender(self.auth_key, self.loop, loggers=self.client._log)
        await sender.connect(self.client._connection(dc.ip_address, dc.port, dc.id,
                                                     loop=self.loop, loggers=self.client._log,
        if not self.auth_key:
            log.debug(f"Exporting auth to DC {self.dc_id}")
            auth = await self.client(ExportAuthorizationRequest(self.dc_id))
            req = self.client._init_with(ImportAuthorizationRequest(
                id=auth.id, bytes=auth.bytes
            await sender.send(req)
            self.auth_key = sender.auth_key
        return sender

    async def init_upload(self, file_id: int, file_size: int, part_size_kb: Optional[float] = None,
                          connection_count: Optional[int] = None) -> Tuple[int, int, bool]:
        connection_count = connection_count or self._get_connection_count(file_size)
        print("init_upload count is ", connection_count)
        part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
        part_count = (file_size + part_size - 1) // part_size
        is_large = file_size > 10 * 1024 * 1024
        await self._init_upload(connection_count, file_id, part_count, is_large)
        return part_size, part_count, is_large

    async def upload(self, part: bytes) -> None:
        await self.senders[self.upload_ticker].next(part)
        self.upload_ticker = (self.upload_ticker + 1) % len(self.senders)

    async def finish_upload(self) -> None:
        await self._cleanup()

    async def download(self, file: TypeLocation, file_size: int,
                       part_size_kb: Optional[float] = None,
                       connection_count: Optional[int] = None) -> AsyncGenerator[bytes, None]:
        connection_count = connection_count or self._get_connection_count(file_size)
        print("download count is ", connection_count)

        part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
        part_count = math.ceil(file_size / part_size)
        log.debug("Starting parallel download: "
                  f"{connection_count} {part_size} {part_count} {file!s}")
        await self._init_download(connection_count, file, part_count, part_size)

        part = 0
        while part < part_count:
            tasks = []
            for sender in self.senders:
            for task in tasks:
                data = await task
                if not data:
                yield data
                part += 1
                log.debug(f"Part {part} downloaded")

        log.debug("Parallel download finished, cleaning up connections")
        await self._cleanup()

parallel_transfer_locks: DefaultDict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())

async def _internal_transfer_to_telegram(client: TelegramClient,
                                         response: BinaryIO,
                                         progress_callback: callable
                                         ) -> Tuple[TypeInputFile, int]:
    file_id = helpers.generate_random_long()
    file_size = os.path.getsize(response.name)

    hash_md5 = hashlib.md5()
    uploader = ParallelTransferrer(client)
    part_size, part_count, is_large = await uploader.init_upload(file_id, file_size)
    buffer = bytearray()
    for data in stream_file(response):
        if progress_callback:
            r = progress_callback(response.tell(), file_size)
            if inspect.isawaitable(r):
                await r
        if not is_large:
        if len(buffer) == 0 and len(data) == part_size:
            await uploader.upload(data)
        new_len = len(buffer) + len(data)
        if new_len >= part_size:
            cutoff = part_size - len(buffer)
            await uploader.upload(bytes(buffer))
    if len(buffer) > 0:
        await uploader.upload(bytes(buffer))
    await uploader.finish_upload()
    if is_large:
        return InputFileBig(file_id, part_count, "upload"), file_size
        return InputFile(file_id, part_count, "upload", hash_md5.hexdigest()), file_size

async def download_file(client: TelegramClient,
                                        location: TypeLocation,
                                        out: BinaryIO,
                                        progress_callback: callable = None
                                        ) -> BinaryIO:
    size = location.size
    dc_id, location = utils.get_input_location(location)
    # We lock the transfers because telegram has connection count limits
    downloader = ParallelTransferrer(client, dc_id)
    downloaded = downloader.download(location, size)
    async for x in downloaded:
        if progress_callback:
            r = progress_callback(out.tell(), size)
            if inspect.isawaitable(r):
                await r

    return out

async def upload_file(client: TelegramClient,
                                        file: BinaryIO,
                                        progress_callback: callable = None,

                                        ) -> TypeInputFile:
    res = (await _internal_transfer_to_telegram(client, file, progress_callback))[0]
    return res



await download_file(client, msg.document, file, progress_callback=prog)


result = await parallel_transfer_to_telegram(client, file, progress_callback=prog)
await client.send_file(event.chat_id, file=result)


如果您使用机器人帐户,DC ID 可能会搞砸,因此您需要在调用 .start() 后立即执行此操作:

if you are using a bot account the DC ID might be messed up so you would need to do this just after calling .start() :

config = await client(functions.help.GetConfigRequest())
for option in config.dc_options:
    if option.ip_address == client.session.server_address:
        if client.session.dc_id != option.id:
            log.warning(f"Fixed DC ID in session from {client.session.dc_id} to {option.id}")
        client.session.set_dc(option.id, option.ip_address, option.port)

