张柏沛的个人IT技术博客-专注和分享PHP建站和Python技术的学习博客

正文内容

从IO模型到协程(七) asyncio协程+redis队列高并发批量下载文件

栏目:Python 系列:从IO模型到协程系列 发布时间:2021-06-10 19:14:52 浏览量:114

公司内部的一个图片网站有一个支持批量下载的小需求,由于不同的图片存放在不同的oss远程服务器上,前端直接请求图片链接可能引发跨域问题,因此需要前端先请求后端接口,由后端下载资源后再输出给前端。为了图省事,后端(用php搭建的)的下载接口一次只能下载一个文件,因此当运营需要批量下载多个图片时前端需要调用多次后端接口(而且不是同时调用而是串行调用),对于文件下载量少,单个文件大小较小的批量下载而言勉强能够使用。直到有一次,有个运营批量下载100张图片,每张图片平均10M的大小,最后花了1个小时还没有下载完。

经反馈,需要进行以下优化:前端只需要发起1次请求(下载任务)到PHP后端,传递一个json参数包含要下载的所有文件的文件名和链接。后端接收到请求后将下载任务信息存放到redis的总任务队列中。于此同时,开启一个常驻的python进程进行从队列中接收下载任务且并发执行下载(从oss服务器下载到后端所在的服务器)。每一个任务对应一个任务task_id(由php端生成),当该批量任务完成时会发送一个消息到以task_id为key的消息队列以通知php端下载完成,然后由php端对后端中下载好的文件打包再输出给前端。目测下载100个平均大小为10M的文件从1小时缩减到三到四分钟。

代码如下:

 

# D:\wamp64\www\SVR_ZK_Center\py_script\main\batch_download.py

# coding=utf-8

import main
import asyncio, sys
from pkg.batch_download.Task import Task
from pkg.batch_download.config import CFG

# 允许的批量下载的并发用户数
default_batch_downloader = CFG.get("default_batch_downloader", 5);
async def main():
    max_batch_downloader = default_batch_downloader if len(sys.argv) < 3 else sys.argv[2]
    task = Task(max_batch_downloader)
    cor_get_task = asyncio.ensure_future(task.get_task())      # 协程1:从redis队列中获取下载任务并放入到本地下载队列
    cor_batch_downloads = [asyncio.ensure_future(task.batch_download()) for i in range(max_batch_downloader)]  # 并发多个协程:从本地队列中获取下载任务并下载
    await asyncio.gather(cor_get_task, *cor_batch_downloads)        # 这里会一直阻塞直到有下载任务过来,任务完成后继续阻塞

loop = asyncio.get_event_loop()     # 创建事件循环
loop.run_until_complete(main())

 

最主要是这个文件,并发下载的逻辑都在里面

# D:\wamp64\www\SVR_ZK_Center\py_script\pkg\batch_download\Task.py

# coding=utf-8

from pkg.common.config.redisCfg import redisCfg
import sys, aioredis, asyncio, json, aiohttp, os, functools
from asyncio import Queue, Semaphore
from pkg.common.utils.logs import Logs
from pkg.common.utils.file_tool import FileTool
from .config import CFG
from pkg.common.utils.common import add_callback_async

