公司内部的一个图片网站有一个支持批量下载的小需求,由于不同的图片存放在不同的oss远程服务器上,前端直接请求图片链接可能引发跨域问题,因此需要前端先请求后端接口,由后端下载资源后再输出给前端。为了图省事,后端(用php搭建的)的下载接口一次只能下载一个文件,因此当运营需要批量下载多个图片时前端需要调用多次后端接口(而且不是同时调用而是串行调用),对于文件下载量少,单个文件大小较小的批量下载而言勉强能够使用。直到有一次,有个运营批量下载100张图片,每张图片平均10M的大小,最后花了1个小时还没有下载完。
经反馈,需要进行以下优化:前端只需要发起1次请求(下载任务)到PHP后端,传递一个json参数包含要下载的所有文件的文件名和链接。后端接收到请求后将下载任务信息存放到redis的总任务队列中。于此同时,开启一个常驻的python进程进行从队列中接收下载任务且并发执行下载(从oss服务器下载到后端所在的服务器)。每一个任务对应一个任务task_id(由php端生成),当该批量任务完成时会发送一个消息到以task_id为key的消息队列以通知php端下载完成,然后由php端对后端中下载好的文件打包再输出给前端。目测下载100个平均大小为10M的文件从1小时缩减到三到四分钟。
代码如下:
# D:\wamp64\www\SVR_ZK_Center\py_script\main\batch_download.py
# coding=utf-8
import main
import asyncio, sys
from pkg.batch_download.Task import Task
from pkg.batch_download.config import CFG
# 允许的批量下载的并发用户数
default_batch_downloader = CFG.get("default_batch_downloader", 5);
async def main():
max_batch_downloader = default_batch_downloader if len(sys.argv) < 3 else sys.argv[2]
task = Task(max_batch_downloader)
cor_get_task = asyncio.ensure_future(task.get_task()) # 协程1:从redis队列中获取下载任务并放入到本地下载队列
cor_batch_downloads = [asyncio.ensure_future(task.batch_download()) for i in range(max_batch_downloader)] # 并发多个协程:从本地队列中获取下载任务并下载
await asyncio.gather(cor_get_task, *cor_batch_downloads) # 这里会一直阻塞直到有下载任务过来,任务完成后继续阻塞
loop = asyncio.get_event_loop() # 创建事件循环
loop.run_until_complete(main())
最主要是这个文件,并发下载的逻辑都在里面
# D:\wamp64\www\SVR_ZK_Center\py_script\pkg\batch_download\Task.py
# coding=utf-8
from pkg.common.config.redisCfg import redisCfg
import sys, aioredis, asyncio, json, aiohttp, os, functools
from asyncio import Queue, Semaphore
from pkg.common.utils.logs import Logs
from pkg.common.utils.file_tool import FileTool
from .config import CFG
from pkg.common.utils.common import add_callback_async
class Task:
redis_queue_key = "batch_download_tasks"
redis_resp_key_prefix = "batch_download_resp"
redis = None # 用于接收批量下载任务请求的redis连接
redis_notify = None # 用于通知客户端下载结果的redis连接,之所以要设定redis和redis_notify是因为redis基本上一直在brpop阻塞,如果不生成新的redis连接而直接用self.redis执行其他命令(如get,lpush)会一直阻塞直到brpop命令有响应
tcpConnNum = 20 # 最大20个并发连接
def __init__(self, max_queue_size):
self.loop = asyncio.get_event_loop()
self.cfg, self.redisCfg = self.__build_cfg()
self.tasks_queue = Queue(max_queue_size) # 下载任务池,结构为 {"task_id":"唯一的任务id", "task_info":[{"fn":"文件名","url":"文件url"}, {...}, ...]}
self.log_tool = Logs(self.cfg["log_dir"])
self.file_tool = FileTool()
self.sem = Semaphore(100) # 允许的最大并发下载个数
self.session = None # 连接标识
self.file_dir = ""
@classmethod
def __build_cfg(cls): # 获取redis配置和下载文件目录的相关配置
env = "local" if len(sys.argv) < 2 else sys.argv[1]
redis_cfg = redisCfg.get(env, redisCfg.get("local"))
cfg = CFG.get(env, CFG.get("local"))
format_redis_cfg = {
"address":(redis_cfg["host"], redis_cfg["port"]),
"db": redis_cfg["db"],
"password": redis_cfg["password"],
"encoding": "utf-8"
}
return cfg, format_redis_cfg
@classmethod
def __get_task_resp_key(cls, task_id):
return cls.redis_resp_key_prefix + "_" + task_id
async def notify(self, task_id, msg): # 单个文件下载完成后的回调函数,用于通知用户单个文件下载完毕
key = self.__get_task_resp_key(task_id)
res = await self.redis_notify.execute("lpush", key, msg) # 通知用户 fn 已经下载完毕
# 从redis队列中获取批量下载任务
async def get_task(self):
try:
self.redis = await aioredis.create_connection(**self.redisCfg) # 连接redis(用于接收批量下载任务)
self.redis_notify = await aioredis.create_connection(**self.redisCfg) # 连接redis(用于响应客户端,即php端)
while True:
# 阻塞,直到有下载任务过来
_, task_str = await self.redis.execute("brpop", self.redis_queue_key, 0)
task = json.loads(task_str) if isinstance(task_str, str) else task_str
task_id = task.get("task_id", "")
if task_id == "":
await self.log_tool.aio_error("empty task_id")
continue
await self.tasks_queue.put(task) # 将任务塞到本地任务队列(任务池)中
await self.redis.execute("lpush", self.__get_task_resp_key(task_id), "start") # 通知请求发起者下载任务开始
except BaseException as e:
# 记录错误日志(异步)
await self.log_tool.aio_error("method get_task | 错误信息:%s" % repr(e))
self.redis.close()
self.redis_notify.close()
await self.redis.wait_closed()
# 从任务池中取出任务并执行批量下载
async def batch_download(self):
while True:
try:
batch_task = await self.tasks_queue.get()
task_id, task_infos = batch_task["task_id"], batch_task["task_info"]
# 创建目录
self.file_dir = self.cfg["file_dir"].rstrip("/") + "/" + task_id + "/"
self.file_tool.touchDir(self.file_dir)
# 创建连接池
connector = aiohttp.TCPConnector(ssl=False, limit=self.tcpConnNum)
async with aiohttp.ClientSession(connector=connector) as self.session:
# 封装单个下载任务
futures = []
for t in task_infos:
cor = self.download(t["url"], t["fn"])
cb_args = {"task_id" : task_id, "msg" : t["fn"]}
# 将notify方法作为单个下载完毕后的回调函数
future = asyncio.ensure_future(add_callback_async(cor, self.notify, **cb_args))
futures.append(future)
await asyncio.gather(*futures) # 并发的下载,并等待里面所有任务的完成
# 发送通知给请求方,压缩的任务交给请求方,因为如果在这里进行压缩会阻塞所有下载任务
await self.notify(task_id, "end")
except BaseException as e:
await self.log_tool.aio_error("method batch_download | 任务id:%s | 错误信息:%s" % (task_id, repr(e)))
break
# 执行单个下载
async def download(self, url, fn):
try:
sem_acquire = False
# 判断文件名是否有后缀
if fn.find(".") == -1:
url_base_name = os.path.basename(url)
if url_base_name.find(".") != -1: # url有文件后缀名
ext = url_base_name.split(".")[-1]
fn = fn + "." + ext
fp = self.file_dir.rstrip("/") + "/" + fn
sem_acquire = await self.sem.acquire() # 限制单个下载任务的并发数
await self.log_tool.aio_info("文件:%s 开始下载(%s)" % (fn, url))
async with self.session.get(url) as resp:
if resp.status == 200:
content = await resp.read()
await self.file_tool.writeBlobAsync(fp, content) # 将下载的文件异步写入磁盘
else:
await self.log_tool.aio_error("链接 %s 下载失败" % url)
self.sem.release()
except BaseException as e:
if sem_acquire:
self.sem.release()
await self.log_tool.aio_error("method download | url:%s, 文件名:%s | 错误信息:%s" % (url, fn, repr(e)))
下面这些文件都是些通用功能,可以略过:
# D:\wamp64\www\SVR_ZK_Center\py_script\pkg\common\utils\common.py
# coding=utf-8
# 通用函数
# 该函数用于为future对象添加异步的回调,该函数是为了弥补 future.add_done_callback() 方法无法注册异步回调而只能注册普通回调的缺陷
async def add_callback_async(fut, cb, **kwargs): # cb是一个异步的回调函数
result = await fut # 执行future对象中的任务
await cb(**kwargs)
return result # 返回future执行的结果
# D:\wamp64\www\SVR_ZK_Center\py_script\pkg\common\utils\file_tool.py
# coding=utf-8
import aiofiles, os
# 文件操作类
class FileTool:
@classmethod
def writeBlob(cls, fp, cont): # 写入二进制文件
with open(fp, mode="wb") as f:
f.write(cont)
@classmethod
async def writeBlobAsync(cls, fp, cont): # 异步写入二进制文件
async with aiofiles.open(fp, mode="wb") as f:
await f.write(cont)
@classmethod
def touchDir(cls, dp): # 如果不存在目录则创建
if not os.path.exists(dp):
os.makedirs(dp)
# D:\wamp64\www\SVR_ZK_Center\py_script\pkg\common\utils\logs.py
# coding=utf-8
import time, aiofiles, os
# 日志记录
class Logs:
# LEVEL_INFO = 1 # 错误级别 信息级
# LEVEL_ERROR = 2 # 错误级别 错误级
LEVEL_INFO = "info" # 错误级别为信息级时的目录文件名
LEVEL_ERROR = "err" # 错误级别为信息级时的目录文件名
def __init__(self, log_dir):
self.log_dir = self.__get_log_dir(log_dir)
self.log_file_name = time.strftime("%Y-%m-%d", time.localtime()) + ".log"
self.log_path = self.log_dir + self.log_file_name
@classmethod
def __format_log(cls, cont, level):
return time.strftime("%Y-%m-%d", time.localtime()) + " | " + level + " | " + cont + "\n"
@classmethod
def __get_log_dir(cls, log_dir):
log_dir = log_dir.rstrip("/") + "/"
if not os.path.exists(log_dir):
os.makedirs(log_dir)
return log_dir
def info(self, cont):
line = self.__format_log(cont, self.LEVEL_INFO)
with open(self.log_path, "a", encoding="utf-8") as f:
f.write(line)
def error(self, cont):
line = self.__format_log(cont, self.LEVEL_ERROR)
with open(self.log_path, "a", encoding="utf-8") as f:
f.write(line)
async def aio_info(self, cont):
line = self.__format_log(cont, self.LEVEL_INFO)
async with aiofiles.open(self.log_path, mode="a") as f:
await f.write(line)
async def aio_error(self, cont):
line = self.__format_log(cont, self.LEVEL_INFO)
async with aiofiles.open(self.log_path, mode="a") as f:
await f.write(line)