Skip to content

Commit 7137368

Browse files
author
Vaijanath Rao
committed
adding re-ranking
1 parent a1a7474 commit 7137368

File tree

5 files changed

+163
-12
lines changed

5 files changed

+163
-12
lines changed

pom.xml

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
1+
<project xmlns="http://maven.apache.org/POM/4.0.0"
2+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
23
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
34
<modelVersion>4.0.0</modelVersion>
45

56
<groupId>de.kherud</groupId>
67
<artifactId>llama</artifactId>
7-
<version>4.0.0</version>
8+
<version>4.0.1</version>
89
<packaging>jar</packaging>
910

1011
<name>${project.groupId}:${project.artifactId}</name>
11-
<description>Java Bindings for llama.cpp - A Port of Facebook's LLaMA model in C/C++.</description>
12+
<description>Java Bindings for llama.cpp - A Port of Facebook's LLaMA model
13+
in C/C++.</description>
1214
<url>https://github.com/kherud/java-llama.cpp</url>
1315

1416
<licenses>
@@ -39,7 +41,8 @@
3941
</snapshotRepository>
4042
<repository>
4143
<id>ossrh</id>
42-
<url>https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/</url>
44+
<url>
45+
https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/</url>
4346
</repository>
4447
</distributionManagement>
4548

@@ -62,6 +65,7 @@
6265
<version>24.1.0</version>
6366
<scope>compile</scope>
6467
</dependency>
68+
6569
</dependencies>
6670

