This repository has been archived by the owner on Jun 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
binding.cpp
100 lines (82 loc) · 2.62 KB
/
binding.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include "gpt4all-j/llmodel/llmodel_c.h"
#include "gpt4all-j/llmodel/llmodel.h"
#include "gpt4all-j/llmodel/llama.cpp/llama.h"
#include "gpt4all-j/llmodel/llamamodel.h"
#include "gpt4all-j/llmodel/gptj.h"
#include "binding.h"
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <map>
#include <string>
#include <vector>
#include <iostream>
#include <unistd.h>
void* binding_load_gptj_model(const char *fname, int n_threads) {
// load the model
auto gptj = llmodel_gptj_create();
llmodel_setThreadCount(gptj, n_threads);
if (!llmodel_loadModel(gptj, fname)) {
return nullptr;
}
return gptj;
}
std::string res ="";
void * mm;
void binding_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
float top_p, float temp, int n_batch,float ctx_erase)
{
llmodel_model* model = (llmodel_model*) m;
// std::string res = "";
auto lambda_prompt = [](int token_id, const char *promptchars) {
return true;
};
mm=model;
res="";
auto lambda_response = [](int token_id, const char *responsechars) {
res.append((char*)responsechars);
return !!bindingTokenCallback(mm, (char*)responsechars);
};
auto lambda_recalculate = [](bool is_recalculating) {
// You can handle recalculation requests here if needed
return is_recalculating;
};
llmodel_prompt_context* prompt_context = new llmodel_prompt_context{
.logits = NULL,
.logits_size = 0,
.tokens = NULL,
.tokens_size = 0,
.n_past = 0,
.n_ctx = 1024,
.n_predict = 50,
.top_k = 10,
.top_p = 0.9,
.temp = 1.0,
.n_batch = 1,
.repeat_penalty = 1.2,
.repeat_last_n = 10,
.context_erase = 0.5
};
prompt_context->n_predict = tokens;
prompt_context->repeat_last_n = repeat_last_n;
prompt_context->repeat_penalty = repeat_penalty;
prompt_context->n_ctx = n_ctx;
prompt_context->top_k = top_k;
prompt_context->context_erase = ctx_erase;
prompt_context->top_p = top_p;
prompt_context->temp = temp;
prompt_context->n_batch = n_batch;
llmodel_prompt(model, prompt,
lambda_prompt,
lambda_response,
lambda_recalculate,
prompt_context );
strcpy(result, res.c_str());
free(prompt_context);
}
void binding_gptj_free_model(void *state_ptr) {
llmodel_model* ctx = (llmodel_model*) state_ptr;
llmodel_llama_destroy(ctx);
}