ctai/src/ctai_curl.cpp

239 lines
7.5 KiB
C++

#include "ctai_curl.h"
ctai_curl::ctai_curl()
{
}
ctai_curl::~ctai_curl()
{
if (m_curl != nullptr)
{
curl_easy_cleanup(m_curl);
m_curl = nullptr; // 可选,避免野指针
}
}
static ctai_curl *instance;
size_t curl_callback(void *buffer, size_t sz, size_t nmemb, void *userdata)
{
auto ctx = static_cast<call_back_context *>(userdata);
std::unique_lock<std::mutex> lock(ctx->instance->m_mutex);
size_t size = sz * nmemb;
if (ctx->m_data.steam_mode)
{
std::string chunk(static_cast<char *>(buffer), size);
qDebug()<<"info:"<<chunk;
ctx->instance->send_stream(ctx->m_data, chunk);
}
else
{
ctx->m_data.model_data.append(static_cast<char *>(buffer), size);
}
lock.unlock();
return sz * nmemb;
}
void ctai_curl::curl_init(send_data _send_args)
{
curl_global_init(CURL_GLOBAL_ALL);
m_curl = curl_easy_init();
m_error = CURLE_QUOTE_ERROR;
m_headers = nullptr;
set_send_post_option(_send_args);
instance = this;
}
void ctai_curl::set_send_post_option(send_data m_args)
{
set_send_post_headers(m_args.content_header);
set_send_post_headers(m_args.accept_header);
set_send_post_headers(m_args.api_key);
curl_easy_setopt(m_curl, CURLOPT_URL, m_args.api_url.c_str());
set_send_post_ssl(m_args.ssl_state);
set_send_post_timeout(m_args.timeout);
curl_easy_setopt(m_curl, CURLOPT_HTTPHEADER, m_headers);
curl_easy_setopt(m_curl, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_0);
curl_easy_setopt(m_curl, CURLOPT_WRITEFUNCTION, curl_callback);
}
void ctai_curl::set_send_post_headers(std::string head_str)
{
m_headers = curl_slist_append(m_headers, head_str.c_str());
}
void ctai_curl::set_send_post_ssl(bool ssl_state)
{
std::string ca_bundle = QDir::currentPath().toStdString() + "/cert/cacert.pem";
if (ssl_state)
{
curl_easy_setopt(m_curl, CURLOPT_SSL_VERIFYPEER, 2L);
curl_easy_setopt(m_curl, CURLOPT_SSL_VERIFYHOST, 1L);
curl_easy_setopt(m_curl, CURLOPT_CAINFO, ca_bundle.c_str());
}
else
{
curl_easy_setopt(m_curl, CURLOPT_SSL_VERIFYPEER, 0L);
curl_easy_setopt(m_curl, CURLOPT_SSL_VERIFYHOST, 0L);
ssl_state = false;
}
}
void ctai_curl::set_send_post_timeout(int time)
{
curl_easy_setopt(m_curl, CURLOPT_TIMEOUT, time);
}
void ctai_curl::set_send_post_followlocation(bool followlocation_state)
{
if (followlocation_state)
{
curl_easy_setopt(m_curl, CURLOPT_FOLLOWLOCATION, 1L);
}
else
{
curl_easy_setopt(m_curl, CURLOPT_FOLLOWLOCATION, 0L);
}
}
void ctai_curl::set_send_post_debug_verbose(bool verbose_state)
{
if (verbose_state)
{
curl_easy_setopt(m_curl, CURLOPT_VERBOSE, 1L);
}
else
{
curl_easy_setopt(m_curl, CURLOPT_VERBOSE, 0L);
}
}
void ctai_curl::set_send_post_tcp_keepalive(bool keepalive_state)
{
if (keepalive_state)
{
curl_easy_setopt(m_curl, CURLOPT_TCP_KEEPALIVE, 1L);
// 空闲 60 秒后探测
curl_easy_setopt(m_curl, CURLOPT_TCP_KEEPIDLE, 60L);
// 每 10 秒探测一次
curl_easy_setopt(m_curl, CURLOPT_TCP_KEEPINTVL, 10L);
// 最多探测 3 次
curl_easy_setopt(m_curl, CURLOPT_TCP_KEEPCNT, 3L);
}
else
{
curl_easy_setopt(m_curl, CURLOPT_TCP_KEEPALIVE, 0L);
}
}
std::string ctai_curl::send_request_body(send_data &_data)
{
json _request_body = {
{"model", _data.user_model},
{"messages", {{{"role", "user"}, {"content", _data.user_data}}}},
{"temperature", 0.7},
{"stream", _data.steam_mode}};
return _request_body.dump();
}
std::vector<std::string> ctai_curl::steam_extract(const std::string &input)
{
std::vector<std::string> result;
// 定义正则表达式模式,用于匹配 data: 和 \n\n 之间的内容
std::regex pattern(R"(data:(.*?)\n\n)");
auto words_begin = std::sregex_iterator(input.begin(), input.end(), pattern);
auto words_end = std::sregex_iterator();
// 遍历所有匹配结果
for (std::sregex_iterator i = words_begin; i != words_end; ++i) {
std::smatch match = *i;
// 提取捕获组中的内容
result.push_back(match.str(1));
}
return result;
}
std::string ctai_curl::send_timestamp_to_time(time_t timestamp)
{
// 将时间戳转换为 QDateTime 对象
QDateTime dateTime = QDateTime::fromSecsSinceEpoch(timestamp);
return dateTime.toString("yyyy-MM-dd hh:mm:ss").toStdString();
}
void ctai_curl::send_stream(send_data &data, std::string response_data)
{
std::vector<std::string> _info_data = steam_extract(response_data);
for (auto str : _info_data)
{
if (str != "" && str != "[DONE]")
{
qDebug() << "info:" << str;
json response = json::parse(str);
if (response.contains("choices") && !response["choices"].empty())
{
data.model_data = response["choices"][0]["delta"]["content"];
}
// 获取时间戳自动转换为时间
if (response.contains("created") && !response["created"].empty())
{
data.time = send_timestamp_to_time(response["created"]);
}
emit send_post_out_data(data);
}
}
}
void ctai_curl::send_not_stream(send_data &data, std::string response_data)
{
json response = json::parse(response_data);
// 获取返回内容
if (response.contains("choices") && !response["choices"].empty())
{
data.model_data = response["choices"][0]["message"]["content"];
}
// 获取时间戳自动转换为时间
if (response.contains("created") && !response["created"].empty())
{
data.time = send_timestamp_to_time(response["created"]);
}
// 生成模型
if (response.contains("model") && !response["model"].empty())
{
data.server_model = response["model"];
}
// 数据信息
if (response.contains("usage") && !response["usage"].empty())
{
if (response["usage"].contains("prompt_tokens"))
{
data.prompt_tokens = response["usage"]["prompt_tokens"];
}
if (response["usage"].contains("completion_tokens"))
{
data.completion_tokens = response["usage"]["completion_tokens"];
}
if (response["usage"].contains("total_tokens"))
{
data.total_tokens = response["usage"]["total_tokens"];
}
if (response["usage"].contains("prompt_cache_hit_tokens"))
{
data.prompt_cache_hit_tokens = response["usage"]["prompt_cache_hit_tokens"];
}
if (response["usage"].contains("prompt_cache_miss_tokens"))
{
data.prompt_cache_hit_tokens = response["usage"]["prompt_cache_miss_tokens"];
}
}
}
void ctai_curl::send_post_response(send_data &data)
{
if (m_curl)
{
call_back_context ctx;
ctx.m_data = data;
ctx.instance = this;
std::string request_body_str = send_request_body(data);
curl_easy_setopt(m_curl, CURLOPT_WRITEDATA, &ctx);
curl_easy_setopt(m_curl, CURLOPT_POST, 1L);
curl_easy_setopt(m_curl, CURLOPT_POSTFIELDS, request_body_str.c_str());
curl_easy_setopt(m_curl, CURLOPT_POSTFIELDSIZE, request_body_str.length());
m_error = curl_easy_perform(m_curl);
if (m_error == CURLE_OK)
{
if(!data.steam_mode){
send_not_stream(data,ctx.m_data.model_data);
emit send_post_out_data(data);
}
}else{
qDebug() << "fialed" << m_error;
}
}
}