- 新增LocalWindow类,实现本地AI推理功能- 更新LM类,添加推理相关方法 - 在Main类中添加AI工具类别和本地AI执行工具项 - 优化FridaWindow类,添加作者注释
185 lines
7.7 KiB
Java
185 lines
7.7 KiB
Java
package org.tzd.lm;
|
||
|
||
import com.axis.innovators.box.tools.FolderCreator;
|
||
import com.axis.innovators.box.tools.LibraryLoad;
|
||
|
||
/**
|
||
* LM推理类
|
||
* @author tzdwindows 7
|
||
*/
|
||
public class LM {
|
||
public static boolean CUDA = true;
|
||
public final static String DEEP_SEEK = "G:/deepseek/deepseek/DeepSeek-R1-Distill-Qwen-1.5B-Q8_0/DeepSeek-R1-Distill-Qwen-1.5B-Q8_0.gguf";
|
||
|
||
static {
|
||
if (!CUDA) {
|
||
LibraryLoad.loadLibrary("cpu/ggml-base");
|
||
LibraryLoad.loadLibrary("cpu/ggml-cpu");
|
||
LibraryLoad.loadLibrary("cpu/ggml");
|
||
LibraryLoad.loadLibrary("cpu/llama");
|
||
} else {
|
||
LibraryLoad.loadLibrary("cuda/ggml-base");
|
||
LibraryLoad.loadLibrary("cuda/ggml-cpu");
|
||
LibraryLoad.loadLibrary("cuda/ggml-rpc");
|
||
// cuda版本 cuda-cu12.4-x64(确保你有)
|
||
LibraryLoad.loadLibrary("cuda/ggml-cuda");
|
||
LibraryLoad.loadLibrary("cuda/ggml");
|
||
LibraryLoad.loadLibrary("cuda/llama");
|
||
}
|
||
LibraryLoad.loadLibrary("LM");
|
||
}
|
||
/**
|
||
* 加载模型
|
||
* @param pathModel 模型路径
|
||
* @return 模型句柄
|
||
*/
|
||
public static native long llamaLoadModelFromFile(String pathModel);
|
||
|
||
/**
|
||
* 释放模型资源
|
||
* @param modelHandle 模型句柄
|
||
*/
|
||
public static native void llamaFreeModel(long modelHandle);
|
||
|
||
public static long createContext(long modelHandle){
|
||
return createContext(modelHandle,
|
||
4096,
|
||
0,
|
||
0,
|
||
0,
|
||
0,
|
||
false,
|
||
false,
|
||
false,
|
||
true,
|
||
false
|
||
);
|
||
}
|
||
|
||
/**
|
||
* 上下文创建
|
||
*
|
||
* 创建一个新的上下文句柄,可以通过该句柄与模型进行交互。
|
||
*
|
||
* @param modelHandle 上下文句柄,表示已加载的模型实例的引用。
|
||
* @param nCtx 上下文的大小,决定了处理的最大上下文量。
|
||
* @param nBatch 批量大小,定义了每次处理的最大样本数。
|
||
* @param nSeqMax 最大序列数,对于递归模型有用。
|
||
* @param nThreads 用于生成的线程数。
|
||
* @param nThreadsBatch 用于批处理的线程数。
|
||
* @param logitsAll 是否计算所有 logits,而不仅仅是最后一个(弃用,改用 `llama_batch.logits`)。
|
||
* @param embeddings 是否提取嵌入(同时提取 logits)。
|
||
* @param offloadKqv 是否将 KQV 操作(包括 KV 缓存)卸载到 GPU 上。
|
||
* @param flashAttn 是否启用闪存注意力(实验性功能)。
|
||
* @param noPerf 是否禁用性能计时。
|
||
* @return 返回创建的上下文句柄,如果创建失败,返回 0。
|
||
*/
|
||
public static native long createContext(long modelHandle,
|
||
int nCtx,
|
||
int nBatch,
|
||
int nSeqMax,
|
||
int nThreads,
|
||
int nThreadsBatch,
|
||
boolean logitsAll,
|
||
boolean embeddings,
|
||
boolean offloadKqv,
|
||
boolean flashAttn,
|
||
boolean noPerf
|
||
);
|
||
|
||
/**
|
||
* 释放上下文资源
|
||
* @param ctxHandle 上下文句柄
|
||
*/
|
||
public static native void llamaFreeContext(long ctxHandle);
|
||
|
||
public static String inference(long modelHandle ,
|
||
long ctxHandle,
|
||
float temperature,
|
||
String prompt,
|
||
MessageCallback messageCallback){
|
||
return inference(modelHandle,
|
||
ctxHandle,
|
||
temperature,
|
||
0.05f,
|
||
40,
|
||
0.95f,
|
||
0,
|
||
5,
|
||
2.0f,
|
||
0,
|
||
0,
|
||
prompt,
|
||
messageCallback
|
||
);
|
||
}
|
||
|
||
/**
|
||
* 推理方法,根据传入的参数进行推理并返回生成的结果。
|
||
*
|
||
* @param modelHandle 模型句柄,指向加载的推理模型实例。
|
||
* @param ctxHandle 上下文句柄,指向推理上下文,用于维持推理状态。
|
||
* @param temperature 控制生成的文本的多样性,较高的值(如0.8)生成更加多样化的文本,较低的值(如0.2)生成更加确定的文本。
|
||
* @param minP 最小概率值,控制生成过程中所接受的最小概率的词汇,避免出现概率极低的词汇。
|
||
* @param topK 控制每次采样时从候选词中选取的最大数量,较小的值会限制生成的多样性。
|
||
* @param topP 控制每次采样时从累计概率最小的词汇池中选取的候选词比例,0到1之间的值。
|
||
* @param dist 种子值,决定生成文本的随机性,通常设置为0或者使用某个固定值以获得确定性的结果。
|
||
* @param penaltyLastN 控制对生成的最后n个词进行惩罚,0表示不惩罚,-1表示禁用惩罚。
|
||
* @param penaltyRepeat 重复惩罚系数,避免生成重复的词汇,1.0表示禁用惩罚。
|
||
* @param penaltyFreq 频率惩罚系数,避免过度重复高频词汇,0.0表示禁用。
|
||
* @param penaltyPresent 当前词汇惩罚系数,避免过度使用相同词汇。
|
||
* @param prompt 输入的提示文本,用于引导模型生成相关内容。
|
||
* @param messageCallback 回调函数,当模型生成消息时调用该函数,传递生成的内容。
|
||
*
|
||
* @return 返回生成的文本结果,类型为字符串。若推理过程中出现错误,返回null。
|
||
*/
|
||
public static native String inference(long modelHandle ,
|
||
long ctxHandle,
|
||
float temperature,
|
||
float minP,
|
||
float topK,
|
||
float topP,
|
||
float dist,
|
||
int penaltyLastN, // 最后n个要惩罚的标记(0 = disablepenalty-1 =上下文大小)
|
||
float penaltyRepeat, // 1.0 =禁用
|
||
float penaltyFreq, // 0.0 =禁用
|
||
float penaltyPresent,
|
||
String prompt,
|
||
MessageCallback messageCallback);
|
||
|
||
/**
|
||
* 回调接口
|
||
*/
|
||
public interface MessageCallback {
|
||
/**
|
||
* 接口回调
|
||
* @param message 消息
|
||
*/
|
||
void onMessage(String message);
|
||
}
|
||
|
||
public static void main(String[] args) {
|
||
// 加载模型
|
||
long modelHandle = llamaLoadModelFromFile(DEEP_SEEK);
|
||
// 创建新的上下文
|
||
long ctxHandle = createContext(modelHandle);
|
||
inference(modelHandle, ctxHandle, 0.2f, "Tell me how gpt works in an easy way, in Chinese", new MessageCallback() {
|
||
@Override
|
||
public void onMessage(String message) {
|
||
// 回调输出
|
||
System.out.print(message);
|
||
}
|
||
});
|
||
// 推理模型
|
||
//inference(modelHandle, ctxHandle, 0.2f, "谢谢你", new MessageCallback() {
|
||
// @Override
|
||
// public void onMessage(String message) {
|
||
// // 回调输出
|
||
// System.out.print(message);
|
||
// }
|
||
//});
|
||
// 清理上下文
|
||
llamaFreeContext(ctxHandle);
|
||
}
|
||
}
|