class Task:
    redis_queue_key = "batch_download_tasks"
    redis_resp_key_prefix = "batch_download_resp"
    redis = None            # 用于接收批量下载任务请求的redis连接
    redis_notify = None       # 用于通知客户端下载结果的redis连接,之所以要设定redis和redis_notify是因为redis基本上一直在brpop阻塞,如果不生成新的redis连接而直接用self.redis执行其他命令(如get,lpush)会一直阻塞直到brpop命令有响应
    tcpConnNum = 20         # 最大20个并发连接

    def __init__(self, max_queue_size):
        self.loop = asyncio.get_event_loop()
        self.cfg, self.redisCfg = self.__build_cfg()
        self.tasks_queue = Queue(max_queue_size)     # 下载任务池,结构为 {"task_id":"唯一的任务id", "task_info":[{"fn":"文件名","url":"文件url"}, {...}, ...]}
        self.log_tool = Logs(self.cfg["log_dir"])
        self.file_tool = FileTool()
        self.sem = Semaphore(100)   # 允许的最大并发下载个数
        self.session = None     # 连接标识
        self.file_dir = ""

    @classmethod
    def __build_cfg(cls):     # 获取redis配置和下载文件目录的相关配置
        env = "local" if len(sys.argv) < 2 else sys.argv[1]
        redis_cfg = redisCfg.get(env, redisCfg.get("local"))
        cfg = CFG.get(env, CFG.get("local"))

        format_redis_cfg = {
            "address":(redis_cfg["host"], redis_cfg["port"]),
            "db": redis_cfg["db"],
            "password": redis_cfg["password"],
            "encoding": "utf-8"
        }

        return cfg, format_redis_cfg

    @classmethod
    def __get_task_resp_key(cls, task_id):
        return cls.redis_resp_key_prefix + "_" + task_id

    async def notify(self, task_id, msg):    # 单个文件下载完成后的回调函数,用于通知用户单个文件下载完毕
        key = self.__get_task_resp_key(task_id)
        res = await self.redis_notify.execute("lpush", key, msg)  # 通知用户 fn 已经下载完毕

    # 从redis队列中获取批量下载任务
    async def get_task(self):
        try:
            self.redis = await aioredis.create_connection(**self.redisCfg)  # 连接redis(用于接收批量下载任务)
            self.redis_notify = await aioredis.create_connection(**self.redisCfg)  # 连接redis(用于响应客户端,即php端)
            while True:
                # 阻塞,直到有下载任务过来
                _, task_str = await self.redis.execute("brpop", self.redis_queue_key, 0)

                task = json.loads(task_str) if isinstance(task_str, str) else task_str
                task_id = task.get("task_id", "")
                if task_id == "":
                    await self.log_tool.aio_error("empty task_id")
                    continue

                await self.tasks_queue.put(task)    # 将任务塞到本地任务队列(任务池)中
                await self.redis.execute("lpush", self.__get_task_resp_key(task_id), "start")   # 通知请求发起者下载任务开始
        except BaseException as e:
            # 记录错误日志(异步)
            await self.log_tool.aio_error("method get_task | 错误信息:%s" % repr(e))
            self.redis.close()
            self.redis_notify.close()
            await self.redis.wait_closed()

    # 从任务池中取出任务并执行批量下载
    async def batch_download(self):
        while True:
            try:
                batch_task = await self.tasks_queue.get()
                task_id, task_infos = batch_task["task_id"], batch_task["task_info"]

                # 创建目录
                self.file_dir = self.cfg["file_dir"].rstrip("/") + "/" + task_id + "/"
                self.file_tool.touchDir(self.file_dir)

                # 创建连接池
                connector = aiohttp.TCPConnector(ssl=False, limit=self.tcpConnNum)
                async with aiohttp.ClientSession(connector=connector) as self.session:
                    # 封装单个下载任务
                    futures = []
                    for t in task_infos:
                        cor = self.download(t["url"], t["fn"])
                        cb_args = {"task_id" : task_id, "msg" : t["fn"]}

                        # 将notify方法作为单个下载完毕后的回调函数
                        future = asyncio.ensure_future(add_callback_async(cor, self.notify, **cb_args))
                        futures.append(future)

                    await asyncio.gather(*futures)      # 并发的下载,并等待里面所有任务的完成

                    # 发送通知给请求方,压缩的任务交给请求方,因为如果在这里进行压缩会阻塞所有下载任务
                    await self.notify(task_id, "end")
            except BaseException as e:
                await self.log_tool.aio_error("method batch_download | 任务id:%s | 错误信息:%s" % (task_id, repr(e)))
                break

    # 执行单个下载
    async def download(self, url, fn):
        try:
            sem_acquire = False
            # 判断文件名是否有后缀
            if fn.find(".") == -1:
                url_base_name = os.path.basename(url)
                if url_base_name.find(".") != -1:   # url有文件后缀名
                    ext = url_base_name.split(".")[-1]
                    fn = fn + "." + ext
            fp = self.file_dir.rstrip("/") + "/" + fn

            sem_acquire = await self.sem.acquire()        # 限制单个下载任务的并发数
            await self.log_tool.aio_info("文件:%s 开始下载(%s)" % (fn, url))
            async with self.session.get(url) as resp:
                if resp.status == 200:
                    content = await resp.read()
                    await self.file_tool.writeBlobAsync(fp, content)   # 将下载的文件异步写入磁盘
                else:
                    await self.log_tool.aio_error("链接 %s 下载失败" % url)
            self.sem.release()
        except BaseException as e:
            if sem_acquire:
                self.sem.release()
            await self.log_tool.aio_error("method download | url:%s, 文件名:%s | 错误信息:%s" % (url, fn, repr(e)))

 

