""" 异步数据汇聚与并行计算框架 - 精简优化版 """
import asyncio
import time
import logging
import json
import os
import random
import sys
from typing import Dict, Any, Tuple
from dataclasses import dataclass, asdict
from datetime import datetime
from multiprocessing import Process, Queue as MPQueue
import multiprocessing
@dataclass
class Config:
NUM_SITES: int = 5
NUM_PREDICT_WORKERS: int = 2
AGGREGATE_TIMEOUT: int = 30
RESULT_DIR: str = "results"
DATA_GEN_INTERVAL: int = 10
def __post_init__(self):
self.SITE_IDS = [f"site_{i:03d}" for i in range(1, self.NUM_SITES + 1)]
config = Config()
def setup_main_logger():
logger = logging.getLogger("Main")
logger.setLevel(logging.INFO)
if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s"))
logger.addHandler(handler)
fh = logging.FileHandler("main.log", encoding='utf-8')
fh.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(message)s"))
logger.addHandler(fh)
logger.propagate = False
return logger
def setup_worker_logger(worker_id: int):
logger = logging.getLogger(f"Worker-{worker_id}")
logger.setLevel(logging.INFO)
if not logger.handlers:
fh = logging.FileHandler(f"worker_{worker_id}.log", encoding='utf-8')
fh.setFormatter(logging.Formatter(f"[W{worker_id}] %(asctime)s [%(levelname)s] %(message)s"))
logger.addHandler(fh)
logger.propagate = False
return logger
logger = setup_main_logger()
@dataclass
class WeatherData:
site_id: str
temperature: float
irradiance: int
timestamp: float
@classmethod
def generate(cls, site_id: str) -> 'WeatherData':
random.seed(hash(site_id) % 1000)
return cls(site_id, round(20 + random.uniform(0, 10), 1), 800 + random.randint(0, 200), time.time())
@dataclass
class LoadData:
site_id: str
power_kw: float
timestamp: float
@classmethod
def generate(cls, site_id: str) -> 'LoadData':
random.seed(hash(site_id) % 1000 + 1000)
return cls(site_id, round(100 + random.uniform(0, 50), 1), time.time())
class AsyncDataGenerator:
def __init__(self, config: Config):
self.config = config
async def generate_weather(self, site_id: str) -> WeatherData:
await asyncio.sleep(1.0)
return WeatherData.generate(site_id)
async def generate_load(self, site_id: str) -> LoadData:
await asyncio.sleep(0.1)
return LoadData.generate(site_id)
async def generate_batch(self) -> Tuple[int, Dict[str, WeatherData], Dict[str, LoadData]]:
"""批量生成天气/负荷数据
通过 asyncio.gather 实现并行 I/O 模拟,返回按站点分组的字典结构"""
batch_ts = int(time.time())
logger.info(f"开始生成批次 {batch_ts}...")
try:
weather_tasks = [self.generate_weather(sid) for sid in self.config.SITE_IDS]
load_tasks = [self.generate_load(sid) for sid in self.config.SITE_IDS]
weather_results = await asyncio.gather(*weather_tasks, return_exceptions=True)
load_results = await asyncio.gather(*load_tasks, return_exceptions=True)
weather_dict = {sid: res for sid, res in zip(self.config.SITE_IDS, weather_results) if not isinstance(res, Exception)}
load_dict = {sid: res for sid, res in zip(self.config.SITE_IDS, load_results) if not isinstance(res, Exception)}
logger.info(f"批次 {batch_ts} 生成完成")
return batch_ts, weather_dict, load_dict
except Exception as e:
logger.error(f"生成批次失败:{e}")
return batch_ts, {}, {}
class AsyncAggregator:
def __init__(self, config: Config, predict_queue: MPQueue):
self.config = config
self.predict_queue = predict_queue
self.pending = {}
self.running = True
async def run(self, weather_queue, load_queue):
logger.info("汇聚器启动")
asyncio.create_task(self._cleanup_loop())
while self.running:
await self._process_queue_once(weather_queue, 'weather')
await self._process_queue_once(load_queue, 'load')
await asyncio.sleep(0.01)
logger.info("汇聚器停止")
async def _process_queue_once(self, queue, data_type: str):
"""单次队列处理逻辑
采用 (batch_ts, site_id) 作为聚合键,超时自动清理机制保障内存安全"""
loop = asyncio.get_running_loop()
try:
item = await asyncio.wait_for(loop.run_in_executor(None, queue.get_nowait), timeout=0.01)
batch_ts, site_id, data = item
key = (batch_ts, site_id)
if key not in self.pending:
self.pending[key] = {'_ts': time.time(), 'weather': None, 'load': None}
self.pending[key][data_type] = data
if self.pending[key]['weather'] and self.pending[key]['load']:
task = {
"batch_ts": batch_ts,
"site_id": site_id,
"weather": asdict(self.pending[key]['weather']),
"load": asdict(self.pending[key]['load'])
}
await loop.run_in_executor(None, lambda: self.predict_queue.put(task))
logger.info(f"聚合完成:{batch_ts}/{site_id}")
del self.pending[key]
except Exception:
pass
async def _cleanup_loop(self):
"""定期清理超时 pending 项,防止内存泄漏"""
while self.running:
await asyncio.sleep(10)
now = time.time()
stale = [k for k, v in self.pending.items() if now - v['_ts'] > self.config.AGGREGATE_TIMEOUT]
for k in stale:
del self.pending[k]
if stale:
logger.warning(f"清理 {len(stale)} 个超时项")
def stop(self):
self.running = False
def prediction_worker(worker_id: int, config: Config, predict_queue: MPQueue):
"""模拟 3 秒计算,输出带 worker_id 的预测结果,持久化到站点专属文件"""
random.seed(time.time() + worker_id)
log = setup_worker_logger(worker_id)
log.info(f"计算进程{worker_id}启动")
count = 0
try:
while True:
try:
task = predict_queue.get(timeout=1)
if task is None:
break
time.sleep(3)
result = {
"batch_ts": task["batch_ts"],
"site_id": task["site_id"],
"forecast_power_kw": round(task["load"]["power_kw"] * (1 + random.uniform(-0.1, 0.1)), 2),
"weather_temp": task["weather"]["temperature"],
"compute_seconds": 3.0,
"worker_id": worker_id,
"created_at": datetime.now().isoformat()
}
with open(os.path.join(config.RESULT_DIR, f"{task['site_id']}.jsonl"), "a", encoding='utf-8') as f:
f.write(json.dumps(result, ensure_ascii=False) + "\n")
log.info(f"完成 {task['site_id']}")
count += 1
except Exception as e:
if "timeout" not in str(e).lower():
log.error(f"异常:{e}")
finally:
log.info(f"退出,共处理 {count} 项")
class DataPipeline:
def __init__(self, config: Config):
self.config = config
self.workers = []
self.weather_queue = asyncio.Queue(maxsize=1000)
self.load_queue = asyncio.Queue(maxsize=1000)
self.predict_queue = MPQueue(maxsize=1000)
self.aggregator = None
self.running = True
def start_workers(self):
for i in range(self.config.NUM_PREDICT_WORKERS):
p = Process(target=prediction_worker, args=(i+1, self.config, self.predict_queue), daemon=True)
p.start()
self.workers.append(p)
logger.info(f"启动 Worker-{i+1}, PID={p.pid}")
async def data_generator(self):
gen = AsyncDataGenerator(self.config)
while self.running:
batch_ts, weathers, loads = await gen.generate_batch()
for sid in self.config.SITE_IDS:
if sid in weathers:
await self.weather_queue.put((batch_ts, sid, weathers[sid]))
if sid in loads:
await self.load_queue.put((batch_ts, sid, loads[sid]))
await asyncio.sleep(self.config.DATA_GEN_INTERVAL)
async def run(self):
"""主协调器,启动数据生成协程、汇聚协程和预测工作进程,统一调度生命周期"""
self.start_workers()
self.aggregator = AsyncAggregator(self.config, self.predict_queue)
tasks = [
asyncio.create_task(self.data_generator()),
asyncio.create_task(self.aggregator.run(self.weather_queue, self.load_queue))
]
try:
await asyncio.gather(*tasks)
except asyncio.CancelledError:
pass
finally:
self.running = False
if self.aggregator:
self.aggregator.stop()
def cleanup(self):
"""进程级资源回收,优雅退出:发送 None 终止信号,超时强制 kill 保障退出可靠性"""
logger.info("清理资源...")
for _ in range(len(self.workers)):
try:
self.predict_queue.put_nowait(None)
except:
pass
for p in self.workers:
p.join(timeout=3)
if p.is_alive():
p.terminate()
p.join(2)
logger.info("清理完成")
def main():
pipeline = DataPipeline(config)
try:
asyncio.run(pipeline.run())
except KeyboardInterrupt:
logger.info("收到中断信号")
finally:
pipeline.cleanup()
if __name__ == "__main__":
if sys.platform == "win32":
multiprocessing.freeze_support()
main()