CREATE TABLE device_energy_consumption (
device_id STRING TAG COMMENT '设备 ID',
device_type STRING TAG COMMENT '设备类型',
user_id STRING TAG COMMENT '用户 ID',
area_code STRING TAG COMMENT '区域编码',
power DOUBLE FIELD COMMENT '实时功率(W)',
energy DOUBLE FIELD COMMENT '累计能耗(kWh)',
run_status BOOLEAN FIELD COMMENT '运行状态',
collect_time TIMESTAMP COMMENT '采集时间'
) ENGINE=InfluxDB DEFAULT CHARSET=utf8mb4;
CREATE TABLE weather_data (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
area_code STRING NOT NULL COMMENT '区域编码',
temperature DOUBLE NOT NULL COMMENT '温度(℃)',
humidity DOUBLE NOT NULL COMMENT '湿度(%)',
weather_type STRING NOT NULL COMMENT '天气类型',
forecast_time TIMESTAMP NOT NULL COMMENT '预报时间',
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
INDEX idx_area_forecast (area_code, forecast_time)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
CREATE TABLE electricity_price (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
area_code STRING NOT NULL COMMENT '区域编码',
hour INT NOT NULL COMMENT '小时(0-23)',
price_type TINYINT NOT NULL COMMENT '电价类型',
price DOUBLE NOT NULL COMMENT '电价(元/kWh)',
effective_date DATE NOT NULL COMMENT '生效日期',
expire_date DATE COMMENT '失效日期',
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
UNIQUE KEY uk_area_hour_date (area_code, hour, effective_date)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
CREATE TABLE energy_forecast_result (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
user_id STRING NOT NULL COMMENT '用户 ID',
forecast_date DATE NOT NULL COMMENT '预测日期',
forecast_hour INT NOT NULL COMMENT '预测小时',
total_energy DOUBLE NOT NULL COMMENT '预测总能耗(kWh)',
accuracy DOUBLE NOT NULL COMMENT '预测精度(%)',
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
INDEX idx_user_date (user_id, forecast_date)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
package com.qingyunjiao.smarthome.energy.forecast;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.PostConstruct;
import java.util.ArrayList;
import java.util.List;
@Service
public class EnergyForecastService {
private static final Logger log = LoggerFactory.getLogger(EnergyForecastService.class);
@Autowired
private SparkSession sparkSession;
@Value("${smarthome.model.energy-forecast-path}")
private String modelPath;
@Value("${smarthome.model.linear-weight:0.4}")
private double linearWeight;
@Value("${smarthome.model.lstm-weight:0.6}")
private double lstmWeight;
private PipelineModel forecastModel;
@PostConstruct
public void initModel() {
long startTime = System.currentTimeMillis();
try {
forecastModel = PipelineModel.load(modelPath);
log.info("能耗预测模型加载完成,耗时:{}ms", System.currentTimeMillis() - startTime);
} catch (Exception e) {
log.error("能耗预测模型加载失败", e);
throw new RuntimeException("能耗预测服务初始化失败", e);
}
}
public List<EnergyForecastVO> forecast24HourEnergy(String userId) {
log.info("开始预测用户{}未来 24 小时能耗", maskUserId(userId));
long startTime = System.currentTimeMillis();
try {
Dataset<Row> featureData = loadFeatureData(userId);
Dataset<Row> predictResult = forecastModel.transform(featureData);
Dataset<Row> fusedResult = fusePredictResult(predictResult);
List<EnergyForecastVO> result = processPredictResult(fusedResult, userId);
double totalEnergy = result.stream().mapToDouble(EnergyForecastVO::getHourlyEnergy).sum();
log.info("预测完成,总能耗:{}kWh,耗时:{}ms", totalEnergy, System.currentTimeMillis() - startTime);
return result;
} catch (Exception e) {
log.error("预测失败", e);
throw new RuntimeException("能耗预测失败", e);
}
}
private Dataset<Row> loadFeatureData(String userId) {
String energySql = String.format(
"SELECT hour(collect_time) AS hour, dayofweek(collect_time) AS weekday, " +
"AVG(power) AS avg_power, SUM(energy) AS daily_energy, " +
"DATEDIFF(current_date(), MAX(device_install_time)) AS device_age_days " +
"FROM hive_db.device_energy_consumption WHERE user_id = '%s' AND collect_time >= date_sub(current_date(), 90) " +
"GROUP BY hour(collect_time), dayofweek(collect_time), device_type",
userId
);
Dataset<Row> energyData = sparkSession.sql(energySql).withColumnRenamed("device_age_days", "device_age").cache();
String weatherSql = String.format(
"SELECT hour(forecast_time) AS hour, temperature, humidity, " +
"CASE weather_type WHEN '晴' THEN 1 WHEN '阴' THEN 2 WHEN '雨' THEN 3 WHEN '雪' THEN 4 ELSE 0 END AS weather_type_code " +
"FROM mysql_db.weather_data WHERE area_code = (SELECT area_code FROM mysql_db.user_info WHERE user_id = '%s') " +
"AND date(forecast_time) = current_date() + 1",
userId
);
Dataset<Row> weatherData = sparkSession.sql(weatherSql).cache();
String priceSql = String.format(
"SELECT hour, price_type, price FROM mysql_db.electricity_price " +
"WHERE area_code = (SELECT area_code FROM mysql_db.user_info WHERE user_id = '%s') " +
"AND effective_date <= current_date() AND (expire_date IS NULL OR expire_date >= current_date())",
userId
);
Dataset<Row> priceData = sparkSession.sql(priceSql).cache();
Dataset<Row> mergedData = energyData.join(weatherData, "hour", "inner")
.join(priceData, "hour", "inner")
.dropDuplicates("hour", "device_type")
.withColumn("is_peak_hour", functions.when(functions.col("price_type").equalTo(2), 1).otherwise(0))
.withColumn("is_weekend", functions.when(functions.col("weekday").isin(1, 7), 1).otherwise(0))
.withColumn("temp_hum_ratio", functions.col("temperature").divide(functions.col("humidity")))
.withColumn("power_price_ratio", functions.col("avg_power").divide(functions.col("price")));
org.apache.spark.ml.feature.VectorAssembler assembler = new org.apache.spark.ml.feature.VectorAssembler()
.setInputCols(new String[]{"hour", "weekday", "avg_power", "device_age", "temperature", "humidity",
"weather_type_code", "price_type", "is_peak_hour", "is_weekend", "temp_hum_ratio", "power_price_ratio", "daily_energy"})
.setOutputCol("features");
Dataset<Row> featureData = assembler.transform(mergedData);
energyData.unpersist();
weatherData.unpersist();
priceData.unpersist();
return featureData;
}
private Dataset<Row> fusePredictResult(Dataset<Row> predictResult) {
return predictResult.withColumn("prediction",
functions.col("linear_prediction").multiply(linearWeight).plus(functions.col("lstm_prediction").multiply(lstmWeight)))
.withColumn("accuracy",
functions.col("linear_accuracy").multiply(linearWeight).plus(functions.col("lstm_accuracy").multiply(lstmWeight)));
}
private List<EnergyForecastVO> processPredictResult(Dataset<Row> fusedResult, String userId) {
Dataset<Row> hourlyResult = fusedResult.groupBy("hour").agg(
functions.sum("prediction").alias("hourly_energy"),
functions.avg("accuracy").alias("accuracy")
).orderBy("hour").cache();
List<EnergyForecastVO> result = hourlyResult.toJavaRDD().map(row -> {
EnergyForecastVO vo = new EnergyForecastVO();
vo.setUserId(userId);
vo.setForecastDate(sparkSession.sql("SELECT current_date() + 1").first().getString(0));
vo.setForecastHour(row.getInt(row.fieldIndex("hour")));
vo.setHourlyEnergy(roundToTwoDecimal(row.getDouble(row.fieldIndex("hourly_energy"))));
return vo;
}).collectAsList();
return result;
}
private String maskUserId(String userId) {
if (userId == null || userId.length() < 5) return userId;
return userId.substring(0, 3) + "****" + userId.substring(userId.length() - 3);
}
private double roundToTwoDecimal(double value) {
return Math.round(value * 100.0) / 100.0;
}
private void cacheForecastResult(String userId, List<EnergyForecastVO> result) {
}
}