6771
<build>
@@ -71,17 +75,21 @@
7175
<artifactId>maven-compiler-plugin</artifactId>
7276
<version>3.13.0</version>
7377
<executions>
74-
<!-- We have to perform a separate build pass for cuda classifier -->
78+
<!-- We have to perform a separate build pass for cuda
79+
classifier -->
7580
<execution>
7681
<id>gpu</id>
7782
<phase>compile</phase>
78-
<goals><goal>compile</goal></goals>
83+
<goals>
84+
<goal>compile</goal>
85+
</goals>
7986
<configuration>
8087
<compilerArgs>
8188
<arg>-h</arg>
8289
<arg>src/main/cpp</arg>
8390
</compilerArgs>
84-
<outputDirectory>${project.build.outputDirectory}_cuda</outputDirectory>
91+
<outputDirectory>
92+
${project.build.outputDirectory}_cuda</outputDirectory>
8593
</configuration>
8694
</execution>
8795
</executions>
@@ -98,10 +106,12 @@
98106
<goal>copy-resources</goal>
99107
</goals>
100108
<configuration>
101-
<outputDirectory>${project.build.outputDirectory}_cuda</outputDirectory>
109+
<outputDirectory>
110+
${project.build.outputDirectory}_cuda</outputDirectory>
102111
<resources>
103112
<resource>
104-
<directory>${basedir}/src/main/resources_linux_cuda/</directory>
113+
<directory>
114+
${basedir}/src/main/resources_linux_cuda/</directory>
105115
<includes>
106116
<include>**/*.*</include>
107117
</includes>
@@ -176,7 +186,8 @@
176186
<artifactId>maven-jar-plugin</artifactId>
177187
<version>3.4.2</version>
178188
<executions>
179-
<!-- Pick class files AND libs from custom output directory -->
189+
<!-- Pick class files AND libs from custom output
190+
directory -->
180191
<execution>
181192
<id>cuda</id>
182193
<phase>package</phase>
@@ -185,7 +196,8 @@
185196
</goals>
186197
<configuration>
187198
<classifier>cuda12-linux-x86-64</classifier>
188-
<classesDirectory>${project.build.outputDirectory}_cuda</classesDirectory>
199+
<classesDirectory>
200+
${project.build.outputDirectory}_cuda</classesDirectory>
189201
</configuration>
190202
</execution>
191203
</executions>

src/main/cpp/jllama.cpp

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,26 @@ char **parse_string_array(JNIEnv *env, const jobjectArray string_array, const js
112112
return result;
113113
}
114114

115+
std::vector<std::string> parse_string_array_for_rerank(JNIEnv *env, const jobjectArray string_array, const jsize length) {
116+
std::vector<std::string> result;
117+
result.reserve(length); // Reserve memory for efficiency
118+
119+
for (jsize i = 0; i < length; i++) {
120+
jstring javaString = static_cast<jstring>(env->GetObjectArrayElement(string_array, i));
121+
if (javaString == nullptr) continue;
122+
123+
const char *cString = env->GetStringUTFChars(javaString, nullptr);
124+
if (cString != nullptr) {
125+
result.emplace_back(cString); // Add to vector
126+
env->ReleaseStringUTFChars(javaString, cString);
127+
}
128+
129+
env->DeleteLocalRef(javaString); // Avoid memory leaks
130+
}
131+
132+
return result;
133+
}
134+
115135
void free_string_array(char **array, jsize length) {
116136
if (array != nullptr) {
117137
for (jsize i = 0; i < length; i++) {
@@ -239,6 +259,7 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {
239259
cc_integer = env->GetMethodID(c_integer, "<init>", "(I)V");
240260
cc_float = env->GetMethodID(c_float, "<init>", "(F)V");
241261

262+
242263
if (!(cc_output && cc_hash_map && cc_integer && cc_float)) {
243264
goto error;
244265
}
@@ -634,7 +655,6 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
634655
json error = nullptr;
635656

636657
server_task_result_ptr result = ctx_server->queue_results.recv(id_task);
637-
ctx_server->queue_results.remove_waiting_task_id(id_task);
638658

639659
json response_str = result->to_json();
640660
if (result->is_error()) {
@@ -643,6 +663,11 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
643663
env->ThrowNew(c_llama_error, response.c_str());
644664
return nullptr;
645665
}
666+
667+
if (result->is_stop()) {
668+
ctx_server->queue_results.remove_waiting_task_id(id_task);
669+
}
670+
646671

647672
const auto out_res = result->to_json();
648673

@@ -679,6 +704,90 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
679704
return j_embedding;
680705
}
681706

707+
JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jobject obj, jstring jprompt, jobjectArray documents) {
708+
jlong server_handle = env->GetLongField(obj, f_model_pointer);
709+
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
710+
711+
if (!ctx_server->params_base.reranking || ctx_server->params_base.embedding) {
712+
env->ThrowNew(c_llama_error,
713+
"This server does not support reranking. Start it with `--reranking` and without `--embedding`");
714+
return nullptr;
715+
}
716+
717+
718+
const std::string prompt = parse_jstring(env, jprompt);
719+
720+
721+
722+
const auto tokenized_query = tokenize_mixed(ctx_server->vocab, prompt, true, true);
723+
724+
json responses = json::array();
725+
bool error = false;
726+
727+
std::vector<server_task> tasks;
728+
const jsize argc = env->GetArrayLength(documents);
729+
std::vector<std::string> documentsArray = parse_string_array_for_rerank(env, documents, argc);
730+
731+
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server->vocab, documentsArray, true, true);
732+
733+
tasks.reserve(tokenized_docs.size());
734+
for (size_t i = 0; i < tokenized_docs.size(); i++) {
735+
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
736+
task.id = ctx_server->queue_tasks.get_new_id();
737+
task.index = i;
738+
task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]);
739+
tasks.push_back(task);
740+
}
741+
ctx_server->queue_results.add_waiting_tasks(tasks);
742+
ctx_server->queue_tasks.post(tasks);
743+
744+
// get the result
745+
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
746+
std::vector<server_task_result_ptr> results(task_ids.size());
747+
748+
// Create a new HashMap instance
749+
jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map);
750+
if (o_probabilities == nullptr) {
751+
env->ThrowNew(c_llama_error, "Failed to create HashMap object.");
752+
return nullptr;
753+
}
754+
755+
for (int i = 0; i < (int)task_ids.size(); i++) {
756+
server_task_result_ptr result = ctx_server->queue_results.recv(task_ids);
757+
if (result->is_error()) {
758+
std::string response = result->to_json()["message"].get<std::string>();
759+
for (const int id_task : task_ids) {
760+
ctx_server->queue_results.remove_waiting_task_id(id_task);
761+
}
762+
env->ThrowNew(c_llama_error, response.c_str());
763+
return nullptr;
764+
}
765+
766+
const auto out_res = result->to_json();
767+
768+
std::cout << out_res.dump(4) << std::endl;
769+
770+
if (result->is_stop()) {
771+
for (const int id_task : task_ids) {
772+
ctx_server->queue_results.remove_waiting_task_id(id_task);
773+
}
774+
}
775+
776+
int index = out_res["index"].get<int>();
777+
float score = out_res["score"].get<float>();
778+
std::string tok_str = documentsArray[index];
779+
jstring jtok_str = env->NewStringUTF(tok_str.c_str());
780+
781+
jobject jprob = env->NewObject(c_float, cc_float, score);
782+
env->CallObjectMethod(o_probabilities, m_map_put, jtok_str, jprob);
783+
env->DeleteLocalRef(jtok_str);
784+
env->DeleteLocalRef(jprob);
785+
}
786+
jbyteArray jbytes = parse_jbytes(env, prompt);
787+
return env->NewObject(c_output, cc_output, jbytes, o_probabilities, true);
788+
789+
}
790+
682791
JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt) {
683792
jlong server_handle = env->GetLongField(obj, f_model_pointer);
684793
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)

src/main/cpp/jllama.h

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/main/java/de/kherud/llama/LlamaModel.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import java.lang.annotation.Native;
77
import java.nio.charset.StandardCharsets;
8+
import java.util.List;
89
import java.util.function.BiConsumer;
910

1011
/**
@@ -137,4 +138,6 @@ public void close() {
137138
public static String jsonSchemaToGrammar(String schema) {
138139
return new String(jsonSchemaToGrammarBytes(schema), StandardCharsets.UTF_8);
139140
}
141+
142+
public native LlamaOutput rerank(String query, String... documents);
140143
}

src/test/java/de/kherud/llama/LlamaModelTest.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,26 @@ public void testEmbedding() {
158158
float[] embedding = model.embed(prefix);
159159
Assert.assertEquals(4096, embedding.length);
160160
}
161+
162+
163+
@Ignore
164+
/**
165+
* To run this test download the model from here https://huggingface.co/mradermacher/jina-reranker-v1-tiny-en-GGUF/tree/main
166+
* remove .enableEmbedding() from model setup and add .enableReRanking() and then enable the test.
167+
*/
168+
public void testReRanking() {
169+
170+
String query = "Machine learning is";
171+
String [] TEST_DOCUMENTS = new String[] {
172+
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
173+
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
174+
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
175+
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
176+
};
177+
LlamaOutput llamaOutput = model.rerank(query, TEST_DOCUMENTS[0], TEST_DOCUMENTS[1], TEST_DOCUMENTS[2], TEST_DOCUMENTS[3] );
178+
179+
System.out.println(llamaOutput);
180+
}
161181

162182
@Test
163183
public void testTokenization() {

0 commit comments

Comments
 (0)