下面这些文件都是些通用功能,可以略过:

# D:\wamp64\www\SVR_ZK_Center\py_script\pkg\common\utils\common.py

# coding=utf-8

# 通用函数

# 该函数用于为future对象添加异步的回调,该函数是为了弥补 future.add_done_callback() 方法无法注册异步回调而只能注册普通回调的缺陷
async def add_callback_async(fut, cb, **kwargs):    # cb是一个异步的回调函数
    result = await fut  # 执行future对象中的任务
    await cb(**kwargs)
    return result       # 返回future执行的结果

 

# D:\wamp64\www\SVR_ZK_Center\py_script\pkg\common\utils\file_tool.py

# coding=utf-8

import aiofiles, os

# 文件操作类
class FileTool:
    @classmethod
    def writeBlob(cls, fp, cont):   # 写入二进制文件
        with open(fp, mode="wb") as f:
            f.write(cont)

    @classmethod
    async def writeBlobAsync(cls, fp, cont):   # 异步写入二进制文件
        async with aiofiles.open(fp, mode="wb") as f:
            await f.write(cont)

    @classmethod
    def touchDir(cls, dp):  # 如果不存在目录则创建
        if not os.path.exists(dp):
            os.makedirs(dp)

 

# D:\wamp64\www\SVR_ZK_Center\py_script\pkg\common\utils\logs.py

# coding=utf-8

import time, aiofiles, os

# 日志记录
class Logs:
    # LEVEL_INFO = 1      # 错误级别 信息级
    # LEVEL_ERROR = 2     # 错误级别 错误级
    LEVEL_INFO = "info"    # 错误级别为信息级时的目录文件名
    LEVEL_ERROR = "err"  # 错误级别为信息级时的目录文件名

    def __init__(self, log_dir):
        self.log_dir = self.__get_log_dir(log_dir)
        self.log_file_name = time.strftime("%Y-%m-%d", time.localtime()) + ".log"
        self.log_path = self.log_dir + self.log_file_name

    @classmethod
    def __format_log(cls, cont, level):
        return time.strftime("%Y-%m-%d", time.localtime()) + " | " + level + " | " + cont + "\n"

    @classmethod
    def __get_log_dir(cls, log_dir):
        log_dir = log_dir.rstrip("/") + "/"
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        return log_dir

    def info(self, cont):
        line = self.__format_log(cont, self.LEVEL_INFO)
        with open(self.log_path, "a", encoding="utf-8") as f:
            f.write(line)

    def error(self, cont):
        line = self.__format_log(cont, self.LEVEL_ERROR)
        with open(self.log_path, "a", encoding="utf-8") as f:
            f.write(line)

    async def aio_info(self, cont):
        line = self.__format_log(cont, self.LEVEL_INFO)
        async with aiofiles.open(self.log_path, mode="a") as f:
            await f.write(line)

    async def aio_error(self, cont):
        line = self.__format_log(cont, self.LEVEL_INFO)
        async with aiofiles.open(self.log_path, mode="a") as f:
            await f.write(line)

 

如果您需要转载,可以点击下方按钮可以进行复制粘贴;本站博客文章为原创,请转载时注明以下信息

张柏沛IT技术博客 > 从IO模型到协程(七) asyncio协程+redis队列高并发批量下载文件

热门推荐