手写 C++ TCP 服务器实现自定义协议及解决粘包问题
介绍基于 C++ 实现 TCP 服务器的过程。由于 TCP 是字节流协议,存在粘包问题,需设计应用层协议(如长度前缀)来界定消息边界。文章详细讲解了序列化与反序列化的原理,提供了完整的请求响应结构定义、编码解码函数以及服务端和客户端的核心代码示例,展示了如何通过自定义协议确保数据完整传输。

介绍基于 C++ 实现 TCP 服务器的过程。由于 TCP 是字节流协议,存在粘包问题,需设计应用层协议(如长度前缀)来界定消息边界。文章详细讲解了序列化与反序列化的原理,提供了完整的请求响应结构定义、编码解码函数以及服务端和客户端的核心代码示例,展示了如何通过自定义协议确保数据完整传输。

在之前的网络编程学习中,我们了解了 UDP 和 TCP 的使用。UDP 基于数据报传输,而 TCP 基于字节流传输。TCP 无法保证读取上来的数据是一个完整的报文,因此需要制定应用层协议来确保数据的结构化。
TCP 是传输控制协议,主要负责发送时机、数据量及错误处理。read 和 write 系统调用将用户空间数据拷贝到内核空间的 TCP 缓冲区。由于 TCP 是面向字节流的,接收端可能遇到以下情况:
例如,客户端发送两个请求:
10 + 20
5 * 6
服务器可能收到:
10 + 20\n5 * 6\n
或者:
10 +
为了避免数据解析错误,必须设计应用层协议来界定消息边界。
主要模块如下:
| 模块 | 作用 |
|---|---|
| Socket | 封装 socket API |
| TcpServer | TCP 服务器框架 |
| Protocol | 协议封装 |
| CalculatorServer | 计算逻辑 |
采用长度前缀加换行符的格式:
len\n
content\n
例如:
7\n10 + 20\n
std::string Encode(std::string &content) {
std::string s;
size_t len = content.size();
s += std::to_string(len);
s += "\n";
s += content;
s += "\n";
return s;
}
bool Decode(std::string &s, std::string *content) {
size_t left_pos = s.find("\n");
if (left_pos == std::string::npos) return false;
std::string content_len = s.substr(0, left_pos);
int len = std::stoi(content_len);
if (s.size() < content_len.size() + len + 2) return false;
*content = s.substr(left_pos + 1, len);
s.erase(0, content_len.size() + len + 2);
return true;
}
Decode 做了三件事:判断是否有完整头部、判断数据是否完整、解析并移除已处理数据。
客户端发送:10 + 5
结构:
class request {
public:
int x_;
int y_;
char op_;
};
序列化:"10 + 5"
反序列化:string -> request
服务器返回:"15 0"
结构:
class response {
public:
int result_;
int code_;
};
code 含义:
| code | 含义 |
|---|---|
| 0 | 成功 |
| 1 | 除 0 |
| 2 | 取模 0 |
| 3 | 非法操作 |
bool Decode(std::string &s, std::string *content) {
size_t pos = s.find("\n");
if (pos == std::string::npos) return false;
int len = std::stoi(s.substr(0, pos));
if (s.size() < pos + 1 + len + 1) return false;
*content = s.substr(pos + 1, len);
s.erase(0, pos + 1 + len + 1);
return true;
}
std::string Calculator(std::string& s) {
std::string content;
if (Decode(s, &content) == false) { return ""; }
request req;
bool r = req.Deserialization(content);
if (!r) { return ""; }
response res = CalculatorHandler(req);
std::string ret = res.serialization();
ret = Encode(ret);
return ret;
}
如果收到的报文不能分解为完整格式,返回空字符串,服务器继续接收直到完整。
while (true) {
int sockfd = listenfd_.Accept(&client_port, &client_ip);
if (fork() == 0) {
while (1) {
char buffer[1280];
ssize_t s = read(sockfd, buffer, sizeof buffer - 1);
if (s > 0) {
inbuffer_stream += buffer;
while (true) {
std::string info = callback_(inbuffer_stream);
if (info.empty()) break;
write(sockfd, info.c_str(), info.size());
}
}
}
}
}
#include <string>
#define blank_sep " "
#define protocol_sep "\n"
class request {
public:
request(int x, int y, char op) : x_(x), y_(y), op_(op) { }
request() { }
~request() { }
std::string serialization() {
std::string str;
str += std::to_string(x_);
str += blank_sep;
str += op_;
str += blank_sep;
str += std::to_string(y_);
return str;
}
bool Deserialization(std::string &in) {
size_t leftpos = in.find(blank_sep);
if (leftpos == std::string::npos) { return false; }
std::string str_x = in.substr(0, leftpos);
x_ = std::stoi(str_x);
op_ = in[leftpos + 1];
size_t rightpos = in.rfind(blank_sep);
if (rightpos == std::string::npos) { return false; }
std::string str_y = in.substr(rightpos + 1);
y_ = std::stoi(str_y);
return true;
}
void DebugPrint() {
std::cout << "新请求构建完成:" << x_ << op_ << y_ << "=?" << std::endl;
}
public:
int x_;
int y_;
char op_;
};
class response {
public:
response() { }
response(int result, int code) : result_(result), code_(code) { }
~response() { }
std::string serialization() {
std::string str;
str += std::to_string(result_);
str += blank_sep;
str += std::to_string(code_);
return str;
}
bool Deserialization(std::string &in) {
size_t pos = in.find(blank_sep);
if (pos == std::string::npos) { return false; }
std::string str_result = in.substr(0, pos);
result_ = std::stoi(str_result);
std::string str_code = in.substr(pos + 1);
code_ = std::stoi(str_code);
return true;
}
void DebugPrint() {
std::cout << "结果响应完成,result: " << result_ << ", code: " << code_ << std::endl;
}
public:
int result_;
int code_;
};
std::string Encode(std::string &content) {
std::string s;
size_t len = content.size();
s += std::to_string(len);
s += protocol_sep;
s += content;
s += protocol_sep;
return s;
}
bool Decode(std::string &s, std::string *content) {
size_t left_pos = s.find(protocol_sep);
if (left_pos == std::string::npos) { return false; }
std::string content_len = s.substr(0, left_pos);
int len = std::stoi(content_len);
if (s.size() < content_len.size() + len + 2) { return false; }
*content = s.substr(left_pos + 1, len);
s.erase(0, content_len.size() + len + 2);
return true;
}
#include <iostream>
#include <unistd.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <string>
enum error { SocketErr = 2, BindErr, ListenErr, ConnectErr, };
class Socket {
public:
Socket() { sockfd_ = socket(AF_INET, SOCK_STREAM, 0); if (sockfd_ < 0) { std::cout << "socket fail" << std::endl; exit(SocketErr); } }
void Bind(uint16_t &port, std::string &ip) {
struct sockaddr_in server;
server.sin_family = AF_INET;
server.sin_port = htons(port);
inet_pton(AF_INET, ip.c_str(), &server.sin_addr);
if (bind(sockfd_, (struct sockaddr *)&server, sizeof(server)) < 0) {
std::cout << "server bind fail!" << std::endl; exit(BindErr);
}
std::cout << "server bind successful" << std::endl;
}
void Listen() {
if (listen(sockfd_, 10) < 0) {
std::cout << "server listen fail!" << std::endl; exit(ListenErr);
}
std::cout << "server listen successful" << std::endl;
}
int Accept(uint16_t *client_port, std::string *client_ip) {
struct sockaddr_in client;
socklen_t len = sizeof(client);
int sockfd = accept(sockfd_, (struct sockaddr *)&client, &len);
if (sockfd < 0) { std::cout << "accept fail!" << std::endl; return -1; }
std::cout << "accept successful" << std::endl;
*client_port = ntohs(client.sin_port);
char ip[64];
inet_ntop(AF_INET, &client.sin_addr, ip, sizeof ip);
*client_ip = ip;
return sockfd;
}
void Connect(uint16_t &server_port, std::string &server_ip) {
struct sockaddr_in server;
server.sin_family = AF_INET;
server.sin_port = htons(server_port);
inet_pton(AF_INET, server_ip.c_str(), &server.sin_addr);
if (connect(sockfd_, (struct sockaddr *)&server, sizeof server) < 0) {
std::cout << "connect fail!" << std::endl; exit(ConnectErr);
}
std::cout << "connect successful!" << std::endl;
}
void Close() { close(sockfd_); }
int fd() { return sockfd_; }
~Socket() { close(sockfd_); }
private:
int sockfd_;
};
class CalculatorServer {
public:
CalculatorServer() { }
response CalculatorHandler(const request &req) {
response res(0, 0);
switch (req.op_) {
case '+': res.result_ = req.x_ + req.y_; break;
case '-': res.result_ = req.x_ - req.y_; break;
case '*': res.result_ = req.x_ * req.y_; break;
case '/': if (req.y_ == 0) { res.code_ = 1; break; } res.result_ = req.x_ / req.y_; break;
case '%': if (req.y_ == 0) { res.code_ = 2; break; } res.result_ = req.x_ % req.y_; break;
default: res.code_ = 3; break;
}
return res;
}
std::string Calculator(std::string& s) {
std::string content;
if (Decode(s, &content) == false) { return ""; }
request req;
bool r = req.Deserialization(content);
if (!r) { return ""; }
response res = CalculatorHandler(req);
std::string ret = res.serialization();
ret = Encode(ret);
return ret;
}
~CalculatorServer() { }
};
using func_t = std::function<std::string(std::string &)>;
class TcpServer {
public:
TcpServer(uint16_t port, std::string ip, func_t callback) : port_(port), ip_(ip), callback_(callback) { }
void InitServer() {
listenfd_.Bind(port_, ip_);
listenfd_.Listen();
std::cout << "init server successful!" << std::endl;
}
void start() {
signal(SIGCHLD, SIG_IGN);
signal(SIGPIPE, SIG_IGN);
while (true) {
uint16_t client_port;
std::string client_ip;
int sockfd = listenfd_.Accept(&client_port, &client_ip);
if (sockfd < 0) { continue; }
if (fork() == 0) {
listenfd_.Close();
std::string inbuffer_stream;
while (1) {
char buffer[1280];
ssize_t s = read(sockfd, buffer, sizeof buffer - 1);
if (s > 0) {
buffer[s] = 0;
inbuffer_stream += buffer;
while (true) {
std::string info = callback_(inbuffer_stream);
std::cout << info << std::endl;
if (info.empty()) { break; }
write(sockfd, info.c_str(), info.size());
}
} else if (s == 0) { break; }
else { break; }
}
close(sockfd);
exit(0);
}
close(sockfd);
}
}
~TcpServer() { }
private:
Socket listenfd_;
uint16_t port_;
std::string ip_;
func_t callback_;
};
int main(int argc,char* argv[]) {
if(argc != 3) { exit(0); }
uint16_t server_port = std::atoi(argv[2]);
std::string server_ip = argv[1];
CalculatorServer cal;
TcpServer* ser = new TcpServer(server_port,server_ip,std::bind(&CalculatorServer::Calculator, &cal, std::placeholders::_1));
ser->InitServer();
ser->start();
return 0;
}
#include <iostream>
#include <cassert>
#include <unistd.h>
#include "Protocol.hpp"
#include "Socket.hpp"
static void Usage(const std::string &proc) {
std::cout << "\nUsage: " << proc << " serverip serverport\n" << std::endl;
}
// ./clientcal ip port
int main(int argc, char *argv[]) {
if (argc != 3) { Usage(argv[0]); exit(0); }
std::string serverip = argv[1];
uint16_t serverport = std::stoi(argv[2]);
Socket sockfd;
sockfd.Connect(serverport, serverip);
srand(time(nullptr) ^ getpid());
int cnt = 1;
const std::string opers = "+-*/%=-=&^";
std::string inbuffer_stream;
while (cnt <= 10) {
std::cout << "===============第" << cnt << "次测试....., " << "===============" << std::endl;
int x = rand() % 100 + 1;
usleep(1234);
int y = rand() % 100;
usleep(4321);
char oper = opers[rand() % opers.size()];
request req(x, y, oper);
req.DebugPrint();
std::string package;
package = req.serialization();
package = Encode(package);
std::cout << package << std::endl;
write(sockfd.fd(), package.c_str(), package.size());
char buffer[128];
ssize_t n = read(sockfd.fd(), buffer, sizeof(buffer) - 1);
if (n > 0) {
buffer[n] = 0;
inbuffer_stream += buffer;
std::string content;
bool r = Decode(inbuffer_stream, &content);
assert(r);
response resp;
r = resp.Deserialization(content);
assert(r);
resp.DebugPrint();
}
std::cout << "=================================================" << std::endl;
sleep(1);
cnt++;
}
sockfd.Close();
return 0;
}
到这里,我们已经完整实现了一个基于 C++ 的 TCP 计算器服务器。TCP 编程的本质不是收发数据,而是如何正确解析数据的边界。TCP 没有消息边界,所以我们必须设计应用层协议。

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online
将 Markdown(GFM)转为 HTML 片段,浏览器内 marked 解析;与 HTML 转 Markdown 互为补充。 在线工具,Markdown 转 HTML在线工具,online
将 HTML 片段转为 GitHub Flavored Markdown,支持标题、列表、链接、代码块与表格等;浏览器内处理,可链接预填。 在线工具,HTML 转 Markdown在线工具,online
通过删除不必要的空白来缩小和压缩JSON。 在线工具,JSON 压缩在线工具,online