Files
window-axis-innovators-box1.17/src/main/java/org/tzd/lm/LM.java
tzdwindows 7 022446eb32 feat(AI工具): 添加本地AI执行工具
- 新增LocalWindow类,实现本地AI推理功能- 更新LM类,添加推理相关方法
- 在Main类中添加AI工具类别和本地AI执行工具项
- 优化FridaWindow类,添加作者注释
2025-02-07 18:03:51 +08:00

185 lines
7.7 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package org.tzd.lm;
import com.axis.innovators.box.tools.FolderCreator;
import com.axis.innovators.box.tools.LibraryLoad;
/**
* LM推理类
* @author tzdwindows 7
*/
public class LM {
public static boolean CUDA = true;
public final static String DEEP_SEEK = "G:/deepseek/deepseek/DeepSeek-R1-Distill-Qwen-1.5B-Q8_0/DeepSeek-R1-Distill-Qwen-1.5B-Q8_0.gguf";
static {
if (!CUDA) {
LibraryLoad.loadLibrary("cpu/ggml-base");
LibraryLoad.loadLibrary("cpu/ggml-cpu");
LibraryLoad.loadLibrary("cpu/ggml");
LibraryLoad.loadLibrary("cpu/llama");
} else {
LibraryLoad.loadLibrary("cuda/ggml-base");
LibraryLoad.loadLibrary("cuda/ggml-cpu");
LibraryLoad.loadLibrary("cuda/ggml-rpc");
// cuda版本 cuda-cu12.4-x64确保你有
LibraryLoad.loadLibrary("cuda/ggml-cuda");
LibraryLoad.loadLibrary("cuda/ggml");
LibraryLoad.loadLibrary("cuda/llama");
}
LibraryLoad.loadLibrary("LM");
}
/**
* 加载模型
* @param pathModel 模型路径
* @return 模型句柄
*/
public static native long llamaLoadModelFromFile(String pathModel);
/**
* 释放模型资源
* @param modelHandle 模型句柄
*/
public static native void llamaFreeModel(long modelHandle);
public static long createContext(long modelHandle){
return createContext(modelHandle,
4096,
0,
0,
0,
0,
false,
false,
false,
true,
false
);
}
/**
* 上下文创建
*
* 创建一个新的上下文句柄,可以通过该句柄与模型进行交互。
*
* @param modelHandle 上下文句柄,表示已加载的模型实例的引用。
* @param nCtx 上下文的大小,决定了处理的最大上下文量。
* @param nBatch 批量大小,定义了每次处理的最大样本数。
* @param nSeqMax 最大序列数,对于递归模型有用。
* @param nThreads 用于生成的线程数。
* @param nThreadsBatch 用于批处理的线程数。
* @param logitsAll 是否计算所有 logits而不仅仅是最后一个弃用改用 `llama_batch.logits`)。
* @param embeddings 是否提取嵌入(同时提取 logits
* @param offloadKqv 是否将 KQV 操作(包括 KV 缓存)卸载到 GPU 上。
* @param flashAttn 是否启用闪存注意力(实验性功能)。
* @param noPerf 是否禁用性能计时。
* @return 返回创建的上下文句柄,如果创建失败,返回 0。
*/
public static native long createContext(long modelHandle,
int nCtx,
int nBatch,
int nSeqMax,
int nThreads,
int nThreadsBatch,
boolean logitsAll,
boolean embeddings,
boolean offloadKqv,
boolean flashAttn,
boolean noPerf
);
/**
* 释放上下文资源
* @param ctxHandle 上下文句柄
*/
public static native void llamaFreeContext(long ctxHandle);
public static String inference(long modelHandle ,
long ctxHandle,
float temperature,
String prompt,
MessageCallback messageCallback){
return inference(modelHandle,
ctxHandle,
temperature,
0.05f,
40,
0.95f,
0,
5,
2.0f,
0,
0,
prompt,
messageCallback
);
}
/**
* 推理方法,根据传入的参数进行推理并返回生成的结果。
*
* @param modelHandle 模型句柄,指向加载的推理模型实例。
* @param ctxHandle 上下文句柄,指向推理上下文,用于维持推理状态。
* @param temperature 控制生成的文本的多样性较高的值如0.8生成更加多样化的文本较低的值如0.2)生成更加确定的文本。
* @param minP 最小概率值,控制生成过程中所接受的最小概率的词汇,避免出现概率极低的词汇。
* @param topK 控制每次采样时从候选词中选取的最大数量,较小的值会限制生成的多样性。
* @param topP 控制每次采样时从累计概率最小的词汇池中选取的候选词比例0到1之间的值。
* @param dist 种子值决定生成文本的随机性通常设置为0或者使用某个固定值以获得确定性的结果。
* @param penaltyLastN 控制对生成的最后n个词进行惩罚0表示不惩罚-1表示禁用惩罚。
* @param penaltyRepeat 重复惩罚系数避免生成重复的词汇1.0表示禁用惩罚。
* @param penaltyFreq 频率惩罚系数避免过度重复高频词汇0.0表示禁用。
* @param penaltyPresent 当前词汇惩罚系数,避免过度使用相同词汇。
* @param prompt 输入的提示文本,用于引导模型生成相关内容。
* @param messageCallback 回调函数,当模型生成消息时调用该函数,传递生成的内容。
*
* @return 返回生成的文本结果类型为字符串。若推理过程中出现错误返回null。
*/
public static native String inference(long modelHandle ,
long ctxHandle,
float temperature,
float minP,
float topK,
float topP,
float dist,
int penaltyLastN, // 最后n个要惩罚的标记0 = disablepenalty-1 =上下文大小)
float penaltyRepeat, // 1.0 =禁用
float penaltyFreq, // 0.0 =禁用
float penaltyPresent,
String prompt,
MessageCallback messageCallback);
/**
* 回调接口
*/
public interface MessageCallback {
/**
* 接口回调
* @param message 消息
*/
void onMessage(String message);
}
public static void main(String[] args) {
// 加载模型
long modelHandle = llamaLoadModelFromFile(DEEP_SEEK);
// 创建新的上下文
long ctxHandle = createContext(modelHandle);
inference(modelHandle, ctxHandle, 0.2f, "Tell me how gpt works in an easy way, in Chinese", new MessageCallback() {
@Override
public void onMessage(String message) {
// 回调输出
System.out.print(message);
}
});
// 推理模型
//inference(modelHandle, ctxHandle, 0.2f, "谢谢你", new MessageCallback() {
// @Override
// public void onMessage(String message) {
// // 回调输出
// System.out.print(message);
// }
//});
// 清理上下文
llamaFreeContext(ctxHandle);
}
}