CREATE TABLE device_energy_consumption (
device_id STRING TAG COMMENT'设备 ID(脱敏,如 D2024****156)',
device_type STRING TAG COMMENT'设备类型(空调/热水器/充电桩/照明/传感器)',
user_id STRING TAG COMMENT'用户 ID(脱敏,如 U2024****156)',
area_code STRING TAG COMMENT'区域编码(如北京 110105)',
power DOUBLE FIELD COMMENT'实时功率(W)',
energy DOUBLE FIELD COMMENT'累计能耗(kWh)',
run_status BOOLEAN FIELD COMMENT'运行状态(true=运行,false=关闭)',
collect_time TIMESTAMP COMMENT'采集时间(精度到秒)'
) ENGINE=InfluxDB DEFAULT CHARSET=utf8mb4 COMMENT'设备实时能耗数据表';
CREATE TABLE weather_data (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
area_code STRING NOT NULL COMMENT'区域编码(如北京 110105)',
temperature DOUBLE NOT NULL COMMENT'温度(℃)',
humidity DOUBLE NOT NULL COMMENT'湿度(%)',
weather_type STRING NOT NULL COMMENT'天气类型(晴/雨/阴/雪)',
forecast_time TIMESTAMP NOT NULL COMMENT'预报时间',
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT'创建时间',
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT'更新时间',
INDEX idx_area_forecast (area_code, forecast_time) COMMENT'区域 + 预报时间索引,优化查询'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT'天气预报表';
CREATE TABLE electricity_price (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
area_code STRING NOT NULL COMMENT'区域编码(如北京 110105)',
hour INT NOT NULL COMMENT'小时(0-23)',
price_type TINYINT NOT NULL COMMENT'电价类型(0=谷电,1=平电,2=峰电)',
price DOUBLE NOT NULL COMMENT'电价(元/kWh)',
effective_date DATE NOT NULL COMMENT'生效日期',
expire_date DATE COMMENT'失效日期(NULL 表示永久有效)',
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT'创建时间',
UNIQUE KEY uk_area_hour_date (area_code, hour, effective_date) COMMENT'唯一索引,避免重复数据'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT'峰谷电价表';
CREATE TABLE energy_forecast_result (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
user_id STRING NOT NULL COMMENT'用户 ID(脱敏)',
forecast_date DATE NOT NULL COMMENT'预测日期',
forecast_hour INT NOT NULL COMMENT'预测小时(0-23)',
total_energy DOUBLE NOT NULL COMMENT'预测总能耗(kWh)',
aircon_energy DOUBLE NOT NULL COMMENT'空调预测能耗(kWh)',
water_heater_energy DOUBLE NOT NULL COMMENT'热水器预测能耗(kWh)',
charger_energy DOUBLE NOT NULL COMMENT'充电桩预测能耗(kWh)',
other_energy DOUBLE NOT NULL COMMENT'其他设备预测能耗(kWh)',
accuracy DOUBLE NOT NULL COMMENT'预测精度(%)',
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT'创建时间',
INDEX idx_user_date (user_id, forecast_date) COMMENT'用户 + 预测日期索引,优化查询'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT'能耗预测结果表';
package com.qingyunjiao.smarthome.energy.forecast;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.regression.LinearRegressionModel;
import org.apache.spark.ml.regression.LSTMRegressionModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.PostConstruct;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
@Service
public class EnergyForecastService {
private static final Logger log = LoggerFactory.getLogger(EnergyForecastService.class);
@Autowired
private SparkSession sparkSession;
@Value("${smarthome.model.energy-forecast-path}")
private String modelPath;
@Value("${smarthome.model.linear-weight:0.4}")
private double linearWeight;
@Value("${smarthome.model.lstm-weight:0.6}")
private double lstmWeight;
private PipelineModel forecastModel;
@PostConstruct
public void initModel() {
long startTime = System.currentTimeMillis();
try {
forecastModel = PipelineModel.load(modelPath);
log.info("能耗预测模型加载完成,模型路径:{},耗时:{}ms", modelPath, System.currentTimeMillis() - startTime);
} catch (Exception e) {
log.error("能耗预测模型加载失败,模型路径:{}", modelPath, e);
throw new RuntimeException("能耗预测服务初始化失败,请检查模型路径或联系管理员", e);
}
}
public List<EnergyForecastVO> forecast24HourEnergy(String userId) {
log.info("开始预测用户{}未来 24 小时能耗", maskUserId(userId));
long startTime = System.currentTimeMillis();
try {
Dataset<Row> featureData = loadFeatureData(userId);
Dataset<Row> predictResult = forecastModel.transform(featureData);
Dataset<Row> fusedResult = fusePredictResult(predictResult);
List<EnergyForecastVO> result = processPredictResult(fusedResult, userId);
double totalEnergy = result.stream().mapToDouble(EnergyForecastVO::getHourlyEnergy).sum();
log.info("用户{}未来 24 小时能耗预测完成,总能耗:{}kWh,耗时:{}ms,预测精度:{}%", maskUserId(userId), totalEnergy, System.currentTimeMillis() - startTime, result.get(0).getAccuracy());
cacheForecastResult(userId, result);
return result;
} catch (Exception e) {
log.error("用户{}未来 24 小时能耗预测失败", maskUserId(userId), e);
throw new RuntimeException("能耗预测失败,请稍后重试或联系管理员", e);
}
}
private Dataset<Row> loadFeatureData(String userId) {
String energySql = String.format("""
SELECT hour(collect_time) AS hour, -- 小时(0-23)
dayofweek(collect_time) AS weekday, -- 星期(1-7)
device_type, -- 设备类型
AVG(power) AS avg_power, -- 平均功率(W)
SUM(energy) AS daily_energy, -- 日能耗(kWh)
DATEDIFF(current_date(), MAX(device_install_time)) AS device_age_days -- 设备使用天数
FROM hive_db.device_energy_consumption
WHERE user_id = '%s' AND collect_time >= date_sub(current_date(), 90) -- 近 90 天数据
GROUP BY hour(collect_time), dayofweek(collect_time), device_type
""", userId);
Dataset<Row> energyData = sparkSession.sql(energySql).withColumnRenamed("device_age_days", "device_age").cache();
String weatherSql = String.format("""
SELECT hour(forecast_time) AS hour, -- 小时(0-23)
temperature, -- 温度(℃)
humidity, -- 湿度(%)
CASE weather_type WHEN '晴' THEN 1 WHEN '阴' THEN 2 WHEN '雨' THEN 3 WHEN '雪' THEN 4 ELSE 0 END AS weather_type_code -- 天气类型编码(便于模型处理)
FROM mysql_db.weather_data
WHERE area_code = (SELECT area_code FROM mysql_db.user_info WHERE user_id = '%s') AND date(forecast_time) = current_date() + 1 -- 未来 1 天(24 小时)
""", userId);
Dataset<Row> weatherData = sparkSession.sql(weatherSql).cache();
String priceSql = String.format("""
SELECT hour, -- 小时(0-23)
price_type, -- 电价类型(0=谷电,1=平电,2=峰电)
price -- 电价(元/kWh)
FROM mysql_db.electricity_price
WHERE area_code = (SELECT area_code FROM mysql_db.user_info WHERE user_id = '%s') AND effective_date <= current_date() AND (expire_date IS NULL OR expire_date >= current_date())
""", userId);
Dataset<Row> priceData = sparkSession.sql(priceSql).cache();
Dataset<Row> mergedData = energyData.join(weatherData, "hour", "inner")
.join(priceData, "hour", "inner")
.dropDuplicates("hour", "device_type")
.withColumn("is_peak_hour", functions.when(functions.col("price_type").equalTo(2), 1).otherwise(0))
.withColumn("is_weekend", functions.when(functions.col("weekday").isin(1, 7), 1).otherwise(0))
.withColumn("temp_hum_ratio", functions.col("temperature").divide(functions.col("humidity")))
.withColumn("power_price_ratio", functions.col("avg_power").divide(functions.col("price")));
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[]{"hour", "weekday", "avg_power", "device_age", "temperature", "humidity", "weather_type_code", "price_type", "is_peak_hour", "is_weekend", "temp_hum_ratio", "power_price_ratio", "daily_energy"})
.setOutputCol("features");
Dataset<Row> featureData = assembler.transform(mergedData);
energyData.unpersist();
weatherData.unpersist();
priceData.unpersist();
return featureData;
}
private Dataset<Row> fusePredictResult(Dataset<Row> predictResult) {
return predictResult.withColumn("prediction",
functions.col("linear_prediction").multiply(linearWeight).plus(functions.col("lstm_prediction").multiply(lstmWeight)))
.withColumn("accuracy",
functions.col("linear_accuracy").multiply(linearWeight).plus(functions.col("lstm_accuracy").multiply(lstmWeight)));
}
private List<EnergyForecastVO> processPredictResult(Dataset<Row> fusedResult, String userId) {
Dataset<Row> hourlyResult = fusedResult.groupBy("hour")
.agg(functions.sum("prediction").alias("hourly_energy"),
functions.avg("accuracy").alias("accuracy"))
.orderBy("hour")
.cache();
List<EnergyForecastVO> result = hourlyResult.toJavaRDD().map(row -> {
EnergyForecastVO vo = new EnergyForecastVO();
vo.setUserId(userId);
vo.setForecastDate(sparkSession.sql("SELECT current_date() + 1").first().getString(0));
vo.setForecastHour(row.getInt(row.fieldIndex("hour")));
vo.setHourlyEnergy(roundToTwoDecimal(row.getDouble(row.fieldIndex("hourly_energy"))));
return vo;
}).collect(Collectors.toList());
hourlyResult.unpersist();
return result;
}
private String maskUserId(String userId) {
if (userId == null || userId.length() < 6) return userId;
return userId.substring(0, 3) + "****" + userId.substring(userId.length() - 3);
}
private double roundToTwoDecimal(double value) {
return Math.round(value * 100.0) / 100.0;
}
private void cacheForecastResult(String userId, List<EnergyForecastVO> result) {
}
}