赞
踩
作者:爱写代码的刚子
时间:2024.4.24
前言:基于高并发服务器的搜索引擎,引用了第三方库cpp-httplib,cppjieba,项目的要点在代码注释中了
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8"> <meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta http-equiv="X-UA-Compatible" content="IE=edge"> <script src="https://cdn.jsdelivr.net/npm/jquery@3.5.1/dist/jquery.min.js"></script> <title>本地 boost 搜索引擎</title> <style> * { margin: 0; padding: 0; box-sizing: border-box; } html, body { height: 100%; font-family: Arial, sans-serif; } .container { width: 100%; display: flex; flex-direction: column; align-items: center; } .title { width: 100%; background-color: #4e6ef2; color: #fff; text-align: center; padding: 10px 0; font-size: 24px; font-weight: bold; } .search-container { width: 100%; background-color: #f2f2f2; display: flex; justify-content: center; align-items: center; padding: 20px 0; position: relative; } .search-input { width: calc(100% - 130px); /* 调整搜索框宽度 */ max-width: 300px; /* 设置最大宽度 */ height: 40px; padding: 10px; border: 1px solid #ccc; border-radius: 20px; font-size: 16px; outline: none; } .search-btn { width: 100px; /* 调整按钮宽度 */ height: 40px; background-color: #4e6ef2; color: #fff; border: none; border-radius: 20px; font-size: 16px; cursor: pointer; position: absolute; right: 10px; } .result-container { width: 100%; padding: 20px 0; display: flex; flex-direction: column; align-items: center; } .result-item { width: 90%; /* 修改为百分比宽度,更适应移动设备 */ max-width: 800px; /* 设置最大宽度 */ border: 1px solid #ccc; border-radius: 5px; padding: 10px; margin-top: 10px; } .result-title { font-size: 18px; color: #4e6ef2; text-decoration: none; } .result-desc { font-size: 14px; color: #333; margin-top: 5px; } .result-url { font-size: 12px; color: #666; margin-top: 5px; } </style> </head> <body> <div class="container"> <div class="title">boost 搜索引擎</div> <div class="search-container"> <input type="text" class="search-input" value="输入搜索关键字..." onfocus="if(this.value=='输入搜索关键字...') this.value='';" onblur="if(this.value=='') this.value='输入搜索关键字...';"> <button class="search-btn" onclick="Search()">搜索一下</button> </div> <div class="result-container"> <!-- 搜索结果动态生成 --> </div> </div> <script> function Search() { let query = $(".search-input").val().trim(); if (query == '') { return; } $.ajax({ type: "GET", url: "/s?word=" + query, dataType: "json", success: function (data) { BuildHtml(data); $(".search-input").css("margin-top", "20px"); } }); } function BuildHtml(data) { let result_container = $(".result-container"); result_container.empty(); if (!data || data.length === 0) { result_container.append("<div>未找到相关结果</div>"); return; } for (let elem of data) { let item = $("<div>", {class: "result-item"}); let title = $("<a>", {class: "result-title", href: elem.url, text: elem.title, target: "_blank"}); let desc = $("<div>", {class: "result-desc", text: elem.desc}); let url = $("<div>", {class: "result-url", text: elem.url}); title.appendTo(item); desc.appendTo(item); url.appendTo(item); item.appendTo(result_container); } } </script> </body> </html>
#pragma once #include <iostream> #include <vector> #include <string> #include <fstream> #include <unordered_map> #include <mutex> #include "util.hpp" #include "log.hpp" namespace ns_index{ struct DocInfo{ std::string title;//文档标题 std::string content;//文档对应的去标签之后的内容 std::string url;//官网文档url uint64_t doc_id; //文档的ID }; struct InvertedElem{//倒排的元素 uint64_t doc_id; std::string word; int weight; }; //倒排拉链 typedef std::vector<InvertedElem> InvertedList; class Index{ private: //正排索引的数据结构用数组,数组的下标天然是文档的ID std::vector<DocInfo> forward_index;//正排索引 //倒排索引一定是一个关键字和一组(个)InvertedElem对应(关键字和倒排拉链的对应关系) std::unordered_map<std::string , InvertedList>inverted_index; private: Index(){}//单例,但是不能delete Index(const Index&) = delete; Index& operator = (const Index&) = delete; static Index *instance; static std::mutex mtx; public: ~Index(){} public: static Index* GetInstance()//多线程环境会存在线程安全 { if(nullptr==instance) { mtx.lock(); if(nullptr==instance) { instance = new Index(); } mtx.unlock(); } return instance; } //根据doc_id找到文档内容 DocInfo* GetForwardIndex(uint64_t doc_id) { if(doc_id >= forward_index.size()) { //std::cerr<<"doc_id out range,error!"<<std::endl; LOG2(DEBUG,"doc_id out range,error!"); return nullptr; } return &forward_index[doc_id]; } //根据关键字string,获得倒排拉链 InvertedList *GetInvertedList(const std::string &word) { auto iter = inverted_index.find(word); if(iter==inverted_index.end()) { //std::cerr<<word<<"have no InvertedList"<<std::endl; LOG2(WARNING,"用户没搜到"); return nullptr; } return &(iter->second); } //根据去标签,格式化之后的文档,构建正排和倒排索引 //data/raw_html/raw.txt bool BuildIndex(const std::string &input)//parse处理完毕的数据交给我(文件的路径) { std::ifstream in(input,std::ios::in | std::ios::binary); if(!in.is_open()){ //std::cerr<<"sorry,"<<input<<"open error"<<std::endl; LOG2(FATAL,"open error"); return false; } //读取文件 std::string line;//每一行是一个文件 int count = 0; while(std::getline(in,line)) { //建立正排索引 DocInfo* doc=BuildForwardIndex(line); if(doc==nullptr) { //std::cerr<<"build"<<line<<"error"<<std::endl;//for debug LOG2(DEBUG,"建立正排索引错误"); continue; } BuildInvertedIndex(*doc); count++; if(count % 50==0) { //std::cout<< "当前已经建立的索引文档:"<<count <<std::endl; LOG2(NORMAL,"当前已经建立的索引文档: " + std::to_string(count)); } } return true; } private: DocInfo *BuildForwardIndex(const std::string &line) { //1. 解析line,字符串切分 line -> 3个string,(title、content、url) std::vector<std::string> results; const std::string sep ="\3";//行内分隔符 ns_util::StringUtil::Split(line,&results,sep); if(results.size()!=3){ return nullptr; } //2. 字符串进行填充到DoInfo DocInfo doc; doc.title = results[0]; doc.content = results[1]; doc.url = results[2]; doc.doc_id = forward_index.size();//先进行保存,再插入,对应的id就是当前doc在vector下的下标 //3. 插入到正排索引的vector forward_index.push_back(std::move(doc));//doc.html文件内容会比较大,避免拷贝应使用move return &forward_index.back(); } bool BuildInvertedIndex(const DocInfo &doc) { //DocInfo(title,content,url,doc_id) //world -> 倒排拉链 struct word_cnt{ int title_cnt; int content_cnt; word_cnt():title_cnt(0),content_cnt(0){} }; std::unordered_map<std::string,word_cnt> word_map;//用来暂存词频的映射表 //对标题进行分词 std::vector<std::string> title_words; ns_util::JiebaUtil::CutString2(doc.title,&title_words);//调用了CutString2 //对标题进行词频统计 for(auto &s : title_words){ boost::to_lower(s); word_map[s].title_cnt++; } //对文档内容进行分词 std::vector<std::string> content_words; ns_util::JiebaUtil::CutString2(doc.content,&content_words); //对内容进行词频统计 for(auto &s : content_words){ boost::to_lower(s); word_map[s].content_cnt++; } #define X 10 #define Y 1 //Hello.HELLO.hello(倒排索引的大小写要忽略) //根据文档内容,形成一个或者多个InvertedElem(倒排拉链) //因为当前我们是一个一个文档进行处理的,一个文档会包含多个“词”,都应当对应到当前的doc_id for(auto &word_pair : word_map){ InvertedElem item; item.doc_id = doc.doc_id; item.word = word_pair.first; item.weight = X*word_pair.second.title_cnt + Y*word_pair.second.content_cnt;//相关性 InvertedList &inverted_list = inverted_index[word_pair.first]; inverted_list.push_back(std::move(item)); } //1.需要对title && content都要先分词 //title: 吃/葡萄 //content:吃/葡萄/不吐/葡萄皮 //词和文档的相关性(非常复杂,我们采用词频:在标题中出现的词,可以认为相关性更高一些,在内容中出现相关性低一些) //2.词频统计 //知道了在文档中,标题和内容每个词出现的次数 //3. 自定义相关性 //jieba的使用————cppjieba return true; } }; Index* Index::instance = nullptr; std::mutex Index::mtx; }
#pragma once #include <iostream> #include <string> #include <ctime> #define NORMAL 1 #define WARNING 2 #define DEBUG 3 #define FATAL 4 #define LOG2(LEVEL,MESSAGE) log(#LEVEL,MESSAGE,__FILE__,__LINE__) //@brief:时间戳转日期时间 static inline std::string getDateTimeFromTS(time_t ts) { if(ts<0) { return ""; } struct tm tm = *localtime(&ts); static char time_str[32]{0}; snprintf(time_str,sizeof(time_str),"%04d-%02d-%02d %02d:%02d:%02d",tm.tm_year+1900,tm.tm_mon+1,tm.tm_mday,tm.tm_hour,tm.tm_min,tm.tm_sec); return std::string(time_str); } void log(std::string level,std::string message,std::string file,int line) { std::cout<<"["<<level<<"]"<<"["<<getDateTimeFromTS(time(nullptr))<<"]"<<"["<<message<<"]"<<"["<<file<<":"<<line<<"]"<<std::endl; }
#include <iostream> #include <string> #include <vector> #include <boost/filesystem.hpp> #include "util.hpp" #include "log.hpp" const std::string src_path = "data/input"; const std::string output = "data/raw_html/raw.txt";//结尾没有'/' typedef struct DocInfo{ std::string title;//文档的标题 std::string content;//文档内容 std::string url;//该文档在官网中的url }DocInfo_t; //const & 输入 //* 输出 //& 输入输出 bool EnumFile(const std::string &src_path,std::vector<std::string> *file_list); bool ParseHtml(const std::vector<std::string> &files_list,std::vector<DocInfo_t> *results); bool SaveHtml(const std::vector<DocInfo_t> &results,const std::string &output); int main() { std::vector<std::string> files_list; //第一步,递归式的把每个html文件名带路径,保存到files_list中,方便后期进行一个一个的文件进行读取 if(!EnumFile(src_path, &files_list)) { //std::cerr<<"enum file error!" <<std::endl; LOG2(FATAL,"enum file error!"); return 1; } //第二步,按照files_list读取每个文件的内容,并进行解析 std::vector<DocInfo_t> results; if(!ParseHtml(files_list,&results)) { //std::cerr <<"parse html error"<<std::endl; LOG2(FATAL,"parse html error"); return 2; } //第三步,把解析完毕的各个文件的内容,写入到output中,按照\3作为每个文档的分割符 if(!SaveHtml(results,output)) { //std::cerr<<"save html error"<<std::endl; LOG2(FATAL,"save html error"); return 3; } return 0; } bool EnumFile(const std::string &src_path,std::vector<std::string> *files_list) { namespace fs = boost::filesystem; fs::path root_path(src_path); //判断路径是否存在,不存在就没必要往后走了 if(!fs::exists(root_path)) { //std::cerr<< src_path<<"not exists"<<std::endl; LOG2(FATAL,"src_path not exists"); return false; } //定义一个空的迭代器,用来进行判断递归结束 fs::recursive_directory_iterator end; for(fs::recursive_directory_iterator iter(root_path);iter != end;iter++){ //判断文件是否是普通文件(html是普通文件) if(!fs::is_regular_file(*iter)) { continue; } if(iter->path().extension()!= ".html"){//判断文件路径名的后缀是否符合要求 path()提取路径字符串,是一个路径对象 ,extension()提取后缀(.以及之后的部分) continue; } //std::cout<<"debug: " <<iter->path().string()<<std::endl; //当前的路径一定是一个合法的,以.html结束的普通网页文件、 files_list->push_back(iter->path().string());//将所有带路径的html保存到files_list,方便后续进行文本分析 } return true; } static bool ParseTitle(const std::string &file,std::string *title){ std::size_t begin = file.find("<title>"); if(begin == std::string::npos){ return false; } std::size_t end = file.find("</title>"); if(end==std::string::npos) { return false; } begin+=std::string("<title>").size(); if(begin>end){ return false; } *title = file.substr(begin,end-begin); return true; } static bool ParseContent(const std::string &file,std::string *content){ //去标签,基于一个简易的状态机编写 enum status{ LABLE, CONTENT }; enum status s=LABLE; for(char c :file){ switch(s) { case LABLE: if(c=='>') s= CONTENT; break; case CONTENT: if(c=='<') s= LABLE; else { //我们不想要保留原始文件中的‘\n’,因为我们想用\n作为html解析之后文本的分隔符 if(c=='\n')c=' '; content->push_back(c); } break; default: break; } } return true; } static bool ParseUrl(const std::string &file_path,std::string *url) { std::string url_head = "https://www.boost.org/doc/libs/1_78_0/doc/html"; std::string url_tail = file_path.substr(src_path.size());//越过长度截取 *url = url_head + url_tail; return true; } //for debug static void ShowDoc(const DocInfo_t &doc) { std::cout<<"title:"<<doc.title << std::endl; std::cout<<"content:"<<doc.content << std::endl; std::cout<<"url:"<<doc.url << std::endl; } bool ParseHtml(const std::vector<std::string> &files_list,std::vector<DocInfo_t> *results) { for(const std::string &file : files_list) { //1.读取文件,Read() std::string result; if(!ns_util::FileUtil::ReadFile(file,&result)){ continue; } //2.解析指定的文件,提取title DocInfo_t doc; if(!ParseTitle(result,&doc.title)){ continue; } //3.解析指定的文件,提取content if(!ParseContent(result,&doc.content)){ continue; } //4.解析指定的文件路径,构建url if(!ParseUrl(file,&doc.url)){ continue; } //done,一定是完成了解析任务,当前文档的相关结果都保存在doc中 results->push_back(std::move(doc)); //bug to do细节,本质会发生拷贝,效率可能会比较低 (move是细节) //std::cout<<1<<std::endl; //for debug //ShowDoc(doc); //break; } return true; } bool SaveHtml(const std::vector<DocInfo_t> &results,const std::string &output) { #define SEP '\3' //按照二进制方式进行写入 std::ofstream out(output,std::ios::out | std::ios::binary); if(!out.is_open()){ //std::cerr<<"open "<<output <<"failed!"<<std::endl; LOG2(FATAL,"open output failed!"); return false; } //就可以进行文件内容的写入了 for(auto &item : results) { std::string out_string; out_string = item.title; out_string+=SEP; out_string +=item.content; out_string +=SEP; out_string +=item.url; out_string+='\n'; out.write(out_string.c_str(),out_string.size()); } out.close(); return true; } //strstr 前闭后开
#pragma once #include "index.hpp" #include "util.hpp" #include <algorithm> #include <jsoncpp/json/json.h> #include "log.hpp" //#include <vector> namespace ns_searcher{ struct InvertedElemPrint{ uint64_t doc_id; int weight; std::vector<std::string> words; InvertedElemPrint():doc_id(0),weight(0){} }; class Searcher{ private: ns_index::Index *index; public: Searcher(){} ~Searcher(){} public: void InitSearcher(const std::string &input) { //1. 获取或者创建index对象 index = ns_index::Index::GetInstance(); //std::cout <<"获取index单例成功..."<<std::endl; LOG2(NORMAL,"获取index单例成功..."); //2. 根据index对象建立索引 index->BuildIndex(input);//CutString //std::cout<<"建立正排和倒排索引成功..."<<std::endl; LOG2(NORMAL,"建立正排和倒排索引成功..."); } //query:搜索关键字 //json_string:返回给用户浏览器的搜索结果 void Search(const std::string &query,std::string *json_string) { //1. [分词]:对我们的query进行按照searcher的要求进行分词 std::vector<std::string> words; ns_util::JiebaUtil::CutString(query,&words); //2. [触发]:就是根据分词的各个“词,进行index查找”,建立index是忽略大小写,所以搜索关键字也需要 //ns_index::InvertedList inverted_list_all; std::vector<InvertedElemPrint> inverted_list_all; std::unordered_map<uint64_t,InvertedElemPrint> tokens_map; for(std::string word : words) { boost::to_lower(word); ns_index::InvertedList *inverted_list = index->GetInvertedList(word); if(nullptr == inverted_list) { continue; } //不完美的地方(去重) //inverted_list_all.insert(inverted_list_all.end(),inverted_list->begin(),inverted_list->end()); for(const auto &elem : *inverted_list) { auto &item = tokens_map[elem.doc_id]; //item一定是doc_id相同的print节点 item.doc_id =elem.doc_id; item.weight += elem.weight; item.words.push_back(elem.word); } } for(const auto&item : tokens_map){ inverted_list_all.push_back(std::move(item.second)); } //3. [合并排序]:汇总查找结果,按照相关性(weight)降序排序 /*std::sort(inverted_list_all.begin(),inverted_list_all.end(),\ [](const ns_index::InvertedElem &e1,const ns_index::InvertedElem &e2){ return e1.weight>e2.weight; } ); */ std::sort(inverted_list_all.begin(),inverted_list_all.end(),\ [](const InvertedElemPrint&e1,const InvertedElemPrint& e2){ return e1.weight >e2.weight; }); //4. [构建]:根据查找出来的结果,构建json串————jsoncpp----通过jsoncpp完成序列化和反序列化 Json::Value root; for(auto &item : inverted_list_all){ ns_index::DocInfo *doc = index->GetForwardIndex(item.doc_id); if(nullptr == doc) { continue; } Json::Value elem; elem["title"] = doc->title; elem["desc"] = GetDesc(doc->content,item.words[0]); //content是文档的去标签的结果,但是不是我们想要的,我们要的是一部分 elem["url"] = doc->url; //foe debug //elem["id"]= (int)item.doc_id;//doc_id是64位的uint64_t //elem["weight"] = item.weight; root.append(elem); } //Json::StyledWriter writer; Json::FastWriter writer; *json_string = writer.write(root); } std::string GetDesc(const std::string &html_content,const std::string &word) { //找到word在html_content中的首次出现,然后往前找50个字节(如果没有,从begin开始),往后找100个字节(如果没有,到end就可以),截取出这部分内容 const std::size_t prev_step = 50; const std::size_t next_step =100; //1. 找到首次出现 auto iter = std::search(html_content.begin(),html_content.end(),word.begin(),word.end(),[](int x,int y){ return (std::tolower(x)==std::tolower(y)); }); if(iter == html_content.end()) { return "None1"; } std::size_t pos = std::distance(html_content.begin(),iter); /*std::size_t pos = html_content.find(word); if(pos == std::string::npos){ return "None1";//这种情况是不存在的 }*/ //2. 获取start,end //这里有一个大坑,就是std::size_t是一个无符号数,无符号数相减为正数 std::size_t start = 0; std::size_t end = html_content.size() - 1; //如果之前有50个字符,就更新开始位置 if(pos >start+ prev_step) start = pos -prev_step;//换成加法 if(pos + next_step <end) end = pos + next_step; //3. 截取子串,return if(start >= end)return "None2"; std::string desc = html_content.substr(start,end-start+1); std::string result="..." + desc + "..."; return result; } }; }
#pragma once #include <iostream> #include <string> #include <fstream> #include <vector> #include "boost_1_84_0/boost/algorithm/string.hpp" #include "../cppjieba/include/cppjieba/Jieba.hpp" //#include "cppjieba/jieba" #include "log.hpp" #include <mutex> #include <unordered_map> namespace ns_util{ class FileUtil { public: static bool ReadFile(const std::string &file_path,std::string *out) { std::ifstream in(file_path,std::ios::in); if(!in.is_open()) { //std::cerr << "open_file" << file_path <<"error" <<std::endl; LOG2(FATAL,"open_file error"); return false; } std::string line; while(std::getline(in,line)){//如何理解getline读取到文件结束呢??getline到返回值是一个&,while(bool),本质是因为重载了强制类型转换 *out += line; } in.close(); return true; } }; class StringUtil{ public: static void Split(const std::string&target,std::vector<std::string>*out,const std::string& sep) { //boost split boost::split(*out,target,boost::is_any_of(sep),boost::token_compress_on); } }; const char* const DICT_PATH = "./dict/jieba.dict.utf8"; const char* const HMM_PATH = "./dict/hmm_model.utf8"; const char* const USER_DICT_PATH = "./dict/user.dict.utf8"; const char* const IDF_PATH = "./dict/idf.utf8"; const char* const STOP_WORD_PATH = "./dict/stop_words.utf8"; class JiebaUtil { private: //static cppjieba::Jieba jieba; cppjieba::Jieba jieba; std::unordered_map<std::string,bool> stop_words; private: JiebaUtil():jieba(DICT_PATH,HMM_PATH,USER_DICT_PATH,IDF_PATH,STOP_WORD_PATH) {} JiebaUtil(const JiebaUtil&)=delete; JiebaUtil& operator=(JiebaUtil const&)=delete; static JiebaUtil *instance; public: static JiebaUtil*get_instance() { static std::mutex mtx; if(nullptr==instance){ mtx.lock(); if(nullptr ==instance){ instance = new JiebaUtil(); instance->InitJiebaUtil(); } mtx.unlock(); } return instance; } void InitJiebaUtil() { std::ifstream in(STOP_WORD_PATH); if(!in.is_open()) { LOG2(FATAL,"load stop words fill error"); return ; } std::string line; while(std::getline(in,line)) { stop_words.insert({line,true}); } in.close(); } void CutStringHelper(const std::string &src,std::vector<std::string>*out) { jieba.CutForSearch(src,*out); std::vector<std::string> v(*out); // //for debug // for(auto e : v) // { // std::cout<<"暂停词测试存储 v:"<<e<<"----"<<std::endl; // } for(auto iter=out->begin();iter!=out->end();){ auto it =stop_words.find(*iter); if(it!=stop_words.end()) { //说明当前的string是暂停词,需要去掉 iter = out->erase(iter); } else { iter++; } } if(out->empty()) { //std::cout<< "out为空"<<std::endl; *out = v; } //debug // std::cout<< out->empty()<<std::endl; // for(auto e : *out) // { // std::cout<<"暂停词测试out 后:"<<e<<"----"<<std::endl; // } } void CutString_has_stop_words(const std::string &src,std::vector<std::string>*out) { jieba.CutForSearch(src,*out); } public: static void CutString(const std::string &src,std::vector<std::string> *out) { //debug //std::cout<< "CutStringHelper" << std::endl; ns_util::JiebaUtil::get_instance()->CutStringHelper(src,out); //jieba.CutForSearch(src,*out); } static void CutString2(const std::string &src,std::vector<std::string> *out) { //debug //std::cout<< "CutString2()" << std::endl; ns_util::JiebaUtil::get_instance()->CutString_has_stop_words(src,out); } //cppjieba::Jieba JiebaUtil::jieba(DICT_PATH,HMM_PATH,USER_DICT_PATH,IDF_PATH,STOP_WORD_PATH); }; JiebaUtil *JiebaUtil::instance = nullptr; //加static是因为这个函数要被外部使用,加了static可以不创建对象就可以使用 }
#include "searcher.hpp" #include "httplib.h" #include "../http.hpp" const std::string root_path = "./wwwroot"; const std::string input = "data/raw_html/raw.txt"; ns_searcher::Searcher search; std::string RequestStr(const HttpRequest &req) { std::stringstream ss; ss << req._method << " " << req._path << " " << req._version << "\r\n"; for (auto it : req._params) { ss << it.first << ": " << it.second << "\r\n"; DBG_LOG("RequestStr_params: first:%s ,second:%s", it.first, it.second); } for (auto it : req._headers) { ss << it.first << ": " << it.second << "\r\n"; DBG_LOG("RequestStr_headers: first:%s ,second:%s", it.first.c_str(), it.second.c_str()); } ss << "\r\n"; ss << req._body; return ss.str(); } void Hello(const HttpRequest &req, HttpResponse *rsp) { if (!req.HasParam("word")) { rsp->SetContent("必须要有搜索关键字!", "text/plain; charset=utf-8"); return; } // rsp.set_content("hello world!你好世界\n","text/plain; charset=utf-8"); const std::string word = req.GetParam("word"); // 获取名为word的参数值 // debug // std::cout<<"test:"<<word<<std::endl; // std::cout<<"用户在搜索:"<<word<<std::endl; //LOG2(NORMAL, "用户搜索的:" + word); std::string json_string; search.Search(word, &json_string); rsp->SetContent(json_string, "application/json"); //rsp->SetContent(RequestStr(req), "text/plain"); } void Login(const HttpRequest &req, HttpResponse *rsp) { rsp->SetContent(RequestStr(req),"text/plain"); } void PutFile(const HttpRequest &req, HttpResponse *rsp) { rsp->SetContent(RequestStr(req), "text/plain"); } void DelFile(const HttpRequest &req, HttpResponse *rsp) { rsp->SetContent(RequestStr(req), "text/plain"); } int main() { search.InitSearcher(input); HttpServer server(8085); server.SetThreadCount(3); server.SetBaseDir(root_path); // 设置静态资源根目录,告诉服务器有静态资源请求到来,需要到哪里去找资源路径 server.Get("/s", Hello); server.Post("/login", Login); server.Put("/1234.txt", PutFile); server.Delete("/1234.txt", DelFile); server.Listen(); return 0; }
#include <iostream> #include <fstream> #include <string> #include <vector> #include <regex> #include <sys/stat.h> #include "../server.hpp" #define DEFALT_TIMEOUT 10 std::unordered_map<int, std::string> _statu_msg = { {100, "Continue"}, {101, "Switching Protocol"}, {102, "Processing"}, {103, "Early Hints"}, {200, "OK"}, {201, "Created"}, {202, "Accepted"}, {203, "Non-Authoritative Information"}, {204, "No Content"}, {205, "Reset Content"}, {206, "Partial Content"}, {207, "Multi-Status"}, {208, "Already Reported"}, {226, "IM Used"}, {300, "Multiple Choice"}, {301, "Moved Permanently"}, {302, "Found"}, {303, "See Other"}, {304, "Not Modified"}, {305, "Use Proxy"}, {306, "unused"}, {307, "Temporary Redirect"}, {308, "Permanent Redirect"}, {400, "Bad Request"}, {401, "Unauthorized"}, {402, "Payment Required"}, {403, "Forbidden"}, {404, "Not Found"}, {405, "Method Not Allowed"}, {406, "Not Acceptable"}, {407, "Proxy Authentication Required"}, {408, "Request Timeout"}, {409, "Conflict"}, {410, "Gone"}, {411, "Length Required"}, {412, "Precondition Failed"}, {413, "Payload Too Large"}, {414, "URI Too Long"}, {415, "Unsupported Media Type"}, {416, "Range Not Satisfiable"}, {417, "Expectation Failed"}, {418, "I'm a teapot"}, {421, "Misdirected Request"}, {422, "Unprocessable Entity"}, {423, "Locked"}, {424, "Failed Dependency"}, {425, "Too Early"}, {426, "Upgrade Required"}, {428, "Precondition Required"}, {429, "Too Many Requests"}, {431, "Request Header Fields Too Large"}, {451, "Unavailable For Legal Reasons"}, {501, "Not Implemented"}, {502, "Bad Gateway"}, {503, "Service Unavailable"}, {504, "Gateway Timeout"}, {505, "HTTP Version Not Supported"}, {506, "Variant Also Negotiates"}, {507, "Insufficient Storage"}, {508, "Loop Detected"}, {510, "Not Extended"}, {511, "Network Authentication Required"} }; std::unordered_map<std::string, std::string> _mime_msg = { {".aac", "audio/aac"}, {".abw", "application/x-abiword"}, {".arc", "application/x-freearc"}, {".avi", "video/x-msvideo"}, {".azw", "application/vnd.amazon.ebook"}, {".bin", "application/octet-stream"}, {".bmp", "image/bmp"}, {".bz", "application/x-bzip"}, {".bz2", "application/x-bzip2"}, {".csh", "application/x-csh"}, {".css", "text/css"}, {".csv", "text/csv"}, {".doc", "application/msword"}, {".docx", "application/vnd.openxmlformats-officedocument.wordprocessingml.document"}, {".eot", "application/vnd.ms-fontobject"}, {".epub", "application/epub+zip"}, {".gif", "image/gif"}, {".htm", "text/html"}, {".html", "text/html"}, {".ico", "image/vnd.microsoft.icon"}, {".ics", "text/calendar"}, {".jar", "application/java-archive"}, {".jpeg", "image/jpeg"}, {".jpg", "image/jpeg"}, {".js", "text/javascript"}, {".json", "application/json"}, {".jsonld", "application/ld+json"}, {".mid", "audio/midi"}, {".midi", "audio/x-midi"}, {".mjs", "text/javascript"}, {".mp3", "audio/mpeg"}, {".mpeg", "video/mpeg"}, {".mpkg", "application/vnd.apple.installer+xml"}, {".odp", "application/vnd.oasis.opendocument.presentation"}, {".ods", "application/vnd.oasis.opendocument.spreadsheet"}, {".odt", "application/vnd.oasis.opendocument.text"}, {".oga", "audio/ogg"}, {".ogv", "video/ogg"}, {".ogx", "application/ogg"}, {".otf", "font/otf"}, {".png", "image/png"}, {".pdf", "application/pdf"}, {".ppt", "application/vnd.ms-powerpoint"}, {".pptx", "application/vnd.openxmlformats-officedocument.presentationml.presentation"}, {".rar", "application/x-rar-compressed"}, {".rtf", "application/rtf"}, {".sh", "application/x-sh"}, {".svg", "image/svg+xml"}, {".swf", "application/x-shockwave-flash"}, {".tar", "application/x-tar"}, {".tif", "image/tiff"}, {".tiff", "image/tiff"}, {".ttf", "font/ttf"}, {".txt", "text/plain"}, {".vsd", "application/vnd.visio"}, {".wav", "audio/wav"}, {".weba", "audio/webm"}, {".webm", "video/webm"}, {".webp", "image/webp"}, {".woff", "font/woff"}, {".woff2", "font/woff2"}, {".xhtml", "application/xhtml+xml"}, {".xls", "application/vnd.ms-excel"}, {".xlsx", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"}, {".xml", "application/xml"}, {".xul", "application/vnd.mozilla.xul+xml"}, {".zip", "application/zip"}, {".3gp", "video/3gpp"}, {".3g2", "video/3gpp2"}, {".7z", "application/x-7z-compressed"} }; class Util { public: //字符串分割函数,将src字符串按照sep字符进行分割,得到的各个字串放到arry中,最终返回字串的数量 static size_t Split(const std::string &src, const std::string &sep, std::vector<std::string> *arry) { size_t offset = 0; // 有10个字符,offset是查找的起始位置,范围应该是0~9,offset==10就代表已经越界了 while(offset < src.size()) { size_t pos = src.find(sep, offset);//在src字符串偏移量offset处,开始向后查找sep字符/字串,返回查找到的位置 if (pos == std::string::npos) {//没有找到特定的字符 //将剩余的部分当作一个字串,放入arry中 if(pos == src.size()) break; arry->push_back(src.substr(offset)); return arry->size(); } if (pos == offset) { offset = pos + sep.size(); continue;//当前字串是一个空的,没有内容 } arry->push_back(src.substr(offset, pos - offset)); offset = pos + sep.size(); } return arry->size(); } //读取文件的所有内容,将读取的内容放到一个Buffer中 static bool ReadFile(const std::string &filename, std::string *buf) { std::ifstream ifs(filename, std::ios::binary); if (ifs.is_open() == false) { printf("OPEN %s FILE FAILED!!", filename.c_str()); return false; } size_t fsize = 0; ifs.seekg(0, ifs.end);//跳转读写位置到末尾 fsize = ifs.tellg(); //获取当前读写位置相对于起始位置的偏移量,从末尾偏移刚好就是文件大小 ifs.seekg(0, ifs.beg);//跳转到起始位置 buf->resize(fsize); //开辟文件大小的空间 ifs.read(&(*buf)[0], fsize); if (ifs.good() == false) { printf("READ %s FILE FAILED!!", filename.c_str()); ifs.close(); return false; } ifs.close(); return true; } //向文件写入数据 static bool WriteFile(const std::string &filename, const std::string &buf) { std::ofstream ofs(filename, std::ios::binary | std::ios::trunc); if (ofs.is_open() == false) { printf("OPEN %s FILE FAILED!!", filename.c_str()); return false; } ofs.write(buf.c_str(), buf.size()); if (ofs.good() == false) { ERR_LOG("WRITE %s FILE FAILED!", filename.c_str()); ofs.close(); return false; } ofs.close(); return true; } //URL编码,避免URL中资源路径与查询字符串中的特殊字符与HTTP请求中特殊字符产生歧义 //编码格式:将特殊字符的ascii值,转换为两个16进制字符,前缀% C++ -> C%2B%2B // 不编码的特殊字符: RFC3986文档规定 . - _ ~ 字母,数字属于绝对不编码字符 //RFC3986文档规定,编码格式 %HH //W3C标准中规定,查询字符串中的空格,需要编码为+, 解码则是+转空格 static std::string UrlEncode(const std::string url, bool convert_space_to_plus) { std::string res; for (auto &c : url) { if (c == '.' || c == '-' || c == '_' || c == '~' || isalnum(c)) { res += c; continue; } if (c == ' ' && convert_space_to_plus == true) { res += '+'; continue; } //剩下的字符都是需要编码成为 %HH 格式 char tmp[4] = {0}; //snprintf 与 printf比较类似,都是格式化字符串,只不过一个是打印,一个是放到一块空间中 snprintf(tmp, 4, "%%%02X", c); res += tmp; } return res; } static char HEXTOI(char c) { if (c >= '0' && c <= '9') { return c - '0'; }else if (c >= 'a' && c <= 'z') { return c - 'a' + 10; }else if (c >= 'A' && c <= 'Z') { return c - 'A' + 10; } return -1; } static std::string UrlDecode(const std::string url, bool convert_plus_to_space) { //遇到了%,则将紧随其后的2个字符,转换为数字,第一个数字左移4位,然后加上第二个数字 + -> 2b %2b->2 << 4 + 11 std::string res; for (int i = 0; i < url.size(); i++) { if (url[i] == '+' && convert_plus_to_space == true) { res += ' '; continue; } if (url[i] == '%' && (i + 2) < url.size()) { char v1 = HEXTOI(url[i + 1]); char v2 = HEXTOI(url[i + 2]); char v = v1 * 16 + v2; res += v; i += 2; continue; } res += url[i]; } return res; } //响应状态码的描述信息获取 static std::string StatuDesc(int statu) { auto it = _statu_msg.find(statu); if (it != _statu_msg.end()) { return it->second; } return "Unknow"; } //根据文件后缀名获取文件mime static std::string ExtMime(const std::string &filename) { // a.b.txt 先获取文件扩展名 size_t pos = filename.find_last_of('.'); if (pos == std::string::npos) { return "application/octet-stream"; } //根据扩展名,获取mime std::string ext = filename.substr(pos); auto it = _mime_msg.find(ext); if (it == _mime_msg.end()) { return "application/octet-stream"; } return it->second; } //判断一个文件是否是一个目录 static bool IsDirectory(const std::string &filename) { struct stat st; int ret = stat(filename.c_str(), &st); if (ret < 0) { return false; } return S_ISDIR(st.st_mode); } //判断一个文件是否是一个普通文件 static bool IsRegular(const std::string &filename) { struct stat st; int ret = stat(filename.c_str(), &st); if (ret < 0) { return false; } return S_ISREG(st.st_mode); } //http请求的资源路径有效性判断 // /index.html --- 前边的/叫做相对根目录 映射的是某个服务器上的子目录 // 想表达的意思就是,客户端只能请求相对根目录中的资源,其他地方的资源都不予理会 // /../login, 这个路径中的..会让路径的查找跑到相对根目录之外,这是不合理的,不安全的 static bool ValidPath(const std::string &path) { //思想:按照/进行路径分割,根据有多少子目录,计算目录深度,有多少层,深度不能小于0 std::vector<std::string> subdir; Split(path, "/", &subdir); int level = 0; for (auto &dir : subdir) { if (dir == "..") { level--; //任意一层走出相对根目录,就认为有问题 if (level < 0) return false; continue; } level++; } return true; } }; class HttpRequest { public: std::string _method; //请求方法 std::string _path; //资源路径 std::string _version; //协议版本 std::string _body; //请求正文 std::smatch _matches; //资源路径的正则提取数据 std::unordered_map<std::string, std::string> _headers; //头部字段 std::unordered_map<std::string, std::string> _params; //查询字符串 public: HttpRequest():_version("HTTP/1.1") {} void ReSet() { _method.clear(); _path.clear(); _version = "HTTP/1.1"; _body.clear(); std::smatch match; _matches.swap(match); _headers.clear(); _params.clear(); } //插入头部字段 void SetHeader(const std::string &key, const std::string &val) { _headers.insert(std::make_pair(key, val)); } //判断是否存在指定头部字段 bool HasHeader(const std::string &key) const { auto it = _headers.find(key); if (it == _headers.end()) { return false; } return true; } //获取指定头部字段的值 std::string GetHeader(const std::string &key) const { auto it = _headers.find(key); if (it == _headers.end()) { return ""; } return it->second; } //插入查询字符串 void SetParam(const std::string &key, const std::string &val) { _params.insert(std::make_pair(key, val)); } //判断是否有某个指定的查询字符串 bool HasParam(const std::string &key) const { auto it = _params.find(key); if (it == _params.end()) { return false; } return true; } //获取指定的查询字符串 std::string GetParam(const std::string &key) const { auto it = _params.find(key); if (it == _params.end()) { return ""; } return it->second; } //获取正文长度 size_t ContentLength() const { // Content-Length: 1234\r\n bool ret = HasHeader("Content-Length"); if (ret == false) { return 0; } std::string clen = GetHeader("Content-Length"); return std::stol(clen); } //判断是否是短链接 bool Close() const { // 没有Connection字段,或者有Connection但是值是close,则都是短链接,否则就是长连接 if (HasHeader("Connection") == true && GetHeader("Connection") == "keep-alive") { return false; } return true; } }; class HttpResponse { public: int _statu; bool _redirect_flag; std::string _body; std::string _redirect_url; std::unordered_map<std::string, std::string> _headers; public: HttpResponse():_redirect_flag(false), _statu(200) {} HttpResponse(int statu):_redirect_flag(false), _statu(statu) {} void ReSet() { _statu = 200; _redirect_flag = false; _body.clear(); _redirect_url.clear(); _headers.clear(); } //插入头部字段 void SetHeader(const std::string &key, const std::string &val) { _headers.insert(std::make_pair(key, val)); } //判断是否存在指定头部字段 bool HasHeader(const std::string &key) { auto it = _headers.find(key); if (it == _headers.end()) { return false; } return true; } //获取指定头部字段的值 std::string GetHeader(const std::string &key) { auto it = _headers.find(key); if (it == _headers.end()) { return ""; } return it->second; } void SetContent(const std::string &body, const std::string &type = "text/html") { _body = body; SetHeader("Content-Type", type); } void SetRedirect(const std::string &url, int statu = 302) { _statu = statu; _redirect_flag = true; _redirect_url = url; } //判断是否是短链接 bool Close() { // 没有Connection字段,或者有Connection但是值是close,则都是短链接,否则就是长连接 if (HasHeader("Connection") == true && GetHeader("Connection") == "keep-alive") { return false; } return true; } }; typedef enum { RECV_HTTP_ERROR, RECV_HTTP_LINE, RECV_HTTP_HEAD, RECV_HTTP_BODY, RECV_HTTP_OVER }HttpRecvStatu; #define MAX_LINE 8192 class HttpContext { private: int _resp_statu; //响应状态码 HttpRecvStatu _recv_statu; //当前接收及解析的阶段状态 HttpRequest _request; //已经解析得到的请求信息 private: bool ParseHttpLine(const std::string &line) { std::smatch matches; std::regex e("(GET|HEAD|POST|PUT|DELETE) ([^?]*)(?:\\?(.*))? (HTTP/1\\.[01])(?:\n|\r\n)?", std::regex::icase); bool ret = std::regex_match(line, matches, e); if (ret == false) { _recv_statu = RECV_HTTP_ERROR; _resp_statu = 400;//BAD REQUEST return false; } //0 : GET /bitejiuyeke/login?user=xiaoming&pass=123123 HTTP/1.1 //1 : GET //2 : /bitejiuyeke/login //3 : user=xiaoming&pass=123123 //4 : HTTP/1.1 //请求方法的获取 _request._method = matches[1]; std::transform(_request._method.begin(), _request._method.end(), _request._method.begin(), ::toupper); //资源路径的获取,需要进行URL解码操作,但是不需要+转空格 _request._path = Util::UrlDecode(matches[2], false); //协议版本的获取 _request._version = matches[4]; //查询字符串的获取与处理 std::vector<std::string> query_string_arry; std::string query_string = matches[3]; //查询字符串的格式 key=val&key=val....., 先以 & 符号进行分割,得到各个字串 Util::Split(query_string, "&", &query_string_arry); //针对各个字串,以 = 符号进行分割,得到key 和val, 得到之后也需要进行URL解码 for (auto &str : query_string_arry) { size_t pos = str.find("="); if (pos == std::string::npos) { _recv_statu = RECV_HTTP_ERROR; _resp_statu = 400;//BAD REQUEST return false; } std::string key = Util::UrlDecode(str.substr(0, pos), true); std::string val = Util::UrlDecode(str.substr(pos + 1), true); _request.SetParam(key, val); } return true; } bool RecvHttpLine(Buffer *buf) { if (_recv_statu != RECV_HTTP_LINE) return false; //1. 获取一行数据,带有末尾的换行 std::string line = buf->GetLineAndPop(); //2. 需要考虑的一些要素:缓冲区中的数据不足一行, 获取的一行数据超大 if (line.size() == 0) { //缓冲区中的数据不足一行,则需要判断缓冲区的可读数据长度,如果很长了都不足一行,这是有问题的 if (buf->ReadAbleSize() > MAX_LINE) { _recv_statu = RECV_HTTP_ERROR; _resp_statu = 414;//URI TOO LONG return false; } //缓冲区中数据不足一行,但是也不多,就等等新数据的到来 return true; } if (line.size() > MAX_LINE) { _recv_statu = RECV_HTTP_ERROR; _resp_statu = 414;//URI TOO LONG return false; } bool ret = ParseHttpLine(line); if (ret == false) { return false; } //首行处理完毕,进入头部获取阶段 _recv_statu = RECV_HTTP_HEAD; return true; } bool RecvHttpHead(Buffer *buf) { if (_recv_statu != RECV_HTTP_HEAD) return false; //一行一行取出数据,直到遇到空行为止, 头部的格式 key: val\r\nkey: val\r\n.... while(1){ std::string line = buf->GetLineAndPop(); //2. 需要考虑的一些要素:缓冲区中的数据不足一行, 获取的一行数据超大 if (line.size() == 0) { //缓冲区中的数据不足一行,则需要判断缓冲区的可读数据长度,如果很长了都不足一行,这是有问题的 if (buf->ReadAbleSize() > MAX_LINE) { _recv_statu = RECV_HTTP_ERROR; _resp_statu = 414;//URI TOO LONG return false; } //缓冲区中数据不足一行,但是也不多,就等等新数据的到来 return true; } if (line.size() > MAX_LINE) { _recv_statu = RECV_HTTP_ERROR; _resp_statu = 414;//URI TOO LONG return false; } if (line == "\n" || line == "\r\n") { break; } bool ret = ParseHttpHead(line); if (ret == false) { return false; } } //头部处理完毕,进入正文获取阶段 _recv_statu = RECV_HTTP_BODY; return true; } bool ParseHttpHead(std::string &line) { //key: val\r\nkey: val\r\n.... if (line.back() == '\n') line.pop_back();//末尾是换行则去掉换行字符 if (line.back() == '\r') line.pop_back();//末尾是回车则去掉回车字符 size_t pos = line.find(": "); if (pos == std::string::npos) { _recv_statu = RECV_HTTP_ERROR; _resp_statu = 400;// return false; } std::string key = line.substr(0, pos); std::string val = line.substr(pos + 2); _request.SetHeader(key, val); return true; } bool RecvHttpBody(Buffer *buf) { if (_recv_statu != RECV_HTTP_BODY) return false; //1. 获取正文长度 size_t content_length = _request.ContentLength(); if (content_length == 0) { //没有正文,则请求接收解析完毕 _recv_statu = RECV_HTTP_OVER; return true; } //2. 当前已经接收了多少正文,其实就是往 _request._body 中放了多少数据了 size_t real_len = content_length - _request._body.size();//实际还需要接收的正文长度 //3. 接收正文放到body中,但是也要考虑当前缓冲区中的数据,是否是全部的正文 // 3.1 缓冲区中数据,包含了当前请求的所有正文,则取出所需的数据 if (buf->ReadAbleSize() >= real_len) { _request._body.append(buf->ReadPosition(), real_len); buf->MoveReadOffset(real_len); _recv_statu = RECV_HTTP_OVER; return true; } // 3.2 缓冲区中数据,无法满足当前正文的需要,数据不足,取出数据,然后等待新数据到来 _request._body.append(buf->ReadPosition(), buf->ReadAbleSize()); buf->MoveReadOffset(buf->ReadAbleSize()); return true; } public: HttpContext():_resp_statu(200), _recv_statu(RECV_HTTP_LINE) {} void ReSet() { _resp_statu = 200; _recv_statu = RECV_HTTP_LINE; _request.ReSet(); } int RespStatu() { return _resp_statu; } HttpRecvStatu RecvStatu() { return _recv_statu; } HttpRequest &Request() { return _request; } //接收并解析HTTP请求 void RecvHttpRequest(Buffer *buf) { //不同的状态,做不同的事情,但是这里不要break, 因为处理完请求行后,应该立即处理头部,而不是退出等新数据 switch(_recv_statu) { case RECV_HTTP_LINE: RecvHttpLine(buf); case RECV_HTTP_HEAD: RecvHttpHead(buf); case RECV_HTTP_BODY: RecvHttpBody(buf); } return; } }; class HttpServer { private: using Handler = std::function<void(const HttpRequest &, HttpResponse *)>; using Handlers = std::vector<std::pair<std::regex, Handler>>; Handlers _get_route; Handlers _post_route; Handlers _put_route; Handlers _delete_route; std::string _basedir; //静态资源根目录 TcpServer _server; private: void ErrorHandler(const HttpRequest &req, HttpResponse *rsp) { //1. 组织一个错误展示页面 std::string body; body += "<html>"; body += "<head>"; body += "<meta http-equiv='Content-Type' content='text/html;charset=utf-8'>"; body += "</head>"; body += "<body>"; body += "<h1>"; body += std::to_string(rsp->_statu); body += " "; body += Util::StatuDesc(rsp->_statu); body += "</h1>"; body += "</body>"; body += "</html>"; //2. 将页面数据,当作响应正文,放入rsp中 rsp->SetContent(body, "text/html"); } //将HttpResponse中的要素按照http协议格式进行组织,发送 void WriteReponse(const PtrConnection &conn, const HttpRequest &req, HttpResponse &rsp) { //1. 先完善头部字段 if (req.Close() == true) { rsp.SetHeader("Connection", "close"); }else { rsp.SetHeader("Connection", "keep-alive"); } if (rsp._body.empty() == false && rsp.HasHeader("Content-Length") == false) { rsp.SetHeader("Content-Length", std::to_string(rsp._body.size())); } if (rsp._body.empty() == false && rsp.HasHeader("Content-Type") == false) { rsp.SetHeader("Content-Type", "application/octet-stream"); } if (rsp._redirect_flag == true) { rsp.SetHeader("Location", rsp._redirect_url); } //2. 将rsp中的要素,按照http协议格式进行组织 std::stringstream rsp_str; rsp_str << req._version << " " << std::to_string(rsp._statu) << " " << Util::StatuDesc(rsp._statu) << "\r\n"; for (auto &head : rsp._headers) { rsp_str << head.first << ": " << head.second << "\r\n"; } rsp_str << "\r\n"; rsp_str << rsp._body; //3. 发送数据 conn->Send(rsp_str.str().c_str(), rsp_str.str().size()); } bool IsFileHandler(const HttpRequest &req) { // 1. 必须设置了静态资源根目录 if (_basedir.empty()) { return false; } // 2. 请求方法,必须是GET / HEAD请求方法 if (req._method != "GET" && req._method != "HEAD") { return false; } // 3. 请求的资源路径必须是一个合法路径 if (Util::ValidPath(req._path) == false) { return false; } // 4. 请求的资源必须存在,且是一个普通文件 // 有一种请求比较特殊 -- 目录:/, /image/, 这种情况给后边默认追加一个 index.html // index.html /image/a.png // 不要忘了前缀的相对根目录,也就是将请求路径转换为实际存在的路径 /image/a.png -> ./wwwroot/image/a.png std::string req_path = _basedir + req._path;//为了避免直接修改请求的资源路径,因此定义一个临时对象 if (req._path.back() == '/') { req_path += "index.html"; } if (Util::IsRegular(req_path) == false) { return false; } return true; } //静态资源的请求处理 --- 将静态资源文件的数据读取出来,放到rsp的_body中, 并设置mime void FileHandler(const HttpRequest &req, HttpResponse *rsp) { std::string req_path = _basedir + req._path; if (req._path.back() == '/') { req_path += "index.html"; } bool ret = Util::ReadFile(req_path, &rsp->_body); if (ret == false) { return; } std::string mime = Util::ExtMime(req_path); rsp->SetHeader("Content-Type", mime); return; } //功能性请求的分类处理 void Dispatcher(HttpRequest &req, HttpResponse *rsp, Handlers &handlers) { //在对应请求方法的路由表中,查找是否含有对应资源请求的处理函数,有则调用,没有则发挥404 //思想:路由表存储的时键值对 -- 正则表达式 & 处理函数 //使用正则表达式,对请求的资源路径进行正则匹配,匹配成功就使用对应函数进行处理 // /numbers/(\d+) /numbers/12345 for (auto &handler : handlers) { const std::regex &re = handler.first; const Handler &functor = handler.second; bool ret = std::regex_match(req._path, req._matches, re); if (ret == false) { continue; } return functor(req, rsp);//传入请求信息,和空的rsp,执行处理函数 } rsp->_statu = 404; } void Route(HttpRequest &req, HttpResponse *rsp) { //1. 对请求进行分辨,是一个静态资源请求,还是一个功能性请求 // 静态资源请求,则进行静态资源的处理 // 功能性请求,则需要通过几个请求路由表来确定是否有处理函数 // 既不是静态资源请求,也没有设置对应的功能性请求处理函数,就返回405 if (IsFileHandler(req) == true) { //是一个静态资源请求, 则进行静态资源请求的处理 return FileHandler(req, rsp); } if (req._method == "GET" || req._method == "HEAD") { return Dispatcher(req, rsp, _get_route); }else if (req._method == "POST") { return Dispatcher(req, rsp, _post_route); }else if (req._method == "PUT") { return Dispatcher(req, rsp, _put_route); }else if (req._method == "DELETE") { return Dispatcher(req, rsp, _delete_route); } rsp->_statu = 405;// Method Not Allowed return ; } //设置上下文 void OnConnected(const PtrConnection &conn) { conn->SetContext(HttpContext()); DBG_LOG("NEW CONNECTION %p", conn.get()); } //缓冲区数据解析+处理 void OnMessage(const PtrConnection &conn, Buffer *buffer) { while(buffer->ReadAbleSize() > 0){ //1. 获取上下文 HttpContext *context = conn->GetContext()->get<HttpContext>(); //2. 通过上下文对缓冲区数据进行解析,得到HttpRequest对象 // 1. 如果缓冲区的数据解析出错,就直接回复出错响应 // 2. 如果解析正常,且请求已经获取完毕,才开始去进行处理 context->RecvHttpRequest(buffer); HttpRequest &req = context->Request(); HttpResponse rsp(context->RespStatu()); if (context->RespStatu() >= 400) { //进行错误响应,关闭连接 ErrorHandler(req, &rsp);//填充一个错误显示页面数据到rsp中 WriteReponse(conn, req, rsp);//组织响应发送给客户端 context->ReSet(); buffer->MoveReadOffset(buffer->ReadAbleSize());//出错了就把缓冲区数据清空 conn->Shutdown();//关闭连接 return; } if (context->RecvStatu() != RECV_HTTP_OVER) { //当前请求还没有接收完整,则退出,等新数据到来再重新继续处理 return; } //3. 请求路由 + 业务处理 Route(req, &rsp); //4. 对HttpResponse进行组织发送 WriteReponse(conn, req, rsp); //5. 重置上下文 context->ReSet(); //6. 根据长短连接判断是否关闭连接或者继续处理 if (rsp.Close() == true) conn->Shutdown();//短链接则直接关闭 } return; } public: HttpServer(int port, int timeout = DEFALT_TIMEOUT):_server(port) { _server.EnableInactiveRelease(timeout); _server.SetConnectedCallback(std::bind(&HttpServer::OnConnected, this, std::placeholders::_1)); _server.SetMessageCallback(std::bind(&HttpServer::OnMessage, this, std::placeholders::_1, std::placeholders::_2)); } void SetBaseDir(const std::string &path) { assert(Util::IsDirectory(path) == true); _basedir = path; } /*设置/添加,请求(请求的正则表达)与处理函数的映射关系*/ void Get(const std::string &pattern, const Handler &handler) { _get_route.push_back(std::make_pair(std::regex(pattern), handler)); } void Post(const std::string &pattern, const Handler &handler) { _post_route.push_back(std::make_pair(std::regex(pattern), handler)); } void Put(const std::string &pattern, const Handler &handler) { _put_route.push_back(std::make_pair(std::regex(pattern), handler)); } void Delete(const std::string &pattern, const Handler &handler) { _delete_route.push_back(std::make_pair(std::regex(pattern), handler)); } void SetThreadCount(int count) { _server.SetThreadCount(count); } void Listen() { _server.Start(); } };
#ifndef __M_SERVER_H__ #define __M_SERVER_H__ #include <iostream> #include <vector> #include <string> #include <cassert> #include <cstring> #include <ctime> #include <functional> #include <unordered_map> #include <thread> #include <mutex> #include <condition_variable> #include <memory> #include <typeinfo> #include <fcntl.h> #include <signal.h> #include <unistd.h> #include <netinet/in.h> #include <arpa/inet.h> #include <sys/socket.h> #include <sys/epoll.h> #include <sys/eventfd.h> #include <sys/timerfd.h> #define INF 0 #define DBG 1 #define ERR 2 #define LOG_LEVEL DBG #define LOG(level, format, ...) do{\ if (level < LOG_LEVEL) break;\ time_t t = time(NULL);\ struct tm *ltm = localtime(&t);\ char tmp[32] = {0};\ strftime(tmp, 31, "%H:%M:%S", ltm);\ fprintf(stdout, "[%p %s %s:%d] " format "\n", (void*)pthread_self(), tmp, __FILE__, __LINE__, ##__VA_ARGS__);\ }while(0) #define INF_LOG(format, ...) LOG(INF, format, ##__VA_ARGS__) #define DBG_LOG(format, ...) LOG(DBG, format, ##__VA_ARGS__) #define ERR_LOG(format, ...) LOG(ERR, format, ##__VA_ARGS__) #define BUFFER_DEFAULT_SIZE 1024 class Buffer { private: std::vector<char> _buffer; //使用vector进行内存空间管理 uint64_t _reader_idx; //读偏移 uint64_t _writer_idx; //写偏移 public: Buffer():_reader_idx(0), _writer_idx(0), _buffer(BUFFER_DEFAULT_SIZE){} char *Begin() { return &*_buffer.begin(); } //获取当前写入起始地址, _buffer的空间起始地址,加上写偏移量 char *WritePosition() { return Begin() + _writer_idx; } //获取当前读取起始地址 char *ReadPosition() { return Begin() + _reader_idx; } //获取缓冲区末尾空闲空间大小--写偏移之后的空闲空间, 总体空间大小减去写偏移 uint64_t TailIdleSize() { return _buffer.size() - _writer_idx; } //获取缓冲区起始空闲空间大小--读偏移之前的空闲空间 uint64_t HeadIdleSize() { return _reader_idx; } //获取可读数据大小 = 写偏移 - 读偏移 uint64_t ReadAbleSize() { return _writer_idx - _reader_idx; } //将读偏移向后移动 void MoveReadOffset(uint64_t len) { if (len == 0) return; //向后移动的大小,必须小于可读数据大小 assert(len <= ReadAbleSize()); _reader_idx += len; } //将写偏移向后移动 void MoveWriteOffset(uint64_t len) { //向后移动的大小,必须小于当前后边的空闲空间大小 assert(len <= TailIdleSize()); _writer_idx += len; } //确保可写空间足够(整体空闲空间够了就移动数据,否则就扩容) void EnsureWriteSpace(uint64_t len) { //如果末尾空闲空间大小足够,直接返回 if (TailIdleSize() >= len) { return; } //末尾空闲空间不够,则判断加上起始位置的空闲空间大小是否足够, 够了就将数据移动到起始位置 if (len <= TailIdleSize() + HeadIdleSize()) { //将数据移动到起始位置 uint64_t rsz = ReadAbleSize();//把当前数据大小先保存起来 std::copy(ReadPosition(), ReadPosition() + rsz, Begin());//把可读数据拷贝到起始位置 _reader_idx = 0; //将读偏移归0 _writer_idx = rsz; //将写位置置为可读数据大小, 因为当前的可读数据大小就是写偏移量 }else { //总体空间不够,则需要扩容,不移动数据,直接给写偏移之后扩容足够空间即可 DBG_LOG("RESIZE %ld", _writer_idx + len); _buffer.resize(_writer_idx + len); } } //写入数据 void Write(const void *data, uint64_t len) { //1. 保证有足够空间,2. 拷贝数据进去 if (len == 0) return; EnsureWriteSpace(len); const char *d = (const char *)data; std::copy(d, d + len, WritePosition()); } void WriteAndPush(const void *data, uint64_t len) { Write(data, len); MoveWriteOffset(len); } void WriteString(const std::string &data) { return Write(data.c_str(), data.size()); } void WriteStringAndPush(const std::string &data) { WriteString(data); MoveWriteOffset(data.size()); } void WriteBuffer(Buffer &data) { return Write(data.ReadPosition(), data.ReadAbleSize()); } void WriteBufferAndPush(Buffer &data) { WriteBuffer(data); MoveWriteOffset(data.ReadAbleSize()); } //读取数据 void Read(void *buf, uint64_t len) { //要求要获取的数据大小必须小于可读数据大小 assert(len <= ReadAbleSize()); std::copy(ReadPosition(), ReadPosition() + len, (char*)buf); } void ReadAndPop(void *buf, uint64_t len) { Read(buf, len); MoveReadOffset(len); } std::string ReadAsString(uint64_t len) { //要求要获取的数据大小必须小于可读数据大小 assert(len <= ReadAbleSize()); std::string str; str.resize(len); Read(&str[0], len); return str; } std::string ReadAsStringAndPop(uint64_t len) { assert(len <= ReadAbleSize()); std::string str = ReadAsString(len); MoveReadOffset(len); return str; } char *FindCRLF() { char *res = (char*)memchr(ReadPosition(), '\n', ReadAbleSize()); return res; } /*通常获取一行数据,这种情况针对是*/ std::string GetLine() { char *pos = FindCRLF(); if (pos == NULL) { return ""; } // +1是为了把换行字符也取出来。 return ReadAsString(pos - ReadPosition() + 1); } std::string GetLineAndPop() { std::string str = GetLine(); MoveReadOffset(str.size()); return str; } //清空缓冲区 void Clear() { //只需要将偏移量归0即可 _reader_idx = 0; _writer_idx = 0; } }; #define MAX_LISTEN 1024 class Socket { private: int _sockfd; public: Socket():_sockfd(-1) {} Socket(int fd): _sockfd(fd) {} ~Socket() { Close(); } int Fd() { return _sockfd; } //创建套接字 bool Create() { // int socket(int domain, int type, int protocol) _sockfd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); if (_sockfd < 0) { ERR_LOG("CREATE SOCKET FAILED!!"); return false; } return true; } //绑定地址信息 bool Bind(const std::string &ip, uint16_t port) { struct sockaddr_in addr; addr.sin_family = AF_INET; addr.sin_port = htons(port); addr.sin_addr.s_addr = inet_addr(ip.c_str()); socklen_t len = sizeof(struct sockaddr_in); // int bind(int sockfd, struct sockaddr*addr, socklen_t len); int ret = bind(_sockfd, (struct sockaddr*)&addr, len); if (ret < 0) { ERR_LOG("BIND ADDRESS FAILED!"); return false; } return true; } //开始监听 bool Listen(int backlog = MAX_LISTEN) { // int listen(int backlog) int ret = listen(_sockfd, backlog); if (ret < 0) { ERR_LOG("SOCKET LISTEN FAILED!"); return false; } return true; } //向服务器发起连接 bool Connect(const std::string &ip, uint16_t port) { struct sockaddr_in addr; addr.sin_family = AF_INET; addr.sin_port = htons(port); addr.sin_addr.s_addr = inet_addr(ip.c_str()); socklen_t len = sizeof(struct sockaddr_in); // int connect(int sockfd, struct sockaddr*addr, socklen_t len); int ret = connect(_sockfd, (struct sockaddr*)&addr, len); if (ret < 0) { ERR_LOG("CONNECT SERVER FAILED!"); return false; } return true; } //获取新连接 int Accept() { // int accept(int sockfd, struct sockaddr *addr, socklen_t *len); int newfd = accept(_sockfd, NULL, NULL); if (newfd < 0) { ERR_LOG("SOCKET ACCEPT FAILED!"); return -1; } return newfd; } //接收数据 ssize_t Recv(void *buf, size_t len, int flag = 0) { // ssize_t recv(int sockfd, void *buf, size_t len, int flag); ssize_t ret = recv(_sockfd, buf, len, flag); if (ret <= 0) { //EAGAIN 当前socket的接收缓冲区中没有数据了,在非阻塞的情况下才会有这个错误 //EINTR 表示当前socket的阻塞等待,被信号打断了, if (errno == EAGAIN || errno == EINTR) { return 0;//表示这次接收没有接收到数据 } ERR_LOG("SOCKET RECV FAILED!!"); return -1; } return ret; //实际接收的数据长度 } ssize_t NonBlockRecv(void *buf, size_t len) { return Recv(buf, len, MSG_DONTWAIT); // MSG_DONTWAIT 表示当前接收为非阻塞。 } //发送数据 ssize_t Send(const void *buf, size_t len, int flag = 0) { // ssize_t send(int sockfd, void *data, size_t len, int flag); ssize_t ret = send(_sockfd, buf, len, flag); if (ret < 0) { if (errno == EAGAIN || errno == EINTR) { return 0; } ERR_LOG("SOCKET SEND FAILED!!"); return -1; } return ret;//实际发送的数据长度 } ssize_t NonBlockSend(void *buf, size_t len) { if (len == 0) return 0; return Send(buf, len, MSG_DONTWAIT); // MSG_DONTWAIT 表示当前发送为非阻塞。 } //关闭套接字 void Close() { if (_sockfd != -1) { close(_sockfd); _sockfd = -1; } } //创建一个服务端连接 bool CreateServer(uint16_t port, const std::string &ip = "0.0.0.0", bool block_flag = false) { //1. 创建套接字,2. 绑定地址,3. 开始监听,4. 设置非阻塞, 5. 启动地址重用 if (Create() == false) return false; if (block_flag) NonBlock(); if (Bind(ip, port) == false) return false; if (Listen() == false) return false; ReuseAddress(); return true; } //创建一个客户端连接 bool CreateClient(uint16_t port, const std::string &ip) { //1. 创建套接字,2.指向连接服务器 if (Create() == false) return false; if (Connect(ip, port) == false) return false; return true; } //设置套接字选项---开启地址端口重用 void ReuseAddress() { // int setsockopt(int fd, int leve, int optname, void *val, int vallen) int val = 1; setsockopt(_sockfd, SOL_SOCKET, SO_REUSEADDR, (void*)&val, sizeof(int)); val = 1; setsockopt(_sockfd, SOL_SOCKET, SO_REUSEPORT, (void*)&val, sizeof(int)); } //设置套接字阻塞属性-- 设置为非阻塞 void NonBlock() { //int fcntl(int fd, int cmd, ... /* arg */ ); int flag = fcntl(_sockfd, F_GETFL, 0); fcntl(_sockfd, F_SETFL, flag | O_NONBLOCK); } }; class Poller; class EventLoop; class Channel { private: int _fd; EventLoop *_loop; uint32_t _events; // 当前需要监控的事件 uint32_t _revents; // 当前连接触发的事件 using EventCallback = std::function<void()>; EventCallback _read_callback; //可读事件被触发的回调函数 EventCallback _write_callback; //可写事件被触发的回调函数 EventCallback _error_callback; //错误事件被触发的回调函数 EventCallback _close_callback; //连接断开事件被触发的回调函数 EventCallback _event_callback; //任意事件被触发的回调函数 public: Channel(EventLoop *loop, int fd):_fd(fd), _events(0), _revents(0), _loop(loop) {} int Fd() { return _fd; } uint32_t Events() { return _events; }//获取想要监控的事件 void SetREvents(uint32_t events) { _revents = events; }//设置实际就绪的事件 void SetReadCallback(const EventCallback &cb) { _read_callback = cb; } void SetWriteCallback(const EventCallback &cb) { _write_callback = cb; } void SetErrorCallback(const EventCallback &cb) { _error_callback = cb; } void SetCloseCallback(const EventCallback &cb) { _close_callback = cb; } void SetEventCallback(const EventCallback &cb) { _event_callback = cb; } //当前是否监控了可读 bool ReadAble() { return (_events & EPOLLIN); } //当前是否监控了可写 bool WriteAble() { return (_events & EPOLLOUT); } //启动读事件监控 void EnableRead() { _events |= EPOLLIN; Update(); } //启动写事件监控 void EnableWrite() { _events |= EPOLLOUT; Update(); } //关闭读事件监控 void DisableRead() { _events &= ~EPOLLIN; Update(); } //关闭写事件监控 void DisableWrite() { _events &= ~EPOLLOUT; Update(); } //关闭所有事件监控 void DisableAll() { _events = 0; Update(); } //移除监控 void Remove(); void Update(); //事件处理,一旦连接触发了事件,就调用这个函数,自己触发了什么事件如何处理自己决定 void HandleEvent() { if ((_revents & EPOLLIN) || (_revents & EPOLLRDHUP) || (_revents & EPOLLPRI)) { /*不管任何事件,都调用的回调函数*/ if (_read_callback) _read_callback(); } /*有可能会释放连接的操作事件,一次只处理一个*/ if (_revents & EPOLLOUT) { if (_write_callback) _write_callback(); }else if (_revents & EPOLLERR) { if (_error_callback) _error_callback();//一旦出错,就会释放连接,因此要放到前边调用任意回调 }else if (_revents & EPOLLHUP) { if (_close_callback) _close_callback(); } if (_event_callback) _event_callback(); } }; #define MAX_EPOLLEVENTS 1024 class Poller { private: int _epfd; struct epoll_event _evs[MAX_EPOLLEVENTS]; std::unordered_map<int, Channel *> _channels; private: //对epoll的直接操作 void Update(Channel *channel, int op) { // int epoll_ctl(int epfd, int op, int fd, struct epoll_event *ev); int fd = channel->Fd(); struct epoll_event ev; ev.data.fd = fd; ev.events = channel->Events(); int ret = epoll_ctl(_epfd, op, fd, &ev); if (ret < 0) { ERR_LOG("EPOLLCTL FAILED!"); } return; } //判断一个Channel是否已经添加了事件监控 bool HasChannel(Channel *channel) { auto it = _channels.find(channel->Fd()); if (it == _channels.end()) { return false; } return true; } public: Poller() { _epfd = epoll_create(MAX_EPOLLEVENTS); if (_epfd < 0) { ERR_LOG("EPOLL CREATE FAILED!!"); abort();//退出程序 } } //添加或修改监控事件 void UpdateEvent(Channel *channel) { bool ret = HasChannel(channel); if (ret == false) { //不存在则添加 _channels.insert(std::make_pair(channel->Fd(), channel)); return Update(channel, EPOLL_CTL_ADD); } return Update(channel, EPOLL_CTL_MOD); } //移除监控 void RemoveEvent(Channel *channel) { auto it = _channels.find(channel->Fd()); if (it != _channels.end()) { _channels.erase(it); } Update(channel, EPOLL_CTL_DEL); } //开始监控,返回活跃连接 void Poll(std::vector<Channel*> *active) { // int epoll_wait(int epfd, struct epoll_event *evs, int maxevents, int timeout) int nfds = epoll_wait(_epfd, _evs, MAX_EPOLLEVENTS, -1); if (nfds < 0) { if (errno == EINTR) { return ; } ERR_LOG("EPOLL WAIT ERROR:%s\n", strerror(errno)); abort();//退出程序 } for (int i = 0; i < nfds; i++) { auto it = _channels.find(_evs[i].data.fd); assert(it != _channels.end()); it->second->SetREvents(_evs[i].events);//设置实际就绪的事件 active->push_back(it->second); } return; } }; using TaskFunc = std::function<void()>; using ReleaseFunc = std::function<void()>; class TimerTask{ private: uint64_t _id; // 定时器任务对象ID uint32_t _timeout; //定时任务的超时时间 bool _canceled; // false-表示没有被取消, true-表示被取消 TaskFunc _task_cb; //定时器对象要执行的定时任务 ReleaseFunc _release; //用于删除TimerWheel中保存的定时器对象信息 public: TimerTask(uint64_t id, uint32_t delay, const TaskFunc &cb): _id(id), _timeout(delay), _task_cb(cb), _canceled(false) {} ~TimerTask() { if (_canceled == false) _task_cb(); _release(); } void Cancel() { _canceled = true; } void SetRelease(const ReleaseFunc &cb) { _release = cb; } uint32_t DelayTime() { return _timeout; } }; class TimerWheel { private: using WeakTask = std::weak_ptr<TimerTask>; using PtrTask = std::shared_ptr<TimerTask>; int _tick; //当前的秒针,走到哪里释放哪里,释放哪里,就相当于执行哪里的任务 int _capacity; //表盘最大数量---其实就是最大延迟时间 std::vector<std::vector<PtrTask>> _wheel; std::unordered_map<uint64_t, WeakTask> _timers; EventLoop *_loop; int _timerfd;//定时器描述符--可读事件回调就是读取计数器,执行定时任务 std::unique_ptr<Channel> _timer_channel; private: void RemoveTimer(uint64_t id) { auto it = _timers.find(id); if (it != _timers.end()) { _timers.erase(it); } } static int CreateTimerfd() { int timerfd = timerfd_create(CLOCK_MONOTONIC, 0); if (timerfd < 0) { ERR_LOG("TIMERFD CREATE FAILED!"); abort(); } //int timerfd_settime(int fd, int flags, struct itimerspec *new, struct itimerspec *old); struct itimerspec itime; itime.it_value.tv_sec = 1; itime.it_value.tv_nsec = 0;//第一次超时时间为1s后 itime.it_interval.tv_sec = 1; itime.it_interval.tv_nsec = 0; //第一次超时后,每次超时的间隔时 timerfd_settime(timerfd, 0, &itime, NULL); return timerfd; } int ReadTimefd() { uint64_t times; //有可能因为其他描述符的事件处理花费事件比较长,然后在处理定时器描述符事件的时候,有可能就已经超时了很多次 //read读取到的数据times就是从上一次read之后超时的次数 int ret = read(_timerfd, ×, 8); if (ret < 0) { ERR_LOG("READ TIMEFD FAILED!"); abort(); } return times; } //这个函数应该每秒钟被执行一次,相当于秒针向后走了一步 void RunTimerTask() { _tick = (_tick + 1) % _capacity; _wheel[_tick].clear();//清空指定位置的数组,就会把数组中保存的所有管理定时器对象的shared_ptr释放掉 } void OnTime() { //根据实际超时的次数,执行对应的超时任务 int times = ReadTimefd(); for (int i = 0; i < times; i++) { RunTimerTask(); } } void TimerAddInLoop(uint64_t id, uint32_t delay, const TaskFunc &cb) { PtrTask pt(new TimerTask(id, delay, cb)); pt->SetRelease(std::bind(&TimerWheel::RemoveTimer, this, id)); int pos = (_tick + delay) % _capacity; _wheel[pos].push_back(pt); _timers[id] = WeakTask(pt); } void TimerRefreshInLoop(uint64_t id) { //通过保存的定时器对象的weak_ptr构造一个shared_ptr出来,添加到轮子中 auto it = _timers.find(id); if (it == _timers.end()) { return;//没找着定时任务,没法刷新,没法延迟 } PtrTask pt = it->second.lock();//lock获取weak_ptr管理的对象对应的shared_ptr int delay = pt->DelayTime(); int pos = (_tick + delay) % _capacity; _wheel[pos].push_back(pt); } void TimerCancelInLoop(uint64_t id) { auto it = _timers.find(id); if (it == _timers.end()) { return;//没找着定时任务,没法刷新,没法延迟 } PtrTask pt = it->second.lock(); if (pt) pt->Cancel(); } public: TimerWheel(EventLoop *loop):_capacity(60), _tick(0), _wheel(_capacity), _loop(loop), _timerfd(CreateTimerfd()), _timer_channel(new Channel(_loop, _timerfd)) { _timer_channel->SetReadCallback(std::bind(&TimerWheel::OnTime, this)); _timer_channel->EnableRead();//启动读事件监控 } /*定时器中有个_timers成员,定时器信息的操作有可能在多线程中进行,因此需要考虑线程安全问题*/ /*如果不想加锁,那就把对定期的所有操作,都放到一个线程中进行*/ void TimerAdd(uint64_t id, uint32_t delay, const TaskFunc &cb); //刷新/延迟定时任务 void TimerRefresh(uint64_t id); void TimerCancel(uint64_t id); /*这个接口存在线程安全问题--这个接口实际上不能被外界使用者调用,只能在模块内,在对应的EventLoop线程内执行*/ bool HasTimer(uint64_t id) { auto it = _timers.find(id); if (it == _timers.end()) { return false; } return true; } }; class EventLoop { private: using Functor = std::function<void()>; std::thread::id _thread_id;//线程ID int _event_fd;//eventfd唤醒IO事件监控有可能导致的阻塞 std::unique_ptr<Channel> _event_channel; Poller _poller;//进行所有描述符的事件监控 std::vector<Functor> _tasks;//任务池 std::mutex _mutex;//实现任务池操作的线程安全 TimerWheel _timer_wheel;//定时器模块 public: //执行任务池中的所有任务 void RunAllTask() { std::vector<Functor> functor; { std::unique_lock<std::mutex> _lock(_mutex); _tasks.swap(functor); } for (auto &f : functor) { f(); } return ; } static int CreateEventFd() { int efd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); if (efd < 0) { ERR_LOG("CREATE EVENTFD FAILED!!"); abort();//让程序异常退出 } return efd; } void ReadEventfd() { uint64_t res = 0; int ret = read(_event_fd, &res, sizeof(res)); if (ret < 0) { //EINTR -- 被信号打断; EAGAIN -- 表示无数据可读 if (errno == EINTR || errno == EAGAIN) { return; } ERR_LOG("READ EVENTFD FAILED!"); abort(); } return ; } void WeakUpEventFd() { uint64_t val = 1; int ret = write(_event_fd, &val, sizeof(val)); if (ret < 0) { if (errno == EINTR) { return; } ERR_LOG("READ EVENTFD FAILED!"); abort(); } return ; } public: EventLoop():_thread_id(std::this_thread::get_id()), _event_fd(CreateEventFd()), _event_channel(new Channel(this, _event_fd)), _timer_wheel(this) { //给eventfd添加可读事件回调函数,读取eventfd事件通知次数 _event_channel->SetReadCallback(std::bind(&EventLoop::ReadEventfd, this)); //启动eventfd的读事件监控 _event_channel->EnableRead(); } //三步走--事件监控-》就绪事件处理-》执行任务 void Start() { while(1) { //1. 事件监控, std::vector<Channel *> actives; _poller.Poll(&actives); //2. 事件处理。 for (auto &channel : actives) { channel->HandleEvent(); } //3. 执行任务 RunAllTask(); } } //用于判断当前线程是否是EventLoop对应的线程; bool IsInLoop() { return (_thread_id == std::this_thread::get_id()); } void AssertInLoop() { assert(_thread_id == std::this_thread::get_id()); } //判断将要执行的任务是否处于当前线程中,如果是则执行,不是则压入队列。 void RunInLoop(const Functor &cb) { if (IsInLoop()) { return cb(); } return QueueInLoop(cb); } //将操作压入任务池 void QueueInLoop(const Functor &cb) { { std::unique_lock<std::mutex> _lock(_mutex); _tasks.push_back(cb); } //唤醒有可能因为没有事件就绪,而导致的epoll阻塞; //其实就是给eventfd写入一个数据,eventfd就会触发可读事件 WeakUpEventFd(); } //添加/修改描述符的事件监控 void UpdateEvent(Channel *channel) { return _poller.UpdateEvent(channel); } //移除描述符的监控 void RemoveEvent(Channel *channel) { return _poller.RemoveEvent(channel); } void TimerAdd(uint64_t id, uint32_t delay, const TaskFunc &cb) { return _timer_wheel.TimerAdd(id, delay, cb); } void TimerRefresh(uint64_t id) { return _timer_wheel.TimerRefresh(id); } void TimerCancel(uint64_t id) { return _timer_wheel.TimerCancel(id); } bool HasTimer(uint64_t id) { return _timer_wheel.HasTimer(id); } }; class LoopThread { private: /*用于实现_loop获取的同步关系,避免线程创建了,但是_loop还没有实例化之前去获取_loop*/ std::mutex _mutex; // 互斥锁 std::condition_variable _cond; // 条件变量 EventLoop *_loop; // EventLoop指针变量,这个对象需要在线程内实例化 std::thread _thread; // EventLoop对应的线程 private: /*实例化 EventLoop 对象,唤醒_cond上有可能阻塞的线程,并且开始运行EventLoop模块的功能*/ void ThreadEntry() { EventLoop loop; { std::unique_lock<std::mutex> lock(_mutex);//加锁 _loop = &loop; _cond.notify_all(); } loop.Start(); } public: /*创建线程,设定线程入口函数*/ LoopThread():_loop(NULL), _thread(std::thread(&LoopThread::ThreadEntry, this)) {} /*返回当前线程关联的EventLoop对象指针*/ EventLoop *GetLoop() { EventLoop *loop = NULL; { std::unique_lock<std::mutex> lock(_mutex);//加锁 _cond.wait(lock, [&](){ return _loop != NULL; });//loop为NULL就一直阻塞 loop = _loop; } return loop; } }; class LoopThreadPool { private: int _thread_count; int _next_idx; EventLoop *_baseloop; std::vector<LoopThread*> _threads; std::vector<EventLoop *> _loops; public: LoopThreadPool(EventLoop *baseloop):_thread_count(0), _next_idx(0), _baseloop(baseloop) {} void SetThreadCount(int count) { _thread_count = count; } void Create() { if (_thread_count > 0) { _threads.resize(_thread_count); _loops.resize(_thread_count); for (int i = 0; i < _thread_count; i++) { _threads[i] = new LoopThread(); _loops[i] = _threads[i]->GetLoop(); } } return ; } EventLoop *NextLoop() { if (_thread_count == 0) { return _baseloop; } _next_idx = (_next_idx + 1) % _thread_count; return _loops[_next_idx]; } }; class Any{ private: class holder { public: virtual ~holder() {} virtual const std::type_info& type() = 0; virtual holder *clone() = 0; }; template<class T> class placeholder: public holder { public: placeholder(const T &val): _val(val) {} // 获取子类对象保存的数据类型 virtual const std::type_info& type() { return typeid(T); } // 针对当前的对象自身,克隆出一个新的子类对象 virtual holder *clone() { return new placeholder(_val); } public: T _val; }; holder *_content; public: Any():_content(NULL) {} template<class T> Any(const T &val):_content(new placeholder<T>(val)) {} Any(const Any &other):_content(other._content ? other._content->clone() : NULL) {} ~Any() { delete _content; } Any &swap(Any &other) { std::swap(_content, other._content); return *this; } // 返回子类对象保存的数据的指针 template<class T> T *get() { //想要获取的数据类型,必须和保存的数据类型一致 assert(typeid(T) == _content->type()); return &((placeholder<T>*)_content)->_val; } //赋值运算符的重载函数 template<class T> Any& operator=(const T &val) { //为val构造一个临时的通用容器,然后与当前容器自身进行指针交换,临时对象释放的时候,原先保存的数据也就被释放 Any(val).swap(*this); return *this; } Any& operator=(const Any &other) { Any(other).swap(*this); return *this; } }; class Connection; //DISCONECTED -- 连接关闭状态; CONNECTING -- 连接建立成功-待处理状态 //CONNECTED -- 连接建立完成,各种设置已完成,可以通信的状态; DISCONNECTING -- 待关闭状态 typedef enum { DISCONNECTED, CONNECTING, CONNECTED, DISCONNECTING}ConnStatu; using PtrConnection = std::shared_ptr<Connection>; class Connection : public std::enable_shared_from_this<Connection> { private: uint64_t _conn_id; // 连接的唯一ID,便于连接的管理和查找 //uint64_t _timer_id; //定时器ID,必须是唯一的,这块为了简化操作使用conn_id作为定时器ID int _sockfd; // 连接关联的文件描述符 bool _enable_inactive_release; // 连接是否启动非活跃销毁的判断标志,默认为false EventLoop *_loop; // 连接所关联的一个EventLoop ConnStatu _statu; // 连接状态 Socket _socket; // 套接字操作管理 Channel _channel; // 连接的事件管理 Buffer _in_buffer; // 输入缓冲区---存放从socket中读取到的数据 Buffer _out_buffer; // 输出缓冲区---存放要发送给对端的数据 Any _context; // 请求的接收处理上下文 /*这四个回调函数,是让服务器模块来设置的(其实服务器模块的处理回调也是组件使用者设置的)*/ /*换句话说,这几个回调都是组件使用者使用的*/ using ConnectedCallback = std::function<void(const PtrConnection&)>; using MessageCallback = std::function<void(const PtrConnection&, Buffer *)>; using ClosedCallback = std::function<void(const PtrConnection&)>; using AnyEventCallback = std::function<void(const PtrConnection&)>; ConnectedCallback _connected_callback; MessageCallback _message_callback; ClosedCallback _closed_callback; AnyEventCallback _event_callback; /*组件内的连接关闭回调--组件内设置的,因为服务器组件内会把所有的连接管理起来,一旦某个连接要关闭*/ /*就应该从管理的地方移除掉自己的信息*/ ClosedCallback _server_closed_callback; private: /*五个channel的事件回调函数*/ //描述符可读事件触发后调用的函数,接收socket数据放到接收缓冲区中,然后调用_message_callback void HandleRead() { //1. 接收socket的数据,放到缓冲区 char buf[65536]; ssize_t ret = _socket.NonBlockRecv(buf, 65535); if (ret < 0) { //出错了,不能直接关闭连接 return ShutdownInLoop(); } //这里的等于0表示的是没有读取到数据,而并不是连接断开了,连接断开返回的是-1 //将数据放入输入缓冲区,写入之后顺便将写偏移向后移动 _in_buffer.WriteAndPush(buf, ret); //2. 调用message_callback进行业务处理 if (_in_buffer.ReadAbleSize() > 0) { //shared_from_this--从当前对象自身获取自身的shared_ptr管理对象 return _message_callback(shared_from_this(), &_in_buffer); } } //描述符可写事件触发后调用的函数,将发送缓冲区中的数据进行发送 void HandleWrite() { //_out_buffer中保存的数据就是要发送的数据 ssize_t ret = _socket.NonBlockSend(_out_buffer.ReadPosition(), _out_buffer.ReadAbleSize()); if (ret < 0) { //发送错误就该关闭连接了, if (_in_buffer.ReadAbleSize() > 0) { _message_callback(shared_from_this(), &_in_buffer); } return Release();//这时候就是实际的关闭释放操作了。 } _out_buffer.MoveReadOffset(ret);//千万不要忘了,将读偏移向后移动 if (_out_buffer.ReadAbleSize() == 0) { _channel.DisableWrite();// 没有数据待发送了,关闭写事件监控 //如果当前是连接待关闭状态,则有数据,发送完数据释放连接,没有数据则直接释放 if (_statu == DISCONNECTING) { return Release(); } } return; } //描述符触发挂断事件 void HandleClose() { /*一旦连接挂断了,套接字就什么都干不了了,因此有数据待处理就处理一下,完毕关闭连接*/ if (_in_buffer.ReadAbleSize() > 0) { _message_callback(shared_from_this(), &_in_buffer); } return Release(); } //描述符触发出错事件 void HandleError() { return HandleClose(); } //描述符触发任意事件: 1. 刷新连接的活跃度--延迟定时销毁任务; 2. 调用组件使用者的任意事件回调 void HandleEvent() { if (_enable_inactive_release == true) { _loop->TimerRefresh(_conn_id); } if (_event_callback) { _event_callback(shared_from_this()); } } //连接获取之后,所处的状态下要进行各种设置(启动读监控,调用回调函数) void EstablishedInLoop() { // 1. 修改连接状态; 2. 启动读事件监控; 3. 调用回调函数 assert(_statu == CONNECTING);//当前的状态必须一定是上层的半连接状态 _statu = CONNECTED;//当前函数执行完毕,则连接进入已完成连接状态 // 一旦启动读事件监控就有可能会立即触发读事件,如果这时候启动了非活跃连接销毁 _channel.EnableRead(); if (_connected_callback) _connected_callback(shared_from_this()); } //这个接口才是实际的释放接口 void ReleaseInLoop() { //1. 修改连接状态,将其置为DISCONNECTED _statu = DISCONNECTED; //2. 移除连接的事件监控 _channel.Remove(); //3. 关闭描述符 _socket.Close(); //4. 如果当前定时器队列中还有定时销毁任务,则取消任务 if (_loop->HasTimer(_conn_id)) CancelInactiveReleaseInLoop(); //5. 调用关闭回调函数,避免先移除服务器管理的连接信息导致Connection被释放,再去处理会出错,因此先调用用户的回调函数 if (_closed_callback) _closed_callback(shared_from_this()); //移除服务器内部管理的连接信息 if (_server_closed_callback) _server_closed_callback(shared_from_this()); } //这个接口并不是实际的发送接口,而只是把数据放到了发送缓冲区,启动了可写事件监控 void SendInLoop(Buffer &buf) { if (_statu == DISCONNECTED) return ; _out_buffer.WriteBufferAndPush(buf); if (_channel.WriteAble() == false) { _channel.EnableWrite(); } } //这个关闭操作并非实际的连接释放操作,需要判断还有没有数据待处理,待发送 void ShutdownInLoop() { _statu = DISCONNECTING;// 设置连接为半关闭状态 if (_in_buffer.ReadAbleSize() > 0) { if (_message_callback) _message_callback(shared_from_this(), &_in_buffer); } //要么就是写入数据的时候出错关闭,要么就是没有待发送数据,直接关闭 if (_out_buffer.ReadAbleSize() > 0) { if (_channel.WriteAble() == false) { _channel.EnableWrite(); } } if (_out_buffer.ReadAbleSize() == 0) { Release(); } } //启动非活跃连接超时释放规则 void EnableInactiveReleaseInLoop(int sec) { //1. 将判断标志 _enable_inactive_release 置为true _enable_inactive_release = true; //2. 如果当前定时销毁任务已经存在,那就刷新延迟一下即可 if (_loop->HasTimer(_conn_id)) { return _loop->TimerRefresh(_conn_id); } //3. 如果不存在定时销毁任务,则新增 _loop->TimerAdd(_conn_id, sec, std::bind(&Connection::Release, this)); } void CancelInactiveReleaseInLoop() { _enable_inactive_release = false; if (_loop->HasTimer(_conn_id)) { _loop->TimerCancel(_conn_id); } } void UpgradeInLoop(const Any &context, const ConnectedCallback &conn, const MessageCallback &msg, const ClosedCallback &closed, const AnyEventCallback &event) { _context = context; _connected_callback = conn; _message_callback = msg; _closed_callback = closed; _event_callback = event; } public: Connection(EventLoop *loop, uint64_t conn_id, int sockfd):_conn_id(conn_id), _sockfd(sockfd), _enable_inactive_release(false), _loop(loop), _statu(CONNECTING), _socket(_sockfd), _channel(loop, _sockfd) { _channel.SetCloseCallback(std::bind(&Connection::HandleClose, this)); _channel.SetEventCallback(std::bind(&Connection::HandleEvent, this)); _channel.SetReadCallback(std::bind(&Connection::HandleRead, this)); _channel.SetWriteCallback(std::bind(&Connection::HandleWrite, this)); _channel.SetErrorCallback(std::bind(&Connection::HandleError, this)); } ~Connection() { DBG_LOG("RELEASE CONNECTION:%p", this); } //获取管理的文件描述符 int Fd() { return _sockfd; } //获取连接ID int Id() { return _conn_id; } //是否处于CONNECTED状态 bool Connected() { return (_statu == CONNECTED); } //设置上下文--连接建立完成时进行调用 void SetContext(const Any &context) { _context = context; } //获取上下文,返回的是指针 Any *GetContext() { return &_context; } void SetConnectedCallback(const ConnectedCallback&cb) { _connected_callback = cb; } void SetMessageCallback(const MessageCallback&cb) { _message_callback = cb; } void SetClosedCallback(const ClosedCallback&cb) { _closed_callback = cb; } void SetAnyEventCallback(const AnyEventCallback&cb) { _event_callback = cb; } void SetSrvClosedCallback(const ClosedCallback&cb) { _server_closed_callback = cb; } //连接建立就绪后,进行channel回调设置,启动读监控,调用_connected_callback void Established() { _loop->RunInLoop(std::bind(&Connection::EstablishedInLoop, this)); } //发送数据,将数据放到发送缓冲区,启动写事件监控 void Send(const char *data, size_t len) { //外界传入的data,可能是个临时的空间,我们现在只是把发送操作压入了任务池,有可能并没有被立即执行 //因此有可能执行的时候,data指向的空间有可能已经被释放了。 Buffer buf; buf.WriteAndPush(data, len); _loop->RunInLoop(std::bind(&Connection::SendInLoop, this, std::move(buf))); } //提供给组件使用者的关闭接口--并不实际关闭,需要判断有没有数据待处理 void Shutdown() { _loop->RunInLoop(std::bind(&Connection::ShutdownInLoop, this)); } void Release() { _loop->QueueInLoop(std::bind(&Connection::ReleaseInLoop, this)); } //启动非活跃销毁,并定义多长时间无通信就是非活跃,添加定时任务 void EnableInactiveRelease(int sec) { _loop->RunInLoop(std::bind(&Connection::EnableInactiveReleaseInLoop, this, sec)); } //取消非活跃销毁 void CancelInactiveRelease() { _loop->RunInLoop(std::bind(&Connection::CancelInactiveReleaseInLoop, this)); } //切换协议---重置上下文以及阶段性回调处理函数 -- 而是这个接口必须在EventLoop线程中立即执行 //防备新的事件触发后,处理的时候,切换任务还没有被执行--会导致数据使用原协议处理了。 void Upgrade(const Any &context, const ConnectedCallback &conn, const MessageCallback &msg, const ClosedCallback &closed, const AnyEventCallback &event) { _loop->AssertInLoop(); _loop->RunInLoop(std::bind(&Connection::UpgradeInLoop, this, context, conn, msg, closed, event)); } }; class Acceptor { private: Socket _socket;//用于创建监听套接字 EventLoop *_loop; //用于对监听套接字进行事件监控 Channel _channel; //用于对监听套接字进行事件管理 using AcceptCallback = std::function<void(int)>; AcceptCallback _accept_callback; private: /*监听套接字的读事件回调处理函数---获取新连接,调用_accept_callback函数进行新连接处理*/ void HandleRead() { int newfd = _socket.Accept(); if (newfd < 0) { return ; } if (_accept_callback) _accept_callback(newfd); } int CreateServer(int port) { bool ret = _socket.CreateServer(port); assert(ret == true); return _socket.Fd(); } public: /*不能将启动读事件监控,放到构造函数中,必须在设置回调函数后,再去启动*/ /*否则有可能造成启动监控后,立即有事件,处理的时候,回调函数还没设置:新连接得不到处理,且资源泄漏*/ Acceptor(EventLoop *loop, int port): _socket(CreateServer(port)), _loop(loop), _channel(loop, _socket.Fd()) { _channel.SetReadCallback(std::bind(&Acceptor::HandleRead, this)); } void SetAcceptCallback(const AcceptCallback &cb) { _accept_callback = cb; } void Listen() { _channel.EnableRead(); } }; class TcpServer { private: uint64_t _next_id; //这是一个自动增长的连接ID, int _port; int _timeout; //这是非活跃连接的统计时间---多长时间无通信就是非活跃连接 bool _enable_inactive_release;//是否启动了非活跃连接超时销毁的判断标志 EventLoop _baseloop; //这是主线程的EventLoop对象,负责监听事件的处理 Acceptor _acceptor; //这是监听套接字的管理对象 LoopThreadPool _pool; //这是从属EventLoop线程池 std::unordered_map<uint64_t, PtrConnection> _conns;//保存管理所有连接对应的shared_ptr对象 using ConnectedCallback = std::function<void(const PtrConnection&)>; using MessageCallback = std::function<void(const PtrConnection&, Buffer *)>; using ClosedCallback = std::function<void(const PtrConnection&)>; using AnyEventCallback = std::function<void(const PtrConnection&)>; using Functor = std::function<void()>; ConnectedCallback _connected_callback; MessageCallback _message_callback; ClosedCallback _closed_callback; AnyEventCallback _event_callback; private: void RunAfterInLoop(const Functor &task, int delay) { _next_id++; _baseloop.TimerAdd(_next_id, delay, task); } //为新连接构造一个Connection进行管理 void NewConnection(int fd) { _next_id++; PtrConnection conn(new Connection(_pool.NextLoop(), _next_id, fd)); conn->SetMessageCallback(_message_callback); conn->SetClosedCallback(_closed_callback); conn->SetConnectedCallback(_connected_callback); conn->SetAnyEventCallback(_event_callback); conn->SetSrvClosedCallback(std::bind(&TcpServer::RemoveConnection, this, std::placeholders::_1)); if (_enable_inactive_release) conn->EnableInactiveRelease(_timeout);//启动非活跃超时销毁 conn->Established();//就绪初始化 _conns.insert(std::make_pair(_next_id, conn)); } void RemoveConnectionInLoop(const PtrConnection &conn) { int id = conn->Id(); auto it = _conns.find(id); if (it != _conns.end()) { _conns.erase(it); } } //从管理Connection的_conns中移除连接信息 void RemoveConnection(const PtrConnection &conn) { _baseloop.RunInLoop(std::bind(&TcpServer::RemoveConnectionInLoop, this, conn)); } public: TcpServer(int port): _port(port), _next_id(0), _enable_inactive_release(false), _acceptor(&_baseloop, port), _pool(&_baseloop) { _acceptor.SetAcceptCallback(std::bind(&TcpServer::NewConnection, this, std::placeholders::_1)); _acceptor.Listen();//将监听套接字挂到baseloop上 } void SetThreadCount(int count) { return _pool.SetThreadCount(count); } void SetConnectedCallback(const ConnectedCallback&cb) { _connected_callback = cb; } void SetMessageCallback(const MessageCallback&cb) { _message_callback = cb; } void SetClosedCallback(const ClosedCallback&cb) { _closed_callback = cb; } void SetAnyEventCallback(const AnyEventCallback&cb) { _event_callback = cb; } void EnableInactiveRelease(int timeout) { _timeout = timeout; _enable_inactive_release = true; } //用于添加一个定时任务 void RunAfter(const Functor &task, int delay) { _baseloop.RunInLoop(std::bind(&TcpServer::RunAfterInLoop, this, task, delay)); } void Start() { _pool.Create(); _baseloop.Start(); } }; void Channel::Remove() { return _loop->RemoveEvent(this); } void Channel::Update() { return _loop->UpdateEvent(this); } void TimerWheel::TimerAdd(uint64_t id, uint32_t delay, const TaskFunc &cb) { _loop->RunInLoop(std::bind(&TimerWheel::TimerAddInLoop, this, id, delay, cb)); } //刷新/延迟定时任务 void TimerWheel::TimerRefresh(uint64_t id) { _loop->RunInLoop(std::bind(&TimerWheel::TimerRefreshInLoop, this, id)); } void TimerWheel::TimerCancel(uint64_t id) { _loop->RunInLoop(std::bind(&TimerWheel::TimerCancelInLoop, this, id)); } class NetWork { public: NetWork() { DBG_LOG("SIGPIPE INIT"); signal(SIGPIPE, SIG_IGN); } }; static NetWork nw; #endif
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。