feat(core): 重构类加载器并添加插件支持

-重构 BoxClassLoader 以支持插件加载
- 添加核心插件加载逻辑
- 实现类转换器和黑名单功能
- 优化工具类别和工具项的注册
- 修复日志输出和异常处理
This commit is contained in:
tzdwindows 7
2025-02-11 13:21:45 +08:00
parent fcc4115638
commit 4d08fbeab0
9 changed files with 287 additions and 87 deletions

View File

@@ -80,6 +80,12 @@ sourceSets {
} }
} }
tasks.withType(JavaExec).configureEach {
jvmArgs = [
'-Djava.system.class.loader=com.axis.innovators.box.plugins.BoxClassLoader'
]
}
// 单独打包文档 // 单独打包文档
task packageOpenSourceDocs(type: Jar) { task packageOpenSourceDocs(type: Jar) {
archiveClassifier = 'docs' archiveClassifier = 'docs'

View File

@@ -11,7 +11,7 @@ import java.io.PrintStream;
* @author tzdwindows 7 * @author tzdwindows 7
*/ */
public class Log4j2OutputStream extends OutputStream { public class Log4j2OutputStream extends OutputStream {
private static final Logger logger = LogManager.getLogger(Log4j2OutputStream.class); private static final Logger logger = LogManager.getLogger();
@Override @Override
public void write(int b) { public void write(int b) {
@@ -20,14 +20,15 @@ public class Log4j2OutputStream extends OutputStream {
@Override @Override
public void write(byte[] b, int off, int len) { public void write(byte[] b, int off, int len) {
logger.info(new String(b, off, len)); String message = new String(b, off, len).trim();
logger.info(message);
} }
/** /**
* 重定向 System.out 和 System.err 到 Log4j2 * 重定向 System.out 和 System.err 到 Log4j2
*/ */
public static void redirectSystemStreams() { public static void redirectSystemStreams() {
System.setOut(new PrintStream(new Log4j2OutputStream())); System.setOut(new PrintStream(new Log4j2OutputStream(), true));
System.setErr(new PrintStream(new Log4j2OutputStream())); System.setErr(new PrintStream(new Log4j2OutputStream(), true));
} }
} }

View File

@@ -36,7 +36,7 @@ public class Main {
private Thread thread; private Thread thread;
private final String[] args; private final String[] args;
private boolean isWindow = false; private boolean isWindow = false;
private RegistrationTool registrationTool; private final RegistrationTool registrationTool = new RegistrationTool(this);
public Main(String[] args){ public Main(String[] args){
this.args = args; this.args = args;
@@ -118,9 +118,10 @@ public class Main {
main.thread = new Thread(() -> { main.thread = new Thread(() -> {
try { try {
// 主任务1加载插件 // 主任务1加载插件
logger.info("Loaded plugins Started");
main.progressBarManager.updateMainProgress(++main.completedTasks); main.progressBarManager.updateMainProgress(++main.completedTasks);
PluginLoader.loadPlugins(); PluginLoader.loadPlugins();
logger.info("Loaded plugins"); logger.info("Loaded plugins End");
main.progressBarManager.close(); main.progressBarManager.close();
@@ -182,14 +183,16 @@ public class Main {
// 主任务2加载工具栏 // 主任务2加载工具栏
progressBarManager.updateMainProgress(++completedTasks); progressBarManager.updateMainProgress(++completedTasks);
if (!registrationTool.getToolCategories().isEmpty()) {
for (int i = 0; i < registrationTool.getToolCategories() for (int i = 0; i < registrationTool.getToolCategories()
.size(); i++) { .size(); i++) {
ex.addToolCategory(registrationTool.getToolCategories() ex.addToolCategory(registrationTool.getToolCategories()
.get(i)); .get(i));
progressBarManager.updateSubProgress( progressBarManager.updateSubProgress(
"add tools",i, "add tools", i,
registrationTool.getToolCategories().size()); registrationTool.getToolCategories().size());
} }
}
ex.initUI(); ex.initUI();
isWindow = true; isWindow = true;

View File

@@ -210,6 +210,13 @@ public class MainWindow extends JFrame {
verticalScrollBar.setUI(new CustomScrollBarUI()); verticalScrollBar.setUI(new CustomScrollBarUI());
verticalScrollBar.setPreferredSize(new Dimension(10, 100)); verticalScrollBar.setPreferredSize(new Dimension(10, 100));
if (category.getIcon() == null){
tabbedPane.addTab(
category.getName(),
category.getIconImage(),
scrollPane
);
}
tabbedPane.addTab( tabbedPane.addTab(
category.getName(), category.getName(),
LoadIcon.loadIcon(category.getIcon(), 24), LoadIcon.loadIcon(category.getIcon(), 24),
@@ -312,8 +319,13 @@ public class MainWindow extends JFrame {
BorderFactory.createEmptyBorder(20, 20, 20, 20) BorderFactory.createEmptyBorder(20, 20, 20, 20)
)); ));
card.setCursor(Cursor.getPredefinedCursor(Cursor.HAND_CURSOR)); card.setCursor(Cursor.getPredefinedCursor(Cursor.HAND_CURSOR));
JLabel iconLabel;
if (tool.getIcon() == null){
iconLabel = new JLabel(tool.getImageIcon());
}else {
iconLabel = new JLabel(LoadIcon.loadIcon(tool.getIcon(), 64));
}
JLabel iconLabel = new JLabel(LoadIcon.loadIcon(tool.icon(), 64));
iconLabel.setHorizontalAlignment(SwingConstants.CENTER); iconLabel.setHorizontalAlignment(SwingConstants.CENTER);
// 文字面板 // 文字面板
@@ -321,12 +333,12 @@ public class MainWindow extends JFrame {
textPanel.setLayout(new BoxLayout(textPanel, BoxLayout.Y_AXIS)); textPanel.setLayout(new BoxLayout(textPanel, BoxLayout.Y_AXIS));
textPanel.setOpaque(false); textPanel.setOpaque(false);
JLabel titleLabel = new JLabel(tool.title()); JLabel titleLabel = new JLabel(tool.getTitle());
titleLabel.setFont(new Font("微软雅黑", Font.BOLD, 18)); titleLabel.setFont(new Font("微软雅黑", Font.BOLD, 18));
titleLabel.setForeground(new Color(44, 62, 80)); titleLabel.setForeground(new Color(44, 62, 80));
titleLabel.setAlignmentX(Component.CENTER_ALIGNMENT); titleLabel.setAlignmentX(Component.CENTER_ALIGNMENT);
JTextArea descArea = new JTextArea(tool.description()); JTextArea descArea = new JTextArea(tool.getDescription());
descArea.setFont(new Font("微软雅黑", Font.PLAIN, 14)); descArea.setFont(new Font("微软雅黑", Font.PLAIN, 14));
descArea.setForeground(new Color(127, 140, 153)); descArea.setForeground(new Color(127, 140, 153));
descArea.setLineWrap(true); descArea.setLineWrap(true);
@@ -374,7 +386,7 @@ public class MainWindow extends JFrame {
private String createToolTipHTML(ToolItem tool) { private String createToolTipHTML(ToolItem tool) {
return "<html><body style='width: 300px; padding: 10px;'>" + return "<html><body style='width: 300px; padding: 10px;'>" +
"<h3 style='color: #2c3e50; margin: 0 0 8px 0;'>" + tool.getName() + "</h3>" + "<h3 style='color: #2c3e50; margin: 0 0 8px 0;'>" + tool.getName() + "</h3>" +
"<p style='color: #7f8c99; margin: 0;'>" + tool.description() + "</p>" + "<p style='color: #7f8c99; margin: 0;'>" + tool.getDescription() + "</p>" +
"</body></html>"; "</body></html>";
} }
@@ -382,6 +394,7 @@ public class MainWindow extends JFrame {
public static class ToolCategory { public static class ToolCategory {
private final String name; private final String name;
private final String icon; private final String icon;
private final ImageIcon iconImage;
private final String description; private final String description;
private final List<ToolItem> tools = new ArrayList<>(); private final List<ToolItem> tools = new ArrayList<>();
@@ -395,6 +408,14 @@ public class MainWindow extends JFrame {
this.name = name; this.name = name;
this.icon = icon; this.icon = icon;
this.description = description; this.description = description;
this.iconImage = null;
}
public ToolCategory(String name, ImageIcon icon, String description) {
this.name = name;
this.iconImage = icon;
this.description = description;
this.icon = null;
} }
/** /**
@@ -420,19 +441,71 @@ public class MainWindow extends JFrame {
public List<ToolItem> getTools() { public List<ToolItem> getTools() {
return tools; return tools;
} }
public ImageIcon getIconImage() {
return iconImage;
}
} }
// 工具项数据类 // 工具项数据类
/** /**
* 工具注册类 * 工具注册类
* @param title 工具的标题(显示名称)
* @param icon 工具的图标resources的路径
* @param description 工具的描述
* @param id 工具的id请不要重复注册相同id的工具
* @param action 工具的点击事件
*/ */
public record ToolItem(String title, String icon, String description, int id, Action action) { public static class ToolItem {
private final ImageIcon imageIcon;
private final String title;
private final String icon;
private final String description;
private final int id;
private final Action action;
public ToolItem(String title, String icon, String description, int id, Action action){
this.title = title;
this.icon = icon;
this.description = description;
this.id = id;
this.action = action;
this.imageIcon = null;
}
public ToolItem(String title, ImageIcon icon, String description, int id, Action action) {
this.title = title;
this.imageIcon = icon;
this.description = description;
this.id = id;
this.action = action;
this.icon = null;
}
public String getTitle() {
return title;
}
public ImageIcon icon() {
return imageIcon;
}
public String getIcon() {
return icon;
}
public Action getAction() {
return action;
}
public ImageIcon getImageIcon() {
return imageIcon;
}
public int getId() {
return id;
}
public String getDescription() {
return description;
}
public String getName() { public String getName() {
return title; return title;
@@ -495,7 +568,7 @@ public class MainWindow extends JFrame {
@Override @Override
public void mouseReleased(MouseEvent e) { public void mouseReleased(MouseEvent e) {
startReleaseAnimation(() -> tool.action().actionPerformed( startReleaseAnimation(() -> tool.getAction().actionPerformed(
new ActionEvent(card, ActionEvent.ACTION_PERFORMED, "") new ActionEvent(card, ActionEvent.ACTION_PERFORMED, "")
)); ));
} }

View File

@@ -1,21 +1,30 @@
package com.axis.innovators.box.plugins; package com.axis.innovators.box.plugins;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassWriter;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.net.URL; import java.net.URL;
import java.net.URLClassLoader; import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.List;
/**
* 自定义加载器
* @author tzdwindows 7
*/
public class BoxClassLoader extends URLClassLoader { public class BoxClassLoader extends URLClassLoader {
public BoxClassLoader(ClassLoader parent) { private static final List<IClassTransformer> CLASS_TRANSFORMS = new ArrayList<>();
super(new URL[0], parent); private static final List<String> CLASS_BLACKLIST = new ArrayList<>();
private static final List<String> CLASS_LOADING_LIST = new ArrayList<>();
private static final List<Class<?>> CLASS_LOADING_LIST_OBJECT = new ArrayList<>();
static {
// 添加黑名单,避免加载核心类
CLASS_BLACKLIST.add("java.");
CLASS_BLACKLIST.add("javax.");
CLASS_BLACKLIST.add("sun.");
} }
public BoxClassLoader(URL[] sources) { public BoxClassLoader(ClassLoader parent) {
super(sources); super(new URL[]{}, parent);
} }
public BoxClassLoader(URL[] urls, ClassLoader parent) { public BoxClassLoader(URL[] urls, ClassLoader parent) {
@@ -23,18 +32,75 @@ public class BoxClassLoader extends URLClassLoader {
} }
@Override @Override
protected Class<?> findClass(String name) throws ClassNotFoundException { public void addURL(URL url) {
byte[] classBytes; super.addURL(url);
try {
classBytes = getClassBytes(name);
} catch (IOException e) {
throw new RuntimeException(e);
}
byte[] transformedClass = transformClass(name, null, classBytes);
return defineClass(name, transformedClass, 0, transformedClass.length);
} }
public byte[] getClassBytes(String className) throws IOException, ClassNotFoundException { //@Override
//protected Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
// synchronized (getClassLoadingLock(name)) {
// // 1. 检查类是否已加载
// Class<?> c = findLoadedClass(name);
// if (c != null) {
// return c;
// }
//
// // 2. 检查黑名单
// if (isBlacklisted(name)) {
// return super.loadClass(name, resolve);
// }
//
// // 3. 尝试自定义加载
// try {
// c = findClass(name);
// } catch (ClassNotFoundException e) {
// // 4. 自定义加载失败则委托给父加载器
// c = super.loadClass(name, resolve);
// }
//
// if (resolve) {
// resolveClass(c);
// }
// return c;
// }
//}
@Override
protected Class<?> findClass(String name) throws ClassNotFoundException {
// 检查是否在黑名单中
if (isBlacklisted(name)) {
throw new ClassNotFoundException("Class is blacklisted: " + name);
}
// 检查是否正在加载
if (CLASS_LOADING_LIST.contains(name)) {
throw new ClassNotFoundException("Class is already loading: " + name);
}
// 读取类字节码
byte[] clazzByte;
try {
clazzByte = getClassBytes(name);
} catch (IOException e) {
throw new ClassNotFoundException("Failed to load class bytes: " + name, e);
}
// 应用类转换器
for (IClassTransformer transformer : CLASS_TRANSFORMS) {
byte[] transformed = transformer.transform(name, transformer.getClass().getName(), clazzByte);
if (transformed != null) {
clazzByte = transformed;
}
}
// 定义类
CLASS_LOADING_LIST.add(name);
Class<?> clazz = defineClass(name, clazzByte, 0, clazzByte.length);
CLASS_LOADING_LIST_OBJECT.add(clazz);
return clazz;
}
private byte[] getClassBytes(String className) throws IOException, ClassNotFoundException {
String path = className.replace('.', '/') + ".class"; String path = className.replace('.', '/') + ".class";
URL classUrl = getResource(path); URL classUrl = getResource(path);
if (classUrl == null) { if (classUrl == null) {
@@ -45,7 +111,15 @@ public class BoxClassLoader extends URLClassLoader {
} }
} }
public byte[] transformClass(String className, String transformedName, byte[] basicClass) { private boolean isBlacklisted(String className) {
return PluginLoader.transformClass(className, transformedName, basicClass); return CLASS_BLACKLIST.stream().anyMatch(className::startsWith);
}
public static void addClassTransformer(IClassTransformer transformer) {
CLASS_TRANSFORMS.add(transformer);
}
public static List<Class<?>> getClassList() {
return CLASS_LOADING_LIST_OBJECT;
} }
} }

View File

@@ -6,6 +6,8 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import java.io.*; import java.io.*;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.*; import java.net.*;
import java.util.*; import java.util.*;
import java.util.jar.*; import java.util.jar.*;
@@ -15,20 +17,26 @@ import java.util.jar.*;
* @author tzdwindows 7 * @author tzdwindows 7
*/ */
public class PluginLoader { public class PluginLoader {
private static final Logger logger = LogManager.getLogger(PluginLoader.class); private static Logger logger;
public static final String PLUGIN_PATH = FolderCreator.getPluginFolder(); public static final String PLUGIN_PATH = FolderCreator.getPluginFolder();
private static final List<PluginDescriptor> loadedPlugins = new ArrayList<>(); private static final List<PluginDescriptor> loadedPlugins = new ArrayList<>();
private static final List<IClassTransformer> transformers = new ArrayList<>(); private static final List<IClassTransformer> transformers = new ArrayList<>();
public static void loadPlugins() throws IOException { public static void loadPlugins() throws IOException {
File pluginDir = new File(PLUGIN_PATH); logger = LogManager.getLogger(PluginLoader.class);
File[] jars = pluginDir.listFiles((dir, name) -> name.toLowerCase().endsWith(".jar")); File pluginDir = null;
if (PLUGIN_PATH != null) {
pluginDir = new File(PLUGIN_PATH);
}
File[] jars = null;
if (pluginDir != null) {
jars = pluginDir.listFiles((dir, name) -> name.toLowerCase().endsWith(".jar"));
}
if (jars == null) { if (jars == null) {
return; return;
} }
for (int i = 0; i < jars.length; i++) { for (int i = 0; i < jars.length; i++) {
processJarFile(jars[i]); processJarFile(jars[i],false);
Main.getMain().progressBarManager.updateSubProgress( Main.getMain().progressBarManager.updateSubProgress(
"Loading Plugin " + i, "Loading Plugin " + i,
i, i,
@@ -36,24 +44,47 @@ public class PluginLoader {
} }
} }
private static void processJarFile(File jarFile) throws IOException { private static void processJarFile(File jarFile, boolean isCorePlugin) throws IOException {
try (JarFile jar = new JarFile(jarFile)) { try (JarFile jar = new JarFile(jarFile)) {
// Check for CorePlugin in MANIFEST.MF
if (isCorePlugin) {
Attributes attributes = jar.getManifest().getMainAttributes();
String corePluginClass = attributes.getValue("CorePlugin");
if (corePluginClass != null) {
processCorePlugin(jarFile, corePluginClass);
}
} else {
JarEntry pluginFile = jar.getJarEntry("plug-in.box"); JarEntry pluginFile = jar.getJarEntry("plug-in.box");
if (pluginFile != null) { if (pluginFile != null) {
processWithManifest(jar, pluginFile, jarFile); processWithManifest(jar, pluginFile, jarFile);
} else { } else {
processWithAnnotations(jar, jarFile); processWithAnnotations(jar, jarFile);
} }
// Check for CorePlugin in MANIFEST.MF
Attributes attributes = jar.getManifest().getMainAttributes();
String corePluginClass = attributes.getValue("CorePlugin");
if (corePluginClass != null) {
processCorePlugin(jarFile, corePluginClass);
} }
} }
} }
/**
* 加载核心插件
* @throws IOException 插件加载失败
*/
public static void loadCorePlugin() throws IOException {
File pluginDir = null;
if (PLUGIN_PATH != null) {
pluginDir = new File(PLUGIN_PATH);
}
File[] jars = null;
if (pluginDir != null) {
jars = pluginDir.listFiles((dir, name) -> name.toLowerCase().endsWith(".jar"));
}
if (jars == null) {
return;
}
for (File jar : jars) {
processJarFile(jar, true);
}
}
private static void processCorePlugin(File jarFile, String corePluginClass) { private static void processCorePlugin(File jarFile, String corePluginClass) {
try (URLClassLoader classLoader = new URLClassLoader( try (URLClassLoader classLoader = new URLClassLoader(
new URL[]{jarFile.toURI().toURL()}, new URL[]{jarFile.toURI().toURL()},
@@ -62,22 +93,23 @@ public class PluginLoader {
Class<?> coreClass = classLoader.loadClass(corePluginClass); Class<?> coreClass = classLoader.loadClass(corePluginClass);
if (LoadingCorePlugin.class.isAssignableFrom(coreClass)) { if (LoadingCorePlugin.class.isAssignableFrom(coreClass)) {
LoadingCorePlugin corePlugin = (LoadingCorePlugin) coreClass.getDeclaredConstructor().newInstance(); LoadingCorePlugin corePlugin = (LoadingCorePlugin) coreClass.getDeclaredConstructor().newInstance();
registerTransformers(corePlugin); registerTransformers(corePlugin, classLoader);
} }
} catch (Exception e) { } catch (Exception e) {
logger.error("Failed to load core plugin: {}", corePluginClass, e); logger.error("Failed to load core plugin: {}", corePluginClass, e);
} }
} }
private static void registerTransformers(LoadingCorePlugin corePlugin) { private static void registerTransformers(LoadingCorePlugin corePlugin, ClassLoader classLoader) {
String[] transformerClasses = corePlugin.getASMTransformerClass(); String[] transformerClasses = corePlugin.getASMTransformerClass();
if (transformerClasses != null) { if (transformerClasses != null) {
for (String transformerClass : transformerClasses) { for (String transformerClass : transformerClasses) {
try { try {
Class<?> transformerClazz = Class.forName(transformerClass); Class<?> transformerClazz = classLoader.loadClass(transformerClass);
if (IClassTransformer.class.isAssignableFrom(transformerClazz)) { if (IClassTransformer.class.isAssignableFrom(transformerClazz)) {
IClassTransformer transformer = (IClassTransformer) transformerClazz.getDeclaredConstructor().newInstance(); IClassTransformer transformer = (IClassTransformer) transformerClazz.getDeclaredConstructor().newInstance();
transformers.add(transformer); transformers.add(transformer);
BoxClassLoader.addClassTransformer(transformer);
} }
} catch (Exception e) { } catch (Exception e) {
logger.error("Failed to register transformer: {}", transformerClass, e); logger.error("Failed to register transformer: {}", transformerClass, e);
@@ -113,14 +145,15 @@ public class PluginLoader {
private static void processWithAnnotations(JarFile jar, File jarFile) { private static void processWithAnnotations(JarFile jar, File jarFile) {
URLClassLoader classLoader = createClassLoader(jarFile); URLClassLoader classLoader = createClassLoader(jarFile);
Enumeration<JarEntry> entries = jar.entries(); Enumeration<JarEntry> entries = jar.entries();
while (entries.hasMoreElements()) { while (entries.hasMoreElements()) {
JarEntry entry = entries.nextElement(); JarEntry entry = entries.nextElement();
if (entry.getName().endsWith(".class")) { if (entry.getName().endsWith(".class")) {
if (classLoader != null) {
processClassEntry(entry, classLoader); processClassEntry(entry, classLoader);
} }
} }
} }
}
private static URLClassLoader createClassLoader(File jarFile) { private static URLClassLoader createClassLoader(File jarFile) {
try { try {
@@ -183,11 +216,4 @@ public class PluginLoader {
public static List<PluginDescriptor> getLoadedPlugins() { public static List<PluginDescriptor> getLoadedPlugins() {
return Collections.unmodifiableList(loadedPlugins); return Collections.unmodifiableList(loadedPlugins);
} }
public static byte[] transformClass(String name, String transformedName, byte[] basicClass) {
for (IClassTransformer transformer : transformers) {
basicClass = transformer.transform(name, transformedName, basicClass);
}
return basicClass;
}
} }

View File

@@ -25,7 +25,7 @@ public class RegistrationTool {
* 注册ToolCategory * 注册ToolCategory
* @param toolCategory ToolCategory * @param toolCategory ToolCategory
*/ */
void addToolCategory(MainWindow.ToolCategory toolCategory){ public void addToolCategory(MainWindow.ToolCategory toolCategory){
if (!main.isWindow()) { if (!main.isWindow()) {
toolCategories.add(toolCategory); toolCategories.add(toolCategory);
} else { } else {

View File

@@ -10,7 +10,7 @@ import java.io.File;
* @author tzdwindows 7 * @author tzdwindows 7
*/ */
public class FolderCreator { public class FolderCreator {
private static final Logger logger = LogManager.getLogger(FolderCreator.class); private static Logger logger;
public static final String LIBRARY_NAME = "library"; public static final String LIBRARY_NAME = "library";
public static final String MODEL_PATH = "model"; public static final String MODEL_PATH = "model";
public static final String PLUGIN_PATH = "plug-in"; public static final String PLUGIN_PATH = "plug-in";
@@ -27,6 +27,9 @@ public class FolderCreator {
public static String getModelFolder() { public static String getModelFolder() {
String folder = createFolder(MODEL_PATH); String folder = createFolder(MODEL_PATH);
if (folder == null) { if (folder == null) {
if (logger == null) {
logger = LogManager.getLogger();
}
logger.error("Model folder creation failure, please use administrator privileges to execute this procedure"); logger.error("Model folder creation failure, please use administrator privileges to execute this procedure");
return null; return null;
} }
@@ -36,6 +39,9 @@ public class FolderCreator {
public static String getLibraryFolder() { public static String getLibraryFolder() {
String folder = createFolder(LIBRARY_NAME); String folder = createFolder(LIBRARY_NAME);
if (folder == null) { if (folder == null) {
if (logger == null) {
logger = LogManager.getLogger();
}
logger.error("Library folder creation failed, please use administrator privileges to execute this procedure"); logger.error("Library folder creation failed, please use administrator privileges to execute this procedure");
return null; return null;
} }
@@ -52,6 +58,9 @@ public class FolderCreator {
File folder = new File(jarDir, folderName); File folder = new File(jarDir, folderName);
if (!folder.exists()) { if (!folder.exists()) {
if (!folder.mkdir()) { if (!folder.mkdir()) {
if (logger == null) {
logger = LogManager.getLogger();
}
logger.error("Folder creation failure"); logger.error("Folder creation failure");
return null; return null;
} }

View File

@@ -10,17 +10,28 @@ import org.apache.logging.log4j.Logger;
* @author tzdwindows 7 * @author tzdwindows 7
*/ */
public class LM { public class LM {
public static boolean CUDA = false; public static boolean CUDA = true;
public final static String DEEP_SEEK = FolderCreator.getModelFolder() + "\\DeepSeek-R1-Distill-Qwen-1.5B-Q8_0.gguf"; public final static String DEEP_SEEK = FolderCreator.getModelFolder() + "\\DeepSeek-R1-Distill-Qwen-1.5B-Q8_0.gguf";
private static final Logger logger = LogManager.getLogger(LM.class); private static final Logger logger = LogManager.getLogger(LM.class);
static { static {
if (!CUDA) { loadLibrary(CUDA);
}
private static void loadLibrary(boolean cuda){
if (!cuda) {
logger.warn("The cpu will be used for inference"); logger.warn("The cpu will be used for inference");
try {
LibraryLoad.loadLibrary("cpu/ggml-base"); LibraryLoad.loadLibrary("cpu/ggml-base");
LibraryLoad.loadLibrary("cpu/ggml-cpu"); LibraryLoad.loadLibrary("cpu/ggml-cpu");
LibraryLoad.loadLibrary("cpu/ggml"); LibraryLoad.loadLibrary("cpu/ggml");
LibraryLoad.loadLibrary("cpu/llama"); LibraryLoad.loadLibrary("cpu/llama");
} catch (UnsatisfiedLinkError e) {
logger.error("Unable to load library: " + e.getMessage(), e);
logger.error("Missing dependency: " + e.getMessage());
} catch (Exception e) {
logger.error("Unable to load cpu, please try updating driver", e);
}
} else { } else {
try { try {
LibraryLoad.loadLibrary("cuda/ggml-base"); LibraryLoad.loadLibrary("cuda/ggml-base");
@@ -33,10 +44,7 @@ public class LM {
} catch (Exception e) { } catch (Exception e) {
CUDA = false; CUDA = false;
logger.warn("The cuda library could not be loaded, the cpu will be used for inference"); logger.warn("The cuda library could not be loaded, the cpu will be used for inference");
LibraryLoad.loadLibrary("cpu/ggml-base"); loadLibrary(false);
LibraryLoad.loadLibrary("cpu/ggml-cpu");
LibraryLoad.loadLibrary("cpu/ggml");
LibraryLoad.loadLibrary("cpu/llama");
} }
} }
LibraryLoad.loadLibrary("LM"); LibraryLoad.loadLibrary("LM");