239 lines
7.5 KiB
C++
239 lines
7.5 KiB
C++
#include "ctai_curl.h"
|
|
|
|
ctai_curl::ctai_curl()
|
|
{
|
|
}
|
|
ctai_curl::~ctai_curl()
|
|
{
|
|
if (m_curl != nullptr)
|
|
{
|
|
curl_easy_cleanup(m_curl);
|
|
m_curl = nullptr; // 可选,避免野指针
|
|
}
|
|
}
|
|
|
|
static ctai_curl *instance;
|
|
|
|
size_t curl_callback(void *buffer, size_t sz, size_t nmemb, void *userdata)
|
|
{
|
|
auto ctx = static_cast<call_back_context *>(userdata);
|
|
std::unique_lock<std::mutex> lock(ctx->instance->m_mutex);
|
|
size_t size = sz * nmemb;
|
|
if (ctx->m_data.steam_mode)
|
|
{
|
|
std::string chunk(static_cast<char *>(buffer), size);
|
|
qDebug()<<"info:"<<chunk;
|
|
ctx->instance->send_stream(ctx->m_data, chunk);
|
|
}
|
|
else
|
|
{
|
|
ctx->m_data.model_data.append(static_cast<char *>(buffer), size);
|
|
}
|
|
lock.unlock();
|
|
return sz * nmemb;
|
|
}
|
|
void ctai_curl::curl_init(send_data _send_args)
|
|
{
|
|
curl_global_init(CURL_GLOBAL_ALL);
|
|
m_curl = curl_easy_init();
|
|
m_error = CURLE_QUOTE_ERROR;
|
|
m_headers = nullptr;
|
|
set_send_post_option(_send_args);
|
|
instance = this;
|
|
}
|
|
void ctai_curl::set_send_post_option(send_data m_args)
|
|
{
|
|
set_send_post_headers(m_args.content_header);
|
|
set_send_post_headers(m_args.accept_header);
|
|
set_send_post_headers(m_args.api_key);
|
|
curl_easy_setopt(m_curl, CURLOPT_URL, m_args.api_url.c_str());
|
|
set_send_post_ssl(m_args.ssl_state);
|
|
set_send_post_timeout(m_args.timeout);
|
|
curl_easy_setopt(m_curl, CURLOPT_HTTPHEADER, m_headers);
|
|
curl_easy_setopt(m_curl, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_0);
|
|
curl_easy_setopt(m_curl, CURLOPT_WRITEFUNCTION, curl_callback);
|
|
}
|
|
void ctai_curl::set_send_post_headers(std::string head_str)
|
|
{
|
|
m_headers = curl_slist_append(m_headers, head_str.c_str());
|
|
}
|
|
|
|
void ctai_curl::set_send_post_ssl(bool ssl_state)
|
|
{
|
|
std::string ca_bundle = QDir::currentPath().toStdString() + "/cert/cacert.pem";
|
|
if (ssl_state)
|
|
{
|
|
curl_easy_setopt(m_curl, CURLOPT_SSL_VERIFYPEER, 2L);
|
|
curl_easy_setopt(m_curl, CURLOPT_SSL_VERIFYHOST, 1L);
|
|
curl_easy_setopt(m_curl, CURLOPT_CAINFO, ca_bundle.c_str());
|
|
}
|
|
else
|
|
{
|
|
curl_easy_setopt(m_curl, CURLOPT_SSL_VERIFYPEER, 0L);
|
|
curl_easy_setopt(m_curl, CURLOPT_SSL_VERIFYHOST, 0L);
|
|
ssl_state = false;
|
|
}
|
|
}
|
|
void ctai_curl::set_send_post_timeout(int time)
|
|
{
|
|
curl_easy_setopt(m_curl, CURLOPT_TIMEOUT, time);
|
|
}
|
|
void ctai_curl::set_send_post_followlocation(bool followlocation_state)
|
|
{
|
|
if (followlocation_state)
|
|
{
|
|
curl_easy_setopt(m_curl, CURLOPT_FOLLOWLOCATION, 1L);
|
|
}
|
|
else
|
|
{
|
|
curl_easy_setopt(m_curl, CURLOPT_FOLLOWLOCATION, 0L);
|
|
}
|
|
}
|
|
void ctai_curl::set_send_post_debug_verbose(bool verbose_state)
|
|
{
|
|
if (verbose_state)
|
|
{
|
|
curl_easy_setopt(m_curl, CURLOPT_VERBOSE, 1L);
|
|
}
|
|
else
|
|
{
|
|
curl_easy_setopt(m_curl, CURLOPT_VERBOSE, 0L);
|
|
}
|
|
}
|
|
void ctai_curl::set_send_post_tcp_keepalive(bool keepalive_state)
|
|
{
|
|
if (keepalive_state)
|
|
{
|
|
curl_easy_setopt(m_curl, CURLOPT_TCP_KEEPALIVE, 1L);
|
|
// 空闲 60 秒后探测
|
|
curl_easy_setopt(m_curl, CURLOPT_TCP_KEEPIDLE, 60L);
|
|
// 每 10 秒探测一次
|
|
curl_easy_setopt(m_curl, CURLOPT_TCP_KEEPINTVL, 10L);
|
|
// 最多探测 3 次
|
|
curl_easy_setopt(m_curl, CURLOPT_TCP_KEEPCNT, 3L);
|
|
}
|
|
else
|
|
{
|
|
curl_easy_setopt(m_curl, CURLOPT_TCP_KEEPALIVE, 0L);
|
|
}
|
|
}
|
|
std::string ctai_curl::send_request_body(send_data &_data)
|
|
{
|
|
json _request_body = {
|
|
{"model", _data.user_model},
|
|
{"messages", {{{"role", "user"}, {"content", _data.user_data}}}},
|
|
{"temperature", 0.7},
|
|
{"stream", _data.steam_mode}};
|
|
return _request_body.dump();
|
|
}
|
|
std::vector<std::string> ctai_curl::steam_extract(const std::string &input)
|
|
{
|
|
std::vector<std::string> result;
|
|
// 定义正则表达式模式,用于匹配 data: 和 \n\n 之间的内容
|
|
std::regex pattern(R"(data:(.*?)\n\n)");
|
|
auto words_begin = std::sregex_iterator(input.begin(), input.end(), pattern);
|
|
auto words_end = std::sregex_iterator();
|
|
|
|
// 遍历所有匹配结果
|
|
for (std::sregex_iterator i = words_begin; i != words_end; ++i) {
|
|
std::smatch match = *i;
|
|
// 提取捕获组中的内容
|
|
result.push_back(match.str(1));
|
|
}
|
|
return result;
|
|
}
|
|
std::string ctai_curl::send_timestamp_to_time(time_t timestamp)
|
|
{
|
|
// 将时间戳转换为 QDateTime 对象
|
|
QDateTime dateTime = QDateTime::fromSecsSinceEpoch(timestamp);
|
|
return dateTime.toString("yyyy-MM-dd hh:mm:ss").toStdString();
|
|
}
|
|
void ctai_curl::send_stream(send_data &data, std::string response_data)
|
|
{
|
|
std::vector<std::string> _info_data = steam_extract(response_data);
|
|
for (auto str : _info_data)
|
|
{
|
|
if (str != "" && str != "[DONE]")
|
|
{
|
|
qDebug() << "info:" << str;
|
|
json response = json::parse(str);
|
|
if (response.contains("choices") && !response["choices"].empty())
|
|
{
|
|
data.model_data = response["choices"][0]["delta"]["content"];
|
|
}
|
|
// 获取时间戳自动转换为时间
|
|
if (response.contains("created") && !response["created"].empty())
|
|
{
|
|
data.time = send_timestamp_to_time(response["created"]);
|
|
}
|
|
emit send_post_out_data(data);
|
|
}
|
|
}
|
|
}
|
|
void ctai_curl::send_not_stream(send_data &data, std::string response_data)
|
|
{
|
|
json response = json::parse(response_data);
|
|
// 获取返回内容
|
|
if (response.contains("choices") && !response["choices"].empty())
|
|
{
|
|
data.model_data = response["choices"][0]["message"]["content"];
|
|
}
|
|
// 获取时间戳自动转换为时间
|
|
if (response.contains("created") && !response["created"].empty())
|
|
{
|
|
data.time = send_timestamp_to_time(response["created"]);
|
|
}
|
|
// 生成模型
|
|
if (response.contains("model") && !response["model"].empty())
|
|
{
|
|
data.server_model = response["model"];
|
|
}
|
|
// 数据信息
|
|
if (response.contains("usage") && !response["usage"].empty())
|
|
{
|
|
if (response["usage"].contains("prompt_tokens"))
|
|
{
|
|
data.prompt_tokens = response["usage"]["prompt_tokens"];
|
|
}
|
|
if (response["usage"].contains("completion_tokens"))
|
|
{
|
|
data.completion_tokens = response["usage"]["completion_tokens"];
|
|
}
|
|
if (response["usage"].contains("total_tokens"))
|
|
{
|
|
data.total_tokens = response["usage"]["total_tokens"];
|
|
}
|
|
if (response["usage"].contains("prompt_cache_hit_tokens"))
|
|
{
|
|
data.prompt_cache_hit_tokens = response["usage"]["prompt_cache_hit_tokens"];
|
|
}
|
|
if (response["usage"].contains("prompt_cache_miss_tokens"))
|
|
{
|
|
data.prompt_cache_hit_tokens = response["usage"]["prompt_cache_miss_tokens"];
|
|
}
|
|
}
|
|
}
|
|
void ctai_curl::send_post_response(send_data &data)
|
|
{
|
|
if (m_curl)
|
|
{
|
|
call_back_context ctx;
|
|
ctx.m_data = data;
|
|
ctx.instance = this;
|
|
std::string request_body_str = send_request_body(data);
|
|
curl_easy_setopt(m_curl, CURLOPT_WRITEDATA, &ctx);
|
|
curl_easy_setopt(m_curl, CURLOPT_POST, 1L);
|
|
curl_easy_setopt(m_curl, CURLOPT_POSTFIELDS, request_body_str.c_str());
|
|
curl_easy_setopt(m_curl, CURLOPT_POSTFIELDSIZE, request_body_str.length());
|
|
m_error = curl_easy_perform(m_curl);
|
|
if (m_error == CURLE_OK)
|
|
{
|
|
if(!data.steam_mode){
|
|
send_not_stream(data,ctx.m_data.model_data);
|
|
emit send_post_out_data(data);
|
|
}
|
|
}else{
|
|
qDebug() << "fialed" << m_error;
|
|
}
|
|
}
|
|
} |