#include <iostream>
#include <tensorflow/core/platform/env.h>
#include <tensorflow/core/public/session.h>
#include <opencv2/opencv.hpp>
#include <libavcodec/avcodec.h>
#include <libavformat/avformat.h>
#include <libswscale/swscale.h>
using namespace tensorflow;
using namespace cv;
Status LoadModel(const std::string& model_path, std::unique_ptr<Session>& session) {
GraphDef graph_def;
Status status = ReadBinaryProto(Env::Default(), model_path, &graph_def);
if (!status.ok()) {
std::cerr << "Failed to read model file: " << status.ToString() << std::endl;
return status;
}
std::unique_ptr<Session> new_session;
status = NewSession(SessionOptions(), &new_session);
if (!status.ok()) {
std::cerr << "Failed to create new session: " << status.ToString() << std::endl;
return status;
}
status = new_session->Create(graph_def);
if (!status.ok()) {
std::cerr << "Failed to create graph in session: " << status.ToString() << std::endl;
return status;
}
session = std::move(new_session);
return Status::OK();
}
Tensor PreprocessInput(const std::string& prompt) {
Tensor input_tensor(DT_STRING, TensorShape());
input_tensor.scalar<std::string>()() = prompt;
return input_tensor;
}
Mat PreprocessImage(const std::string& image_path) {
Mat image = imread(image_path, IMREAD_COLOR);
if (image.empty()) {
std::cerr << "Failed to read image: " << image_path << std::endl;
return Mat();
}
resize(image, image, Size(256, 256));
image.convertTo(image, CV_32F, 1.0 / 255.0);
return image;
}
Tensor ImageToTensor(const Mat& image) {
int height = image.rows;
int width = image.cols;
int channels = image.channels();
Tensor input_tensor(DT_FLOAT, TensorShape({1, height, width, channels}));
auto input_tensor_mapped = input_tensor.tensor<float, 4>();
for (int y = 0; y < height; ++y) {
for (int x = 0; x < width; ++x) {
for (int c = 0; c < channels; ++c) {
input_tensor_mapped(0, y, x, c) = image.at<Vec3f>(y, x)[c];
}
}
}
return input_tensor;
}
Status RunModel(const std::unique_ptr<Session>& session, const Tensor& input_tensor, std::vector<Tensor>& outputs) {
std::vector<std::pair<std::string, Tensor>> inputs = {{"input", input_tensor}};
Status status = session->Run(inputs, {"output_video"}, {}, &outputs);
if (!status.ok()) {
std::cerr << "Failed to run model: " << status.ToString() << std::endl;
}
return status;
}
void ProcessOutput(const std::vector<Tensor>& outputs, const std::string& output_path) {
if (outputs.empty()) {
std::cerr << "No output from model." << std::endl;
return;
}
Tensor output_tensor = outputs[0];
int num_frames = output_tensor.dim_size(0);
int height = output_tensor.dim_size(1);
int width = output_tensor.dim_size(2);
int channels = output_tensor.dim_size(3);
av_register_all();
avformat_network_init();
AVFormatContext* format_context = nullptr;
avformat_alloc_output_context2(&format_context, nullptr, nullptr, output_path.c_str());
if (!format_context) {
std::cerr << "Failed to allocate output format context." << std::endl;
return;
}
AVStream* stream = avformat_new_stream(format_context, nullptr);
if (!stream) {
std::cerr << "Failed to create new stream." << std::endl;
return;
}
AVCodec* codec = avcodec_find_encoder(AV_CODEC_ID_H264);
if (!codec) {
std::cerr << "Failed to find H.264 encoder." << std::endl;
return;
}
AVCodecContext* codec_context = avcodec_alloc_context3(codec);
if (!codec_context) {
std::cerr << "Failed to allocate codec context." << std::endl;
return;
}
codec_context->codec_id = AV_CODEC_ID_H264;
codec_context->codec_type = AVMEDIA_TYPE_VIDEO;
codec_context->pix_fmt = AV_PIX_FMT_YUV420P;
codec_context->width = width;
codec_context->height = height;
codec_context->time_base = {1, 25};
codec_context->framerate = {25, 1};
if (avio_open(&format_context->pb, output_path.c_str(), AVIO_FLAG_WRITE) < 0) {
std::cerr << "Failed to open output file." << std::endl;
return;
}
if (avformat_write_header(format_context, nullptr) < 0) {
std::cerr << "Failed to write header." << std::endl;
return;
}
AVFrame* frame = av_frame_alloc();
if (!frame) {
std::cerr << "Failed to allocate frame." << std::endl;
return;
}
frame->format = codec_context->pix_fmt;
frame->width = codec_context->width;
frame->height = codec_context->height;
if (av_frame_get_buffer(frame, 0) < 0) {
std::cerr << "Failed to allocate frame buffer." << std::endl;
return;
}
SwsContext* sws_context = sws_getContext(width, height, AV_PIX_FMT_RGB24, width, height, AV_PIX_FMT_YUV420P, SWS_BILINEAR, nullptr, nullptr, nullptr);
if (!sws_context) {
std::cerr << "Failed to create SwsContext." << std::endl;
return;
}
for (int i = 0; i < num_frames; ++i) {
Tensor frame_tensor = output_tensor.Slice(i, i + 1);
auto frame_tensor_mapped = frame_tensor.tensor<float, 4>();
Mat frame_mat(height, width, CV_32FC3);
for (int y = 0; y < height; ++y) {
for (int x = 0; x < width; ++x) {
for (int c = 0; c < channels; ++c) {
frame_mat.at<Vec3f>(y, x)[c] = frame_tensor_mapped(0, y, x, c);
}
}
}
frame_mat.convertTo(frame_mat, CV_8UC3, 255.0);
const int stride[] = {static_cast<int>(frame_mat.step)};
sws_scale(sws_context, &frame_mat.data, stride, 0, height, frame->data, frame->linesize);
frame->pts = i;
AVPacket packet;
av_init_packet(&packet);
packet.data = nullptr;
packet.size = 0;
int ret = avcodec_send_frame(codec_context, frame);
if (ret < 0) {
std::cerr << "Error sending frame to encoder: " << av_err2str(ret) << std::endl;
continue;
}
while (ret >= 0) {
ret = avcodec_receive_packet(codec_context, &packet);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
break;
} else if (ret < 0) {
std::cerr << "Error receiving packet from encoder: " << av_err2str(ret) << std::endl;
break;
}
av_packet_rescale_ts(&packet, codec_context->time_base, stream->time_base);
packet.stream_index = stream->index;
ret = av_interleaved_write_frame(format_context, &packet);
if (ret < 0) {
std::cerr << "Error writing packet to output file: " << av_err2str(ret) << std::endl;
break;
}
av_packet_unref(&packet);
}
}
sws_freeContext(sws_context);
av_frame_free(&frame);
avcodec_free_context(&codec_context);
avio_closep(&format_context->pb);
avformat_free_context(format_context);
}
int main() {
std::string model_path = "path/to/model.pb";
std::unique_ptr<Session> session;
Status status = LoadModel(model_path, session);
if (!status.ok()) {
return 1;
}
std::string prompt = "美丽的森林中,阳光透过树叶洒在地面,小鸟在枝头歌唱";
Tensor input_tensor = PreprocessInput(prompt);
std::vector<Tensor> outputs;
status = RunModel(session, input_tensor, outputs);
if (!status.ok()) {
return 1;
}
std::string output_path = "text_to_video.mp4";
ProcessOutput(outputs, output_path);
std::string image_path = "path/to/image.jpg";
Mat image = PreprocessImage(image_path);
if (image.empty()) {
return 1;
}
Tensor image_tensor = ImageToTensor(image);
outputs.clear();
status = RunModel(session, image_tensor, outputs);
if (!status.ok()) {
return 1;
}
output_path = "image_to_video.mp4";
ProcessOutput(outputs, output_path);
return 0;
}