From 4d08fbeab038cb07a8395748c7a2fc31d002dfa8 Mon Sep 17 00:00:00 2001
From: tzdwindows 7 <3076584115@qq.com>
Date: Tue, 11 Feb 2025 13:21:45 +0800
Subject: [PATCH] =?UTF-8?q?feat(core):=20=E9=87=8D=E6=9E=84=E7=B1=BB?=
=?UTF-8?q?=E5=8A=A0=E8=BD=BD=E5=99=A8=E5=B9=B6=E6=B7=BB=E5=8A=A0=E6=8F=92?=
=?UTF-8?q?=E4=BB=B6=E6=94=AF=E6=8C=81?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
-重构 BoxClassLoader 以支持插件加载
- 添加核心插件加载逻辑
- 实现类转换器和黑名单功能
- 优化工具类别和工具项的注册
- 修复日志输出和异常处理
---
build.gradle | 6 +
.../innovators/box/Log4j2OutputStream.java | 11 +-
.../java/com/axis/innovators/box/Main.java | 21 ++--
.../axis/innovators/box/gui/MainWindow.java | 95 ++++++++++++--
.../box/plugins/BoxClassLoader.java | 116 ++++++++++++++----
.../innovators/box/plugins/PluginLoader.java | 84 ++++++++-----
.../box/register/RegistrationTool.java | 2 +-
.../innovators/box/tools/FolderCreator.java | 11 +-
src/main/java/org/tzd/lm/LM.java | 28 +++--
9 files changed, 287 insertions(+), 87 deletions(-)
diff --git a/build.gradle b/build.gradle
index 32b0645..03f2c4a 100644
--- a/build.gradle
+++ b/build.gradle
@@ -80,6 +80,12 @@ sourceSets {
}
}
+tasks.withType(JavaExec).configureEach {
+ jvmArgs = [
+ '-Djava.system.class.loader=com.axis.innovators.box.plugins.BoxClassLoader'
+ ]
+}
+
// 单独打包文档
task packageOpenSourceDocs(type: Jar) {
archiveClassifier = 'docs'
diff --git a/src/main/java/com/axis/innovators/box/Log4j2OutputStream.java b/src/main/java/com/axis/innovators/box/Log4j2OutputStream.java
index 22ab8ed..e90a2e1 100644
--- a/src/main/java/com/axis/innovators/box/Log4j2OutputStream.java
+++ b/src/main/java/com/axis/innovators/box/Log4j2OutputStream.java
@@ -11,7 +11,7 @@ import java.io.PrintStream;
* @author tzdwindows 7
*/
public class Log4j2OutputStream extends OutputStream {
- private static final Logger logger = LogManager.getLogger(Log4j2OutputStream.class);
+ private static final Logger logger = LogManager.getLogger();
@Override
public void write(int b) {
@@ -20,14 +20,15 @@ public class Log4j2OutputStream extends OutputStream {
@Override
public void write(byte[] b, int off, int len) {
- logger.info(new String(b, off, len));
+ String message = new String(b, off, len).trim();
+ logger.info(message);
}
/**
* 重定向 System.out 和 System.err 到 Log4j2
*/
public static void redirectSystemStreams() {
- System.setOut(new PrintStream(new Log4j2OutputStream()));
- System.setErr(new PrintStream(new Log4j2OutputStream()));
+ System.setOut(new PrintStream(new Log4j2OutputStream(), true));
+ System.setErr(new PrintStream(new Log4j2OutputStream(), true));
}
-}
+}
\ No newline at end of file
diff --git a/src/main/java/com/axis/innovators/box/Main.java b/src/main/java/com/axis/innovators/box/Main.java
index 06e1145..8028ac2 100644
--- a/src/main/java/com/axis/innovators/box/Main.java
+++ b/src/main/java/com/axis/innovators/box/Main.java
@@ -36,7 +36,7 @@ public class Main {
private Thread thread;
private final String[] args;
private boolean isWindow = false;
- private RegistrationTool registrationTool;
+ private final RegistrationTool registrationTool = new RegistrationTool(this);
public Main(String[] args){
this.args = args;
@@ -118,9 +118,10 @@ public class Main {
main.thread = new Thread(() -> {
try {
// 主任务1:加载插件
+ logger.info("Loaded plugins Started");
main.progressBarManager.updateMainProgress(++main.completedTasks);
PluginLoader.loadPlugins();
- logger.info("Loaded plugins");
+ logger.info("Loaded plugins End");
main.progressBarManager.close();
@@ -182,13 +183,15 @@ public class Main {
// 主任务2:加载工具栏
progressBarManager.updateMainProgress(++completedTasks);
- for (int i = 0; i < registrationTool.getToolCategories()
- .size(); i++) {
- ex.addToolCategory(registrationTool.getToolCategories()
- .get(i));
- progressBarManager.updateSubProgress(
- "add tools",i,
- registrationTool.getToolCategories().size());
+ if (!registrationTool.getToolCategories().isEmpty()) {
+ for (int i = 0; i < registrationTool.getToolCategories()
+ .size(); i++) {
+ ex.addToolCategory(registrationTool.getToolCategories()
+ .get(i));
+ progressBarManager.updateSubProgress(
+ "add tools", i,
+ registrationTool.getToolCategories().size());
+ }
}
ex.initUI();
diff --git a/src/main/java/com/axis/innovators/box/gui/MainWindow.java b/src/main/java/com/axis/innovators/box/gui/MainWindow.java
index dfc2064..6de9dfc 100644
--- a/src/main/java/com/axis/innovators/box/gui/MainWindow.java
+++ b/src/main/java/com/axis/innovators/box/gui/MainWindow.java
@@ -210,6 +210,13 @@ public class MainWindow extends JFrame {
verticalScrollBar.setUI(new CustomScrollBarUI());
verticalScrollBar.setPreferredSize(new Dimension(10, 100));
+ if (category.getIcon() == null){
+ tabbedPane.addTab(
+ category.getName(),
+ category.getIconImage(),
+ scrollPane
+ );
+ }
tabbedPane.addTab(
category.getName(),
LoadIcon.loadIcon(category.getIcon(), 24),
@@ -312,8 +319,13 @@ public class MainWindow extends JFrame {
BorderFactory.createEmptyBorder(20, 20, 20, 20)
));
card.setCursor(Cursor.getPredefinedCursor(Cursor.HAND_CURSOR));
+ JLabel iconLabel;
+ if (tool.getIcon() == null){
+ iconLabel = new JLabel(tool.getImageIcon());
+ }else {
+ iconLabel = new JLabel(LoadIcon.loadIcon(tool.getIcon(), 64));
+ }
- JLabel iconLabel = new JLabel(LoadIcon.loadIcon(tool.icon(), 64));
iconLabel.setHorizontalAlignment(SwingConstants.CENTER);
// 文字面板
@@ -321,12 +333,12 @@ public class MainWindow extends JFrame {
textPanel.setLayout(new BoxLayout(textPanel, BoxLayout.Y_AXIS));
textPanel.setOpaque(false);
- JLabel titleLabel = new JLabel(tool.title());
+ JLabel titleLabel = new JLabel(tool.getTitle());
titleLabel.setFont(new Font("微软雅黑", Font.BOLD, 18));
titleLabel.setForeground(new Color(44, 62, 80));
titleLabel.setAlignmentX(Component.CENTER_ALIGNMENT);
- JTextArea descArea = new JTextArea(tool.description());
+ JTextArea descArea = new JTextArea(tool.getDescription());
descArea.setFont(new Font("微软雅黑", Font.PLAIN, 14));
descArea.setForeground(new Color(127, 140, 153));
descArea.setLineWrap(true);
@@ -374,7 +386,7 @@ public class MainWindow extends JFrame {
private String createToolTipHTML(ToolItem tool) {
return "
" +
"" + tool.getName() + "
" +
- "" + tool.description() + "
" +
+ "" + tool.getDescription() + "
" +
"";
}
@@ -382,6 +394,7 @@ public class MainWindow extends JFrame {
public static class ToolCategory {
private final String name;
private final String icon;
+ private final ImageIcon iconImage;
private final String description;
private final List tools = new ArrayList<>();
@@ -395,6 +408,14 @@ public class MainWindow extends JFrame {
this.name = name;
this.icon = icon;
this.description = description;
+ this.iconImage = null;
+ }
+
+ public ToolCategory(String name, ImageIcon icon, String description) {
+ this.name = name;
+ this.iconImage = icon;
+ this.description = description;
+ this.icon = null;
}
/**
@@ -420,19 +441,71 @@ public class MainWindow extends JFrame {
public List getTools() {
return tools;
}
+
+ public ImageIcon getIconImage() {
+ return iconImage;
+ }
}
// 工具项数据类
/**
* 工具注册类
- * @param title 工具的标题(显示名称)
- * @param icon 工具的图标(resources的路径)
- * @param description 工具的描述
- * @param id 工具的id(请不要重复注册相同id的工具)
- * @param action 工具的点击事件
*/
- public record ToolItem(String title, String icon, String description, int id, Action action) {
+ public static class ToolItem {
+ private final ImageIcon imageIcon;
+ private final String title;
+ private final String icon;
+
+ private final String description;
+ private final int id;
+ private final Action action;
+
+ public ToolItem(String title, String icon, String description, int id, Action action){
+ this.title = title;
+ this.icon = icon;
+ this.description = description;
+ this.id = id;
+ this.action = action;
+ this.imageIcon = null;
+ }
+
+ public ToolItem(String title, ImageIcon icon, String description, int id, Action action) {
+ this.title = title;
+ this.imageIcon = icon;
+ this.description = description;
+ this.id = id;
+ this.action = action;
+ this.icon = null;
+ }
+
+ public String getTitle() {
+ return title;
+ }
+
+ public ImageIcon icon() {
+ return imageIcon;
+ }
+
+ public String getIcon() {
+ return icon;
+ }
+
+ public Action getAction() {
+ return action;
+ }
+
+ public ImageIcon getImageIcon() {
+ return imageIcon;
+ }
+
+ public int getId() {
+ return id;
+ }
+
+ public String getDescription() {
+ return description;
+ }
public String getName() {
return title;
@@ -495,7 +568,7 @@ public class MainWindow extends JFrame {
@Override
public void mouseReleased(MouseEvent e) {
- startReleaseAnimation(() -> tool.action().actionPerformed(
+ startReleaseAnimation(() -> tool.getAction().actionPerformed(
new ActionEvent(card, ActionEvent.ACTION_PERFORMED, "")
));
}
diff --git a/src/main/java/com/axis/innovators/box/plugins/BoxClassLoader.java b/src/main/java/com/axis/innovators/box/plugins/BoxClassLoader.java
index b31efc5..db49ccc 100644
--- a/src/main/java/com/axis/innovators/box/plugins/BoxClassLoader.java
+++ b/src/main/java/com/axis/innovators/box/plugins/BoxClassLoader.java
@@ -1,21 +1,30 @@
package com.axis.innovators.box.plugins;
+import org.objectweb.asm.ClassReader;
+import org.objectweb.asm.ClassWriter;
+
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.net.URLClassLoader;
+import java.util.ArrayList;
+import java.util.List;
-/**
- * 自定义加载器
- * @author tzdwindows 7
- */
public class BoxClassLoader extends URLClassLoader {
- public BoxClassLoader(ClassLoader parent) {
- super(new URL[0], parent);
+ private static final List CLASS_TRANSFORMS = new ArrayList<>();
+ private static final List CLASS_BLACKLIST = new ArrayList<>();
+ private static final List CLASS_LOADING_LIST = new ArrayList<>();
+ private static final List> CLASS_LOADING_LIST_OBJECT = new ArrayList<>();
+
+ static {
+ // 添加黑名单,避免加载核心类
+ CLASS_BLACKLIST.add("java.");
+ CLASS_BLACKLIST.add("javax.");
+ CLASS_BLACKLIST.add("sun.");
}
- public BoxClassLoader(URL[] sources) {
- super(sources);
+ public BoxClassLoader(ClassLoader parent) {
+ super(new URL[]{}, parent);
}
public BoxClassLoader(URL[] urls, ClassLoader parent) {
@@ -23,18 +32,75 @@ public class BoxClassLoader extends URLClassLoader {
}
@Override
- protected Class> findClass(String name) throws ClassNotFoundException {
- byte[] classBytes;
- try {
- classBytes = getClassBytes(name);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- byte[] transformedClass = transformClass(name, null, classBytes);
- return defineClass(name, transformedClass, 0, transformedClass.length);
+ public void addURL(URL url) {
+ super.addURL(url);
}
- public byte[] getClassBytes(String className) throws IOException, ClassNotFoundException {
+ //@Override
+ //protected Class> loadClass(String name, boolean resolve) throws ClassNotFoundException {
+ // synchronized (getClassLoadingLock(name)) {
+ // // 1. 检查类是否已加载
+ // Class> c = findLoadedClass(name);
+ // if (c != null) {
+ // return c;
+ // }
+//
+ // // 2. 检查黑名单
+ // if (isBlacklisted(name)) {
+ // return super.loadClass(name, resolve);
+ // }
+//
+ // // 3. 尝试自定义加载
+ // try {
+ // c = findClass(name);
+ // } catch (ClassNotFoundException e) {
+ // // 4. 自定义加载失败则委托给父加载器
+ // c = super.loadClass(name, resolve);
+ // }
+//
+ // if (resolve) {
+ // resolveClass(c);
+ // }
+ // return c;
+ // }
+ //}
+
+ @Override
+ protected Class> findClass(String name) throws ClassNotFoundException {
+ // 检查是否在黑名单中
+ if (isBlacklisted(name)) {
+ throw new ClassNotFoundException("Class is blacklisted: " + name);
+ }
+
+ // 检查是否正在加载
+ if (CLASS_LOADING_LIST.contains(name)) {
+ throw new ClassNotFoundException("Class is already loading: " + name);
+ }
+
+ // 读取类字节码
+ byte[] clazzByte;
+ try {
+ clazzByte = getClassBytes(name);
+ } catch (IOException e) {
+ throw new ClassNotFoundException("Failed to load class bytes: " + name, e);
+ }
+
+ // 应用类转换器
+ for (IClassTransformer transformer : CLASS_TRANSFORMS) {
+ byte[] transformed = transformer.transform(name, transformer.getClass().getName(), clazzByte);
+ if (transformed != null) {
+ clazzByte = transformed;
+ }
+ }
+
+ // 定义类
+ CLASS_LOADING_LIST.add(name);
+ Class> clazz = defineClass(name, clazzByte, 0, clazzByte.length);
+ CLASS_LOADING_LIST_OBJECT.add(clazz);
+ return clazz;
+ }
+
+ private byte[] getClassBytes(String className) throws IOException, ClassNotFoundException {
String path = className.replace('.', '/') + ".class";
URL classUrl = getResource(path);
if (classUrl == null) {
@@ -45,7 +111,15 @@ public class BoxClassLoader extends URLClassLoader {
}
}
- public byte[] transformClass(String className, String transformedName, byte[] basicClass) {
- return PluginLoader.transformClass(className, transformedName, basicClass);
+ private boolean isBlacklisted(String className) {
+ return CLASS_BLACKLIST.stream().anyMatch(className::startsWith);
}
-}
+
+ public static void addClassTransformer(IClassTransformer transformer) {
+ CLASS_TRANSFORMS.add(transformer);
+ }
+
+ public static List> getClassList() {
+ return CLASS_LOADING_LIST_OBJECT;
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/com/axis/innovators/box/plugins/PluginLoader.java b/src/main/java/com/axis/innovators/box/plugins/PluginLoader.java
index 94180eb..bfc7315 100644
--- a/src/main/java/com/axis/innovators/box/plugins/PluginLoader.java
+++ b/src/main/java/com/axis/innovators/box/plugins/PluginLoader.java
@@ -6,6 +6,8 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.io.*;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
import java.net.*;
import java.util.*;
import java.util.jar.*;
@@ -15,20 +17,26 @@ import java.util.jar.*;
* @author tzdwindows 7
*/
public class PluginLoader {
- private static final Logger logger = LogManager.getLogger(PluginLoader.class);
+ private static Logger logger;
public static final String PLUGIN_PATH = FolderCreator.getPluginFolder();
private static final List loadedPlugins = new ArrayList<>();
private static final List transformers = new ArrayList<>();
-
public static void loadPlugins() throws IOException {
- File pluginDir = new File(PLUGIN_PATH);
- File[] jars = pluginDir.listFiles((dir, name) -> name.toLowerCase().endsWith(".jar"));
+ logger = LogManager.getLogger(PluginLoader.class);
+ File pluginDir = null;
+ if (PLUGIN_PATH != null) {
+ pluginDir = new File(PLUGIN_PATH);
+ }
+ File[] jars = null;
+ if (pluginDir != null) {
+ jars = pluginDir.listFiles((dir, name) -> name.toLowerCase().endsWith(".jar"));
+ }
if (jars == null) {
return;
}
for (int i = 0; i < jars.length; i++) {
- processJarFile(jars[i]);
+ processJarFile(jars[i],false);
Main.getMain().progressBarManager.updateSubProgress(
"Loading Plugin " + i,
i,
@@ -36,24 +44,47 @@ public class PluginLoader {
}
}
- private static void processJarFile(File jarFile) throws IOException {
+ private static void processJarFile(File jarFile, boolean isCorePlugin) throws IOException {
try (JarFile jar = new JarFile(jarFile)) {
- JarEntry pluginFile = jar.getJarEntry("plug-in.box");
- if (pluginFile != null) {
- processWithManifest(jar, pluginFile, jarFile);
- } else {
- processWithAnnotations(jar, jarFile);
- }
-
// Check for CorePlugin in MANIFEST.MF
- Attributes attributes = jar.getManifest().getMainAttributes();
- String corePluginClass = attributes.getValue("CorePlugin");
- if (corePluginClass != null) {
- processCorePlugin(jarFile, corePluginClass);
+ if (isCorePlugin) {
+ Attributes attributes = jar.getManifest().getMainAttributes();
+ String corePluginClass = attributes.getValue("CorePlugin");
+ if (corePluginClass != null) {
+ processCorePlugin(jarFile, corePluginClass);
+ }
+ } else {
+ JarEntry pluginFile = jar.getJarEntry("plug-in.box");
+ if (pluginFile != null) {
+ processWithManifest(jar, pluginFile, jarFile);
+ } else {
+ processWithAnnotations(jar, jarFile);
+ }
}
}
}
+ /**
+ * 加载核心插件
+ * @throws IOException 插件加载失败
+ */
+ public static void loadCorePlugin() throws IOException {
+ File pluginDir = null;
+ if (PLUGIN_PATH != null) {
+ pluginDir = new File(PLUGIN_PATH);
+ }
+ File[] jars = null;
+ if (pluginDir != null) {
+ jars = pluginDir.listFiles((dir, name) -> name.toLowerCase().endsWith(".jar"));
+ }
+ if (jars == null) {
+ return;
+ }
+ for (File jar : jars) {
+ processJarFile(jar, true);
+ }
+ }
+
private static void processCorePlugin(File jarFile, String corePluginClass) {
try (URLClassLoader classLoader = new URLClassLoader(
new URL[]{jarFile.toURI().toURL()},
@@ -62,22 +93,23 @@ public class PluginLoader {
Class> coreClass = classLoader.loadClass(corePluginClass);
if (LoadingCorePlugin.class.isAssignableFrom(coreClass)) {
LoadingCorePlugin corePlugin = (LoadingCorePlugin) coreClass.getDeclaredConstructor().newInstance();
- registerTransformers(corePlugin);
+ registerTransformers(corePlugin, classLoader);
}
} catch (Exception e) {
logger.error("Failed to load core plugin: {}", corePluginClass, e);
}
}
- private static void registerTransformers(LoadingCorePlugin corePlugin) {
+ private static void registerTransformers(LoadingCorePlugin corePlugin, ClassLoader classLoader) {
String[] transformerClasses = corePlugin.getASMTransformerClass();
if (transformerClasses != null) {
for (String transformerClass : transformerClasses) {
try {
- Class> transformerClazz = Class.forName(transformerClass);
+ Class> transformerClazz = classLoader.loadClass(transformerClass);
if (IClassTransformer.class.isAssignableFrom(transformerClazz)) {
IClassTransformer transformer = (IClassTransformer) transformerClazz.getDeclaredConstructor().newInstance();
transformers.add(transformer);
+ BoxClassLoader.addClassTransformer(transformer);
}
} catch (Exception e) {
logger.error("Failed to register transformer: {}", transformerClass, e);
@@ -113,11 +145,12 @@ public class PluginLoader {
private static void processWithAnnotations(JarFile jar, File jarFile) {
URLClassLoader classLoader = createClassLoader(jarFile);
Enumeration entries = jar.entries();
-
while (entries.hasMoreElements()) {
JarEntry entry = entries.nextElement();
if (entry.getName().endsWith(".class")) {
- processClassEntry(entry, classLoader);
+ if (classLoader != null) {
+ processClassEntry(entry, classLoader);
+ }
}
}
}
@@ -183,11 +216,4 @@ public class PluginLoader {
public static List getLoadedPlugins() {
return Collections.unmodifiableList(loadedPlugins);
}
-
- public static byte[] transformClass(String name, String transformedName, byte[] basicClass) {
- for (IClassTransformer transformer : transformers) {
- basicClass = transformer.transform(name, transformedName, basicClass);
- }
- return basicClass;
- }
}
\ No newline at end of file
diff --git a/src/main/java/com/axis/innovators/box/register/RegistrationTool.java b/src/main/java/com/axis/innovators/box/register/RegistrationTool.java
index 772ce5f..18cca7b 100644
--- a/src/main/java/com/axis/innovators/box/register/RegistrationTool.java
+++ b/src/main/java/com/axis/innovators/box/register/RegistrationTool.java
@@ -25,7 +25,7 @@ public class RegistrationTool {
* 注册ToolCategory
* @param toolCategory ToolCategory
*/
- void addToolCategory(MainWindow.ToolCategory toolCategory){
+ public void addToolCategory(MainWindow.ToolCategory toolCategory){
if (!main.isWindow()) {
toolCategories.add(toolCategory);
} else {
diff --git a/src/main/java/com/axis/innovators/box/tools/FolderCreator.java b/src/main/java/com/axis/innovators/box/tools/FolderCreator.java
index b6dcb57..4ab7b5d 100644
--- a/src/main/java/com/axis/innovators/box/tools/FolderCreator.java
+++ b/src/main/java/com/axis/innovators/box/tools/FolderCreator.java
@@ -10,7 +10,7 @@ import java.io.File;
* @author tzdwindows 7
*/
public class FolderCreator {
- private static final Logger logger = LogManager.getLogger(FolderCreator.class);
+ private static Logger logger;
public static final String LIBRARY_NAME = "library";
public static final String MODEL_PATH = "model";
public static final String PLUGIN_PATH = "plug-in";
@@ -27,6 +27,9 @@ public class FolderCreator {
public static String getModelFolder() {
String folder = createFolder(MODEL_PATH);
if (folder == null) {
+ if (logger == null) {
+ logger = LogManager.getLogger();
+ }
logger.error("Model folder creation failure, please use administrator privileges to execute this procedure");
return null;
}
@@ -36,6 +39,9 @@ public class FolderCreator {
public static String getLibraryFolder() {
String folder = createFolder(LIBRARY_NAME);
if (folder == null) {
+ if (logger == null) {
+ logger = LogManager.getLogger();
+ }
logger.error("Library folder creation failed, please use administrator privileges to execute this procedure");
return null;
}
@@ -52,6 +58,9 @@ public class FolderCreator {
File folder = new File(jarDir, folderName);
if (!folder.exists()) {
if (!folder.mkdir()) {
+ if (logger == null) {
+ logger = LogManager.getLogger();
+ }
logger.error("Folder creation failure");
return null;
}
diff --git a/src/main/java/org/tzd/lm/LM.java b/src/main/java/org/tzd/lm/LM.java
index bbc3d57..df1844c 100644
--- a/src/main/java/org/tzd/lm/LM.java
+++ b/src/main/java/org/tzd/lm/LM.java
@@ -10,17 +10,28 @@ import org.apache.logging.log4j.Logger;
* @author tzdwindows 7
*/
public class LM {
- public static boolean CUDA = false;
+ public static boolean CUDA = true;
public final static String DEEP_SEEK = FolderCreator.getModelFolder() + "\\DeepSeek-R1-Distill-Qwen-1.5B-Q8_0.gguf";
private static final Logger logger = LogManager.getLogger(LM.class);
static {
- if (!CUDA) {
+ loadLibrary(CUDA);
+ }
+
+ private static void loadLibrary(boolean cuda){
+ if (!cuda) {
logger.warn("The cpu will be used for inference");
- LibraryLoad.loadLibrary("cpu/ggml-base");
- LibraryLoad.loadLibrary("cpu/ggml-cpu");
- LibraryLoad.loadLibrary("cpu/ggml");
- LibraryLoad.loadLibrary("cpu/llama");
+ try {
+ LibraryLoad.loadLibrary("cpu/ggml-base");
+ LibraryLoad.loadLibrary("cpu/ggml-cpu");
+ LibraryLoad.loadLibrary("cpu/ggml");
+ LibraryLoad.loadLibrary("cpu/llama");
+ } catch (UnsatisfiedLinkError e) {
+ logger.error("Unable to load library: " + e.getMessage(), e);
+ logger.error("Missing dependency: " + e.getMessage());
+ } catch (Exception e) {
+ logger.error("Unable to load cpu, please try updating driver", e);
+ }
} else {
try {
LibraryLoad.loadLibrary("cuda/ggml-base");
@@ -33,10 +44,7 @@ public class LM {
} catch (Exception e) {
CUDA = false;
logger.warn("The cuda library could not be loaded, the cpu will be used for inference");
- LibraryLoad.loadLibrary("cpu/ggml-base");
- LibraryLoad.loadLibrary("cpu/ggml-cpu");
- LibraryLoad.loadLibrary("cpu/ggml");
- LibraryLoad.loadLibrary("cpu/llama");
+ loadLibrary(false);
}
}
LibraryLoad.loadLibrary("LM");