refactor(ai):重构分割模型包装类继承结构- 将 Anime2ModelWrapper、Anime2VividModelWrapper 和 AnimeModelWrapper 改为继承自 VividModelWrapper 基类

- 移除重复的 ResultFiles 内部类和相关工具方法实现
- Anime2Segmenter 和 AnimeSegmenter 继承自抽象基类 Segmenter
- Anime2SegmentationResult与 AnimeSegmentationResult 继承 SegmentationResult
- 重命名 LabelPalette 为 BiSeNetLabelPalette 并调整其引用
- 更新模型路径配置以匹配新的文件命名约定
- 删除冗余的 getLabels() 和 getPalette() 方法定义
- 简化 segmentAndSave 方法中的类型转换逻辑- 移除已被继承方法替代的手动资源管理代码
- 调整 import 语句以反映包结构调整- 清理不再需要的独立主测试函数入口点- 修改字段访问权限以符合继承设计模式
- 替换具体的返回类型为更通用的 SegmentationResult 接口- 整合公共功能至基类减少子类间重复代码
- 统一分割后处理流程提高模块复用性
- 引入泛型支持增强 Wrapper 类型安全性
- 更新注释文档保持与最新架构同步
- 优化异常处理策略统一关闭资源方式
- 规范文件命名规则便于未来维护扩展
- 提取共通逻辑到父类降低耦合度
- 完善类型检查避免运行时 ClassCastException 风险
This commit is contained in:
tzdwindows 7
2025-10-31 09:25:18 +08:00
parent a725e7eb23
commit e06c59c8d1
20 changed files with 700 additions and 1150 deletions

View File

@@ -0,0 +1,221 @@
package com.chuangzhou.vivid2D.ai;
import com.chuangzhou.vivid2D.ai.anime_face_segmentation.AnimeModelWrapper;
import com.chuangzhou.vivid2D.ai.anime_segmentation.Anime2VividModelWrapper;
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetVividModelWrapper;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
/**
* 模型管理器 - 负责模型的注册、分类和检索
*/
public class ModelManagement {
private final Map<String, Class<?>> models = new ConcurrentHashMap<>();
private final Map<String, List<String>> modelsByCategory = new ConcurrentHashMap<>();
private final List<String> modelDisplayNames = new ArrayList<>();
private final Map<String, String> displayNameToRegistrationName = new ConcurrentHashMap<>();
private ModelManagement() {
initializeDefaultCategories();
registerDefaultModels();
}
/**
* 初始化默认分类
*/
private void initializeDefaultCategories() {
modelsByCategory.put("Image Segmentation", new ArrayList<>());
modelsByCategory.put("Image Processing", new ArrayList<>());
modelsByCategory.put("Image Generation", new ArrayList<>());
modelsByCategory.put("Image Inpainting", new ArrayList<>());
modelsByCategory.put("Image Completion", new ArrayList<>());
modelsByCategory.put("Face Analysis", new ArrayList<>());
}
/**
* 注册默认模型
*/
private void registerDefaultModels() {
registerModel("segmentation:anime_face", "Anime Face Segmentation",
AnimeModelWrapper.class, "Image Segmentation");
registerModel("segmentation:anime", "Anime Image Segmentation",
Anime2VividModelWrapper.class, "Image Segmentation");
registerModel("segmentation:face_parsing", "Face Parsing",
BiSeNetVividModelWrapper.class, "Image Segmentation");
}
/**
* 注册模型
* @param modelRegistrationName 注册名称,格式必须为 "category:model_name"
* @param modelDisplayName 模型显示名称
* @param modelClass 模型类
* @param category 模型类别
*/
public void registerModel(String modelRegistrationName, String modelDisplayName,
Class<?> modelClass, String category) {
if (!isValidRegistrationName(modelRegistrationName)) {
throw new IllegalArgumentException(
"Invalid registration name format. Expected 'category:model_name', got: " + modelRegistrationName);
}
if (models.containsKey(modelRegistrationName)) {
throw new IllegalArgumentException(
"Model registration name already exists: " + modelRegistrationName);
}
if (displayNameToRegistrationName.containsKey(modelDisplayName)) {
throw new IllegalArgumentException(
"Model display name already exists: " + modelDisplayName);
}
if (!modelsByCategory.containsKey(category)) {
modelsByCategory.put(category, new ArrayList<>());
}
models.put(modelRegistrationName, modelClass);
displayNameToRegistrationName.put(modelDisplayName, modelRegistrationName);
modelDisplayNames.add(modelDisplayName);
modelsByCategory.get(category).add(modelRegistrationName);
}
/**
* 验证注册名称格式
*/
private boolean isValidRegistrationName(String name) {
return name != null && name.matches("^[a-zA-Z0-9_]+:[a-zA-Z0-9_]+$");
}
/**
* 通过显示名称获取模型类
*/
public Class<?> getModel(String modelDisplayName) {
String registrationName = displayNameToRegistrationName.get(modelDisplayName);
return registrationName != null ? models.get(registrationName) : null;
}
/**
* 通过索引获取模型类
*/
public Class<?> getModel(int modelIndex) {
if (modelIndex >= 0 && modelIndex < modelDisplayNames.size()) {
String displayName = modelDisplayNames.get(modelIndex);
return getModel(displayName);
}
return null;
}
/**
* 通过注册名称获取模型类
*/
public Class<?> getModelByRegistrationName(String registrationName) {
return models.get(registrationName);
}
/**
* 通过类名获取模型类
*/
public Class<?> getModelByClassName(String className) {
for (Class<?> modelClass : models.values()) {
if (modelClass.getName().equals(className)) {
return modelClass;
}
}
return null;
}
/**
* 获取所有模型的显示名称
*/
public List<String> getAllModelDisplayNames() {
return Collections.unmodifiableList(modelDisplayNames);
}
/**
* 获取所有模型的注册名称
*/
public Set<String> getAllModelRegistrationNames() {
return Collections.unmodifiableSet(models.keySet());
}
/**
* 按类别获取模型注册名称
*/
public List<String> getModelsByCategory(String category) {
return Collections.unmodifiableList(
modelsByCategory.getOrDefault(category, new ArrayList<>())
);
}
/**
* 获取所有可用的类别
*/
public Set<String> getAllCategories() {
return Collections.unmodifiableSet(modelsByCategory.keySet());
}
/**
* 获取模型数量
*/
public int getModelCount() {
return modelDisplayNames.size();
}
/**
* 获取模型显示名称对应的注册名称
*/
public String getRegistrationName(String modelDisplayName) {
return displayNameToRegistrationName.get(modelDisplayName);
}
/**
* 获取模型注册名称对应的显示名称
*/
public String getDisplayName(String registrationName) {
for (Map.Entry<String, String> entry : displayNameToRegistrationName.entrySet()) {
if (entry.getValue().equals(registrationName)) {
return entry.getKey();
}
}
return null;
}
/**
* 检查模型是否存在
*/
public boolean containsModel(String modelDisplayName) {
return displayNameToRegistrationName.containsKey(modelDisplayName);
}
/**
* 检查注册名称是否存在
*/
public boolean containsRegistrationName(String registrationName) {
return models.containsKey(registrationName);
}
/**
* 移除模型
*/
public boolean removeModel(String modelDisplayName) {
String registrationName = displayNameToRegistrationName.get(modelDisplayName);
if (registrationName != null) {
// 从所有存储中移除
models.remove(registrationName);
displayNameToRegistrationName.remove(modelDisplayName);
modelDisplayNames.remove(modelDisplayName);
// 从类别中移除
for (List<String> categoryModels : modelsByCategory.values()) {
categoryModels.remove(registrationName);
}
return true;
}
return false;
}
private static final class InstanceHolder {
private static final ModelManagement instance = new ModelManagement();
}
public static ModelManagement getInstance() {
return InstanceHolder.instance;
}
}

View File

@@ -1,11 +1,8 @@
package com.chuangzhou.vivid2D.ai.face_parsing;
package com.chuangzhou.vivid2D.ai;
import java.awt.image.BufferedImage;
import java.util.Map;
/**
* 分割结果容器
*/
public class SegmentationResult {
// 分割掩码图每个像素的颜色为对应类别颜色
private final BufferedImage maskImage;

View File

@@ -0,0 +1,154 @@
package com.chuangzhou.vivid2D.ai;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetLabelPalette;
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmentationResult;
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmenter;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.*;
public abstract class Segmenter implements AutoCloseable {
// 内部类用于从Translator安全地传出数据
public static class SegmentationData {
public final int[] indices;
public final long[] shape;
public SegmentationData(int[] indices, long[] shape) {
this.indices = indices;
this.shape = shape;
}
}
private String engine = "PyTorch";
protected final ZooModel<Image, Segmenter.SegmentationData> modelWrapper;
protected final Predictor<Image, Segmenter.SegmentationData> predictor;
protected final List<String> labels;
protected final Map<String, Integer> palette;
public Segmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
this.labels = new ArrayList<>(labels);
this.palette = BiSeNetLabelPalette.defaultPalette();
Translator<Image, Segmenter.SegmentationData> translator = new Translator<Image, Segmenter.SegmentationData>() {
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
return Segmenter.this.processInput(ctx, input);
}
@Override
public Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list) {
return Segmenter.this.processOutput(ctx, list);
}
@Override
public Batchifier getBatchifier() {
return Segmenter.this.getBatchifier();
}
};
Criteria<Image, Segmenter.SegmentationData> criteria = Criteria.builder()
.setTypes(Image.class, Segmenter.SegmentationData.class)
.optModelPath(modelDir)
.optEngine(engine)
.optTranslator(translator)
.build();
this.modelWrapper = criteria.loadModel();
this.predictor = modelWrapper.newPredictor();
}
/**
* 处理模型输入
* @param ctx translator 上下文
* @param input 图片
* @return 模型输入
*/
public abstract NDList processInput(TranslatorContext ctx, Image input);
/**
* 处理模型输出
* @param ctx translator 上下文
* @param list 模型输出
* @return 模型输出
*/
public abstract Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list);
/**
* 获取批量处理方式
* @return 批量处理方式
*/
public Batchifier getBatchifier(){
return null;
}
public SegmentationResult segment(File imgFile) throws TranslateException, IOException {
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
// predict 方法现在直接返回安全的 Java 对象
Segmenter.SegmentationData data = predictor.predict(img);
long[] shp = data.shape;
int[] indices = data.indices;
int height, width;
if (shp.length == 2) {
height = (int) shp[0];
width = (int) shp[1];
} else {
throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp));
}
// 后续处理完全基于 Java 对象,不再有 Native resource 问题
BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
Map<Integer, String> labelsMap = new HashMap<>();
for (int i = 0; i < labels.size(); i++) {
labelsMap.put(i, labels.get(i));
}
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int idx = indices[y * width + x];
String label = labelsMap.getOrDefault(idx, "unknown");
int argb = palette.getOrDefault(label, 0xFF00FF00);
mask.setRGB(x, y, argb);
}
}
return new SegmentationResult(mask, labelsMap, palette);
}
public void setEngine(String engine) {
this.engine = engine;
}
@Override
public void close() {
try {
predictor.close();
} catch (Exception ignore) {
}
try {
modelWrapper.close();
} catch (Exception ignore) {
}
}
}

View File

@@ -0,0 +1,113 @@
package com.chuangzhou.vivid2D.ai;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.List;
public abstract class VividModelWrapper<s extends Segmenter> implements AutoCloseable{
protected final s segmenter;
protected final List<String> labels; // index -> name
protected final Map<String, Integer> palette; // name -> ARGB
protected VividModelWrapper(s segmenter, List<String> labels, Map<String, Integer> palette) {
this.segmenter = segmenter;
this.labels = labels;
this.palette = palette;
}
public List<String> getLabels() {
return Collections.unmodifiableList(labels);
}
public Map<String, Integer> getPalette() {
return Collections.unmodifiableMap(palette);
}
/**
* 直接返回分割结果SegmentationResult
*/
public SegmentationResult segment(File inputImage) throws Exception {
return segmenter.segment(inputImage);
}
/**
* 把指定 targets标签名集合从输入图片中分割并保存到 outDir。
* 如果 targets 包含单个元素 "all"(忽略大小写),则保存所有标签。
* <p>
* 返回值Map<labelName, ResultFiles>ResultFiles 包含 maskFile、overlayFile两个 PNG
*/
public abstract Map<String, ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception;
protected static String safeFileName(String s) {
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
}
protected static Set<String> parseTargetsSet(Set<String> in) {
if (in == null || in.isEmpty()) return Collections.emptySet();
// 若包含单个 "all"
if (in.size() == 1) {
String only = in.iterator().next();
if ("all".equalsIgnoreCase(only.trim())) {
return Set.of("all");
}
}
// 直接返回 trim 后的小写不变集合(保持用户传入的名字)
Set<String> out = new LinkedHashSet<>();
for (String s : in) {
if (s != null) out.add(s.trim());
}
return out;
}
/**
* 关闭底层资源
*/
@Override
public void close() {
try {
segmenter.close();
} catch (Exception ignore) {}
}
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
protected static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
Path syn = modelDir.resolve("synset.txt");
if (Files.exists(syn)) {
try {
List<String> lines = Files.readAllLines(syn);
List<String> cleaned = new ArrayList<>();
for (String l : lines) {
String s = l.trim();
if (!s.isEmpty()) cleaned.add(s);
}
if (!cleaned.isEmpty()) return Optional.of(cleaned);
} catch (IOException ignore) {}
}
return Optional.empty();
}
/**
* 存放结果文件路径
*/
public static class ResultFiles {
private final File maskFile;
private final File overlayFile;
public ResultFiles(File maskFile, File overlayFile) {
this.maskFile = maskFile;
this.overlayFile = overlayFile;
}
public File getMaskFile() {
return maskFile;
}
public File getOverlayFile() {
return overlayFile;
}
}
}

View File

@@ -1,5 +1,7 @@
package com.chuangzhou.vivid2D.ai.anime_face_segmentation;
import com.chuangzhou.vivid2D.ai.VividModelWrapper;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
@@ -14,16 +16,10 @@ import java.util.List;
/**
* AnimeModelWrapper - 专门为 Anime-Face-Segmentation 模型封装的 Wrapper
*/
public class AnimeModelWrapper implements AutoCloseable {
private final AnimeSegmenter segmenter;
private final List<String> labels; // index -> name
private final Map<String, Integer> palette; // name -> ARGB
public class AnimeModelWrapper extends VividModelWrapper<AnimeSegmenter> {
private AnimeModelWrapper(AnimeSegmenter segmenter, List<String> labels, Map<String, Integer> palette) {
this.segmenter = segmenter;
this.labels = labels;
this.palette = palette;
super(segmenter, labels, palette);
}
/**
@@ -36,14 +32,6 @@ public class AnimeModelWrapper implements AutoCloseable {
return new AnimeModelWrapper(segmenter, labels, palette);
}
public List<String> getLabels() {
return Collections.unmodifiableList(labels);
}
public Map<String, Integer> getPalette() {
return Collections.unmodifiableMap(palette);
}
/**
* 直接返回分割结果(在丢给底层 segmenter 前会做通用预处理RGB 转换 + 等比 letterbox 缩放到模型输入尺寸)
*/
@@ -152,28 +140,6 @@ public class AnimeModelWrapper implements AutoCloseable {
return saved;
}
private static String safeFileName(String s) {
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
}
private static Set<String> parseTargetsSet(Set<String> in) {
if (in == null || in.isEmpty()) return Collections.emptySet();
// 若包含单个 "all"
if (in.size() == 1) {
String only = in.iterator().next();
if ("all".equalsIgnoreCase(only.trim())) {
// 返回所有标签
return new LinkedHashSet<>(AnimeLabelPalette.defaultLabels());
}
}
// 直接返回 trim 后的集合
Set<String> out = new LinkedHashSet<>();
for (String s : in) {
if (s != null) out.add(s.trim());
}
return out;
}
/**
* 专门提取眼睛的方法(在丢给底层 segmenter 前做预处理)
*/
@@ -246,44 +212,6 @@ public class AnimeModelWrapper implements AutoCloseable {
} catch (Exception ignore) {}
}
/**
* 存放结果文件路径
*/
public static class ResultFiles {
private final File maskFile;
private final File overlayFile;
public ResultFiles(File maskFile, File overlayFile) {
this.maskFile = maskFile;
this.overlayFile = overlayFile;
}
public File getMaskFile() {
return maskFile;
}
public File getOverlayFile() {
return overlayFile;
}
}
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
Path syn = modelDir.resolve("synset.txt");
if (Files.exists(syn)) {
try {
List<String> lines = Files.readAllLines(syn);
List<String> cleaned = new ArrayList<>();
for (String l : lines) {
String s = l.trim();
if (!s.isEmpty()) cleaned.add(s);
}
if (!cleaned.isEmpty()) return Optional.of(cleaned);
} catch (IOException ignore) {}
}
return Optional.empty();
}
// ========== 新增:预处理并保存到临时文件 ==========
private File preprocessAndSave(File inputImage) throws IOException {
BufferedImage img = ImageIO.read(inputImage);
@@ -391,36 +319,4 @@ public class AnimeModelWrapper implements AutoCloseable {
return new int[]{w, h};
}
/* ================= convenience 主方法(快速测试) ================= */
public static void main(String[] args) throws Exception {
if (args.length < 4) {
System.out.println("用法: AnimeModelWrapper <modelDir> <inputImage> <outDir> <targetsCommaOrAll>");
System.out.println("示例: AnimeModelWrapper ./anime_unet.pt input.jpg outDir eye,face");
System.out.println("标签: " + AnimeLabelPalette.defaultLabels());
return;
}
Path modelDir = Path.of(args[0]);
File input = new File(args[1]);
Path out = Path.of(args[2]);
String targetsArg = args[3];
Set<String> targets;
if ("all".equalsIgnoreCase(targetsArg.trim())) {
targets = new LinkedHashSet<>(AnimeLabelPalette.defaultLabels());
} else {
String[] parts = targetsArg.split(",");
targets = new LinkedHashSet<>();
for (String p : parts) {
if (!p.trim().isEmpty()) targets.add(p.trim());
}
}
try (AnimeModelWrapper wrapper = AnimeModelWrapper.load(modelDir)) {
Map<String, ResultFiles> m = wrapper.segmentAndSave(input, targets, out);
m.forEach((k, v) -> {
System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath()));
});
}
}
}

View File

@@ -1,12 +1,14 @@
package com.chuangzhou.vivid2D.ai.anime_face_segmentation;
import com.chuangzhou.vivid2D.ai.SegmentationResult;
import java.awt.image.BufferedImage;
import java.util.Map;
/**
* 动漫分割结果容器
*/
public class AnimeSegmentationResult {
public class AnimeSegmentationResult extends SegmentationResult {
// 分割掩码图(每个像素的颜色为对应类别颜色)
private final BufferedImage maskImage;
@@ -21,6 +23,7 @@ public class AnimeSegmentationResult {
public AnimeSegmentationResult(BufferedImage maskImage, float[][][] probabilityMap,
Map<Integer, String> labels, Map<String, Integer> palette) {
super(maskImage, labels, palette);
this.maskImage = maskImage;
this.probabilityMap = probabilityMap;
this.labels = labels;

View File

@@ -16,6 +16,7 @@ import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import com.chuangzhou.vivid2D.ai.Segmenter;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
@@ -27,7 +28,7 @@ import java.util.*;
/**
* AnimeSegmenter: 专门为 Anime-Face-Segmentation UNet 模型设计的分割器
*/
public class AnimeSegmenter implements AutoCloseable {
public class AnimeSegmenter extends Segmenter {
// 模型默认输入大小(与训练时一致)。若模型不同可以修改为实际值或让 caller 通过构造参数传入。
private static final int MODEL_INPUT_W = 512;
@@ -48,11 +49,10 @@ public class AnimeSegmenter implements AutoCloseable {
private final ZooModel<Image, SegmentationData> modelWrapper;
private final Predictor<Image, SegmentationData> predictor;
private final List<String> labels;
private final Map<String, Integer> palette;
public AnimeSegmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
this.labels = new ArrayList<>(labels);
super(modelDir, labels);
this.palette = AnimeLabelPalette.defaultPalette();
Translator<Image, SegmentationData> translator = new Translator<Image, SegmentationData>() {
@@ -137,6 +137,16 @@ public class AnimeSegmenter implements AutoCloseable {
this.predictor = modelWrapper.newPredictor();
}
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
return null;
}
@Override
public Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list) {
return null;
}
public AnimeSegmentationResult segment(File imgFile) throws TranslateException, IOException {
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
@@ -201,30 +211,4 @@ public class AnimeSegmenter implements AutoCloseable {
} catch (Exception ignore) {
}
}
// 测试主函数
public static void main(String[] args) throws Exception {
if (args.length < 3) {
System.out.println("用法: java AnimeSegmenter <modelDir> <inputImage> <outputMaskPng>");
System.out.println("示例: java AnimeSegmenter ./anime_unet.pt input.jpg output.png");
return;
}
Path modelDir = Path.of(args[0]);
File input = new File(args[1]);
File out = new File(args[2]);
List<String> labels = AnimeLabelPalette.defaultLabels();
try (AnimeSegmenter segmenter = new AnimeSegmenter(modelDir, labels)) {
AnimeSegmentationResult res = segmenter.segment(input);
ImageIO.write(res.getMaskImage(), "png", out);
System.out.println("动漫分割掩码已保存到: " + out.getAbsolutePath());
// 额外保存眼睛分割结果
BufferedImage eyes = segmenter.extractEyes(input);
File eyesOut = new File(out.getParent(), "eyes_" + out.getName());
ImageIO.write(eyes, "png", eyesOut);
System.out.println("眼睛分割结果已保存到: " + eyesOut.getAbsolutePath());
}
}
}

View File

@@ -1,244 +0,0 @@
package com.chuangzhou.vivid2D.ai.anime_segmentation;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.List;
/**
* Anime2ModelWrapper - 对动漫分割模型的封装
*
* 用法示例:
* Anime2ModelWrapper wrapper = Anime2ModelWrapper.load(Paths.get("/path/to/modelDir"));
* Map<String, Anime2ModelWrapper.ResultFiles> out = wrapper.segmentAndSave(
* new File("input.jpg"),
* Set.of("foreground"), // 动漫分割主要关注前景
* Paths.get("outDir")
* );
* wrapper.close();
*/
public class Anime2ModelWrapper implements AutoCloseable {
private final Anime2Segmenter segmenter;
private final List<String> labels; // index -> name
private final Map<String, Integer> palette; // name -> ARGB
private Anime2ModelWrapper(Anime2Segmenter segmenter, List<String> labels, Map<String, Integer> palette) {
this.segmenter = segmenter;
this.labels = labels;
this.palette = palette;
}
/**
* 创建 Anime2Segmenter 实例
*/
public static Anime2ModelWrapper load(Path modelDir) throws Exception {
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels);
Anime2Segmenter s = new Anime2Segmenter(modelDir, labels);
Map<String, Integer> palette = Anime2LabelPalette.animeSegmentationPalette();
return new Anime2ModelWrapper(s, labels, palette);
}
public List<String> getLabels() {
return Collections.unmodifiableList(labels);
}
public Map<String, Integer> getPalette() {
return Collections.unmodifiableMap(palette);
}
/**
* 直接返回分割结果
*/
public Anime2SegmentationResult segment(File inputImage) throws Exception {
return segmenter.segment(inputImage);
}
/**
* 把指定 targets标签名集合从输入图片中分割并保存到 outDir
*/
public Map<String, ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
if (!Files.exists(outDir)) {
Files.createDirectories(outDir);
}
Anime2SegmentationResult res = segment(inputImage);
BufferedImage original = ImageIO.read(inputImage);
BufferedImage maskImage = res.getMaskImage();
int maskW = maskImage.getWidth();
int maskH = maskImage.getHeight();
// 解析 targets
Set<String> realTargets = parseTargetsSet(targets);
Map<String, ResultFiles> saved = new LinkedHashMap<>();
for (String target : realTargets) {
if (!palette.containsKey(target)) {
System.err.println("Warning: unknown label '" + target + "' - skip.");
continue;
}
int targetColor = palette.get(target);
// 1) 生成透明背景的二值掩码(只保留 target 像素)
BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB);
for (int y = 0; y < maskH; y++) {
for (int x = 0; x < maskW; x++) {
int c = maskImage.getRGB(x, y);
if (c == targetColor) {
partMask.setRGB(x, y, targetColor | 0xFF000000); // 保证不透明
} else {
partMask.setRGB(x, y, 0x00000000);
}
}
}
// 2) 将 mask 缩放到与原图一致
BufferedImage maskResized = partMask;
if (original.getWidth() != maskW || original.getHeight() != maskH) {
maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
Graphics2D g = maskResized.createGraphics();
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null);
g.dispose();
}
// 3) 生成叠加图
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
Graphics2D g2 = overlay.createGraphics();
g2.drawImage(original, 0, 0, null);
int rgbOnly = (targetColor & 0x00FFFFFF);
int translucent = (0x88 << 24) | rgbOnly;
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
for (int y = 0; y < colorOverlay.getHeight(); y++) {
for (int x = 0; x < colorOverlay.getWidth(); x++) {
int mc = maskResized.getRGB(x, y);
if ((mc & 0x00FFFFFF) == (targetColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) {
colorOverlay.setRGB(x, y, translucent);
} else {
colorOverlay.setRGB(x, y, 0x00000000);
}
}
}
g2.drawImage(colorOverlay, 0, 0, null);
g2.dispose();
// 保存
String safe = safeFileName(target);
File maskOut = outDir.resolve(safe + "_mask.png").toFile();
File overlayOut = outDir.resolve(safe + "_overlay.png").toFile();
ImageIO.write(maskResized, "png", maskOut);
ImageIO.write(overlay, "png", overlayOut);
saved.put(target, new ResultFiles(maskOut, overlayOut));
}
return saved;
}
private static String safeFileName(String s) {
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
}
private static Set<String> parseTargetsSet(Set<String> in) {
if (in == null || in.isEmpty()) return Collections.emptySet();
if (in.size() == 1) {
String only = in.iterator().next();
if ("all".equalsIgnoreCase(only.trim())) {
return Set.of("foreground"); // 动漫分割主要关注前景
}
}
Set<String> out = new LinkedHashSet<>();
for (String s : in) {
if (s != null) out.add(s.trim());
}
return out;
}
/**
* 关闭底层资源
*/
@Override
public void close() {
try {
segmenter.close();
} catch (Exception ignore) {}
}
/**
* 存放结果文件路径
*/
public static class ResultFiles {
private final File maskFile;
private final File overlayFile;
public ResultFiles(File maskFile, File overlayFile) {
this.maskFile = maskFile;
this.overlayFile = overlayFile;
}
public File getMaskFile() {
return maskFile;
}
public File getOverlayFile() {
return overlayFile;
}
}
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
Path syn = modelDir.resolve("synset.txt");
if (Files.exists(syn)) {
try {
List<String> lines = Files.readAllLines(syn);
List<String> cleaned = new ArrayList<>();
for (String l : lines) {
String s = l.trim();
if (!s.isEmpty()) cleaned.add(s);
}
if (!cleaned.isEmpty()) return Optional.of(cleaned);
} catch (IOException ignore) {}
}
return Optional.empty();
}
/* ================= convenience 主方法(快速测试) ================= */
public static void main(String[] args) throws Exception {
if (args.length < 4) {
System.out.println("用法: Anime2ModelWrapper <modelDir> <inputImage> <outDir> <targetsCommaOrAll>");
System.out.println("示例: Anime2ModelWrapper /models/anime_seg /images/in.jpg outDir foreground");
return;
}
Path modelDir = Path.of(args[0]);
File input = new File(args[1]);
Path out = Path.of(args[2]);
String targetsArg = args[3];
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels);
Set<String> targets;
if ("all".equalsIgnoreCase(targetsArg.trim())) {
targets = new LinkedHashSet<>(labels);
} else {
String[] parts = targetsArg.split(",");
targets = new LinkedHashSet<>();
for (String p : parts) {
if (!p.trim().isEmpty()) targets.add(p.trim());
}
}
try (Anime2ModelWrapper wrapper = Anime2ModelWrapper.load(modelDir)) {
Map<String, ResultFiles> m = wrapper.segmentAndSave(input, targets, out);
m.forEach((k, v) -> {
System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath()));
});
}
}
}

View File

@@ -1,36 +1,15 @@
package com.chuangzhou.vivid2D.ai.anime_segmentation;
import com.chuangzhou.vivid2D.ai.SegmentationResult;
import java.awt.image.BufferedImage;
import java.util.Map;
/**
* 动漫分割结果容器
*/
public class Anime2SegmentationResult {
// 分割掩码图(每个像素的颜色为对应类别颜色)
private final BufferedImage maskImage;
// 类别索引 -> 类别名称
private final Map<Integer, String> labels;
// 类别名称 -> ARGB 颜色
private final Map<String, Integer> palette;
public class Anime2SegmentationResult extends SegmentationResult {
public Anime2SegmentationResult(BufferedImage maskImage, Map<Integer, String> labels, Map<String, Integer> palette) {
this.maskImage = maskImage;
this.labels = labels;
this.palette = palette;
}
public BufferedImage getMaskImage() {
return maskImage;
}
public Map<Integer, String> getLabels() {
return labels;
}
public Map<String, Integer> getPalette() {
return palette;
super(maskImage, labels, palette);
}
}

View File

@@ -1,21 +1,17 @@
package com.chuangzhou.vivid2D.ai.anime_segmentation;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import com.chuangzhou.vivid2D.ai.SegmentationResult;
import com.chuangzhou.vivid2D.ai.Segmenter;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
@@ -28,94 +24,51 @@ import java.util.*;
* Anime2Segmenter: 专门用于动漫分割模型
* 处理 anime-segmentation 模型的二值分割输出
*/
public class Anime2Segmenter implements AutoCloseable {
public static class SegmentationData {
final int[] indices;
final long[] shape;
public SegmentationData(int[] indices, long[] shape) {
this.indices = indices;
this.shape = shape;
}
}
private final ZooModel<Image, SegmentationData> modelWrapper;
private final Predictor<Image, SegmentationData> predictor;
private final List<String> labels;
private final Map<String, Integer> palette;
public class Anime2Segmenter extends Segmenter {
public Anime2Segmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
this.labels = new ArrayList<>(labels);
this.palette = Anime2LabelPalette.animeSegmentationPalette();
Translator<Image, SegmentationData> translator = new Translator<Image, SegmentationData>() {
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
NDManager manager = ctx.getNDManager();
// 调整输入图像尺寸到模型期望的大小 (1024x1024)
Image resized = input.resize(1024, 1024, true);
NDArray array = resized.toNDArray(manager);
// 转换为 CHW 格式并归一化
array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false);
array = array.div(255f);
array = array.expandDims(0); // 添加batch维度
return new NDList(array);
}
@Override
public SegmentationData processOutput(TranslatorContext ctx, NDList list) {
if (list == null || list.isEmpty()) {
throw new IllegalStateException("Model did not return any output.");
}
NDArray out = list.get(0);
// 动漫分割模型输出形状: [1, 1, H, W] - 单通道概率图
// 应用sigmoid并二值化
NDArray probabilities = out.div(out.neg().exp().add(1));
NDArray binaryMask = probabilities.gt(0.5).toType(DataType.INT32, false);
// 移除batch和channel维度
if (binaryMask.getShape().dimension() == 4) {
binaryMask = binaryMask.squeeze(0).squeeze(0);
}
// 转换为Java数组
long[] finalShape = binaryMask.getShape().getShape();
int[] indices = binaryMask.toIntArray();
return new SegmentationData(indices, finalShape);
}
@Override
public Batchifier getBatchifier() {
return null;
}
};
Criteria<Image, SegmentationData> criteria = Criteria.builder()
.setTypes(Image.class, SegmentationData.class)
.optModelPath(modelDir)
.optEngine("PyTorch")
.optTranslator(translator)
.build();
this.modelWrapper = criteria.loadModel();
this.predictor = modelWrapper.newPredictor();
super(modelDir, labels);
}
public Anime2SegmentationResult segment(File imgFile) throws TranslateException, IOException {
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
NDManager manager = ctx.getNDManager();
// 调整输入图像尺寸到模型期望的大小 (1024x1024)
Image resized = input.resize(1024, 1024, true);
NDArray array = resized.toNDArray(manager);
// 转换为 CHW 格式并归一化
array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false);
array = array.div(255f);
array = array.expandDims(0); // 添加batch维度
return new NDList(array);
}
@Override
public SegmentationData processOutput(TranslatorContext ctx, NDList list) {
if (list == null || list.isEmpty()) {
throw new IllegalStateException("Model did not return any output.");
}
NDArray out = list.get(0);
// 动漫分割模型输出形状: [1, 1, H, W] - 单通道概率图
// 应用sigmoid并二值化
NDArray probabilities = out.div(out.neg().exp().add(1));
NDArray binaryMask = probabilities.gt(0.5).toType(DataType.INT32, false);
if (binaryMask.getShape().dimension() == 4) {
binaryMask = binaryMask.squeeze(0).squeeze(0);
}
long[] finalShape = binaryMask.getShape().getShape();
int[] indices = binaryMask.toIntArray();
return new SegmentationData(indices, finalShape);
}
@Override
public SegmentationResult segment(File imgFile) throws TranslateException, IOException {
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
SegmentationData data = predictor.predict(img);
Segmenter.SegmentationData data = predictor.predict(img);
long[] shp = data.shape;
int[] indices = data.indices;
int height, width;
if (shp.length == 2) {
height = (int) shp[0];
@@ -123,14 +76,11 @@ public class Anime2Segmenter implements AutoCloseable {
} else {
throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp));
}
// 创建分割掩码
BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
Map<Integer, String> labelsMap = new HashMap<>();
for (int i = 0; i < labels.size(); i++) {
labelsMap.put(i, labels.get(i));
}
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int idx = indices[y * width + x];
@@ -139,8 +89,7 @@ public class Anime2Segmenter implements AutoCloseable {
mask.setRGB(x, y, argb);
}
}
return new Anime2SegmentationResult(mask, labelsMap, palette);
return new SegmentationResult(mask, labelsMap, palette);
}
@Override
@@ -154,22 +103,4 @@ public class Anime2Segmenter implements AutoCloseable {
} catch (Exception ignore) {
}
}
public static void main(String[] args) throws Exception {
if (args.length < 3) {
System.out.println("用法: java Anime2Segmenter <modelDir> <inputImage> <outputMaskPng>");
return;
}
Path modelDir = Path.of(args[0]);
File input = new File(args[1]);
File out = new File(args[2]);
List<String> labels = Anime2LabelPalette.defaultLabels();
try (Anime2Segmenter s = new Anime2Segmenter(modelDir, labels)) {
Anime2SegmentationResult res = s.segment(input);
ImageIO.write(res.getMaskImage(), "png", out);
System.out.println("动漫分割掩码已保存到: " + out.getAbsolutePath());
}
}
}

View File

@@ -1,5 +1,8 @@
package com.chuangzhou.vivid2D.ai.anime_segmentation;
import com.chuangzhou.vivid2D.ai.SegmentationResult;
import com.chuangzhou.vivid2D.ai.VividModelWrapper;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
@@ -22,16 +25,11 @@ import java.util.List;
* // out contains 每个目标标签对应的 mask+overlay 文件路径
* wrapper.close();
*/
public class Anime2VividModelWrapper implements AutoCloseable {
private final Anime2Segmenter segmenter;
private final List<String> labels; // index -> name
private final Map<String, Integer> palette; // name -> ARGB
public class Anime2VividModelWrapper extends VividModelWrapper<Anime2Segmenter> {
private Anime2VividModelWrapper(Anime2Segmenter segmenter, List<String> labels, Map<String, Integer> palette) {
this.segmenter = segmenter;
this.labels = labels;
this.palette = palette;
super(segmenter, labels, palette);
}
/**
@@ -56,7 +54,7 @@ public class Anime2VividModelWrapper implements AutoCloseable {
/**
* 直接返回分割结果Anime2SegmentationResult
*/
public Anime2SegmentationResult segment(File inputImage) throws Exception {
public SegmentationResult segment(File inputImage) throws Exception {
return segmenter.segment(inputImage);
}
@@ -66,12 +64,12 @@ public class Anime2VividModelWrapper implements AutoCloseable {
* <p>
* 返回值Map<labelName, ResultFiles>ResultFiles 包含 maskFile、overlayFile两个 PNG
*/
public Map<String, ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
public Map<String, VividModelWrapper.ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
if (!Files.exists(outDir)) {
Files.createDirectories(outDir);
}
Anime2SegmentationResult res = segment(inputImage);
SegmentationResult res = segment(inputImage);
BufferedImage original = ImageIO.read(inputImage);
BufferedImage maskImage = res.getMaskImage();
@@ -84,7 +82,6 @@ public class Anime2VividModelWrapper implements AutoCloseable {
for (String target : realTargets) {
if (!palette.containsKey(target)) {
// 尝试忽略大小写匹配
String finalTarget = target;
Optional<String> matched = palette.keySet().stream()
.filter(k -> k.equalsIgnoreCase(finalTarget))
@@ -154,109 +151,4 @@ public class Anime2VividModelWrapper implements AutoCloseable {
return saved;
}
private static String safeFileName(String s) {
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
}
private static Set<String> parseTargetsSet(Set<String> in) {
if (in == null || in.isEmpty()) return Collections.emptySet();
// 若包含单个 "all"
if (in.size() == 1) {
String only = in.iterator().next();
if ("all".equalsIgnoreCase(only.trim())) {
// 由调用方自行取 labels这里返回 sentinel, but caller already checks palette
// For convenience, return a set containing "all" and let caller logic handle it earlier.
return Set.of("all");
}
}
// 直接返回 trim 后的小写不变集合(保持用户传入的名字)
Set<String> out = new LinkedHashSet<>();
for (String s : in) {
if (s != null) out.add(s.trim());
}
return out;
}
/**
* 关闭底层资源
*/
@Override
public void close() {
try {
segmenter.close();
} catch (Exception ignore) {}
}
/**
* 存放结果文件路径
*/
public static class ResultFiles {
private final File maskFile;
private final File overlayFile;
public ResultFiles(File maskFile, File overlayFile) {
this.maskFile = maskFile;
this.overlayFile = overlayFile;
}
public File getMaskFile() {
return maskFile;
}
public File getOverlayFile() {
return overlayFile;
}
}
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
Path syn = modelDir.resolve("synset.txt");
if (Files.exists(syn)) {
try {
List<String> lines = Files.readAllLines(syn);
List<String> cleaned = new ArrayList<>();
for (String l : lines) {
String s = l.trim();
if (!s.isEmpty()) cleaned.add(s);
}
if (!cleaned.isEmpty()) return Optional.of(cleaned);
} catch (IOException ignore) {}
}
return Optional.empty();
}
/* ================= convenience 主方法(快速测试) ================= */
public static void main(String[] args) throws Exception {
if (args.length < 4) {
System.out.println("用法: Anime2VividModelWrapper <modelDir> <inputImage> <outDir> <targetsCommaOrAll>");
System.out.println("示例: Anime2VividModelWrapper /models/anime_seg /images/in.jpg outDir foreground");
System.out.println("示例: Anime2VividModelWrapper /models/anime_seg /images/in.jpg outDir all");
return;
}
Path modelDir = Path.of(args[0]);
File input = new File(args[1]);
Path out = Path.of(args[2]);
String targetsArg = args[3];
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels);
Set<String> targets;
if ("all".equalsIgnoreCase(targetsArg.trim())) {
targets = new LinkedHashSet<>(labels);
} else {
String[] parts = targetsArg.split(",");
targets = new LinkedHashSet<>();
for (String p : parts) {
if (!p.trim().isEmpty()) targets.add(p.trim());
}
}
try (Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(modelDir)) {
Map<String, ResultFiles> m = wrapper.segmentAndSave(input, targets, out);
m.forEach((k, v) -> {
System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath()));
});
}
}
}

View File

@@ -7,7 +7,7 @@ import java.util.*;
* 颜色值基于 zllrunning/face-parsing.PyTorch 仓库的 test.py 文件
* 标签索引必须与模型输出索引一致0-18
*/
public class LabelPalette {
public class BiSeNetLabelPalette {
/**
* BiSeNet 人脸解析模型的标准标签19个类别索引 0-18

View File

@@ -0,0 +1,15 @@
package com.chuangzhou.vivid2D.ai.face_parsing;
import com.chuangzhou.vivid2D.ai.SegmentationResult;
import java.awt.image.BufferedImage;
import java.util.Map;
/**
* 分割结果容器
*/
public class BiSeNetSegmentationResult extends SegmentationResult {
public BiSeNetSegmentationResult(BufferedImage maskImage, Map<Integer, String> labels, Map<String, Integer> palette) {
super(maskImage, labels, palette);
}
}

View File

@@ -0,0 +1,101 @@
package com.chuangzhou.vivid2D.ai.face_parsing;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import com.chuangzhou.vivid2D.ai.SegmentationResult;
import com.chuangzhou.vivid2D.ai.Segmenter;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.*;
/**
* Segmenter: 加载模型并对图片做语义分割
*
* 说明:
* - Translator.processOutput 在翻译器层就把模型输出处理成 (H, W) 的类别索引 NDArray
* 并把该 NDArray 拷贝到 persistentManager 中返回,从而避免后续 native 资源被释放的问题。
* - 这里改为在 Translator 内部把 classMap 转为 Java int[](通过 classMap.toIntArray()
* 再用 persistentManager.create(int[], shape) 创建新的 NDArray 返回,确保安全。
*/
public class BiSeNetSegmenter extends Segmenter {
public BiSeNetSegmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
super(modelDir, labels);
}
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
NDManager manager = ctx.getNDManager();
NDArray array = input.toNDArray(manager);
array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false);
array = array.div(255f);
array = array.expandDims(0);
return new NDList(array);
}
@Override
public Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list) {
if (list == null || list.isEmpty()) {
throw new IllegalStateException("Model did not return any output.");
}
NDArray out = list.get(0);
NDArray classMap;
// 1. 解析模型输出,得到类别图谱 (classMap)
long[] shape = out.getShape().getShape();
if (shape.length == 4 && shape[1] > 1) {
classMap = out.argMax(1);
} else if (shape.length == 3) {
classMap = (shape[0] == 1) ? out : out.argMax(0);
} else if (shape.length == 2) {
classMap = out;
} else {
throw new IllegalStateException("Unexpected output shape: " + Arrays.toString(shape));
}
if (classMap.getShape().dimension() == 3) {
classMap = classMap.squeeze(0);
}
// 2. *** 关键步骤 ***
// 在 NDArray 仍然有效的上下文中,将其转换为 Java 原生类型
// 首先,确保数据类型是 INT32
NDArray int32ClassMap = classMap.toType(DataType.INT32, false);
// 然后,获取形状和 int[] 数组
long[] finalShape = int32ClassMap.getShape().getShape();
int[] indices = int32ClassMap.toIntArray();
// 3. 将 Java 对象封装并返回
return new SegmentationData(indices, finalShape);
}
@Override
public SegmentationResult segment(File imgFile) throws TranslateException, IOException {
return super.segment(imgFile);
}
@Override
public void close() {
super.close();
}
}

View File

@@ -1,5 +1,8 @@
package com.chuangzhou.vivid2D.ai.face_parsing;
import com.chuangzhou.vivid2D.ai.SegmentationResult;
import com.chuangzhou.vivid2D.ai.VividModelWrapper;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
@@ -22,35 +25,29 @@ import java.util.List;
* // out contains 每个目标标签对应的 mask+overlay 文件路径
* wrapper.close();
*/
public class VividModelWrapper implements AutoCloseable {
public class BiSeNetVividModelWrapper extends VividModelWrapper<BiSeNetSegmenter> {
private final Segmenter segmenter;
private final List<String> labels; // index -> name
private final Map<String, Integer> palette; // name -> ARGB
private VividModelWrapper(Segmenter segmenter, List<String> labels, Map<String, Integer> palette) {
this.segmenter = segmenter;
this.labels = labels;
this.palette = palette;
private BiSeNetVividModelWrapper(BiSeNetSegmenter segmenter, List<String> labels, Map<String, Integer> palette) {
super(segmenter, labels, palette);
}
/**
* 读取 modelDir/synset.txt每行一个标签若不存在则使用 LabelPalette.defaultLabels()
* 并创建 Segmenter 实例
*/
public static VividModelWrapper load(Path modelDir) throws Exception {
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels);
Segmenter s = new Segmenter(modelDir, labels);
Map<String, Integer> palette = LabelPalette.defaultPalette();
return new VividModelWrapper(s, labels, palette);
public static BiSeNetVividModelWrapper load(Path modelDir) throws Exception {
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(BiSeNetLabelPalette::defaultLabels);
BiSeNetSegmenter s = new BiSeNetSegmenter(modelDir, labels);
Map<String, Integer> palette = BiSeNetLabelPalette.defaultPalette();
return new BiSeNetVividModelWrapper(s, labels, palette);
}
public List<String> getLabels() {
return Collections.unmodifiableList(labels);
return super.getLabels();
}
public Map<String, Integer> getPalette() {
return Collections.unmodifiableMap(palette);
return super.getPalette();
}
/**
@@ -66,25 +63,19 @@ public class VividModelWrapper implements AutoCloseable {
* <p>
* 返回值Map<labelName, ResultFiles>ResultFiles 包含 maskFileoverlayFile两个 PNG
*/
public Map<String, ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
public Map<String, VividModelWrapper.ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
if (!Files.exists(outDir)) {
Files.createDirectories(outDir);
}
SegmentationResult res = segment(inputImage);
BufferedImage original = ImageIO.read(inputImage);
BufferedImage maskImage = res.getMaskImage();
int maskW = maskImage.getWidth();
int maskH = maskImage.getHeight();
// 解析 targets
Set<String> realTargets = parseTargetsSet(targets);
Map<String, ResultFiles> saved = new LinkedHashMap<>();
for (String target : realTargets) {
if (!palette.containsKey(target)) {
// 尝试忽略大小写匹配
String finalTarget = target;
Optional<String> matched = palette.keySet().stream()
.filter(k -> k.equalsIgnoreCase(finalTarget))
@@ -95,10 +86,7 @@ public class VividModelWrapper implements AutoCloseable {
continue;
}
}
int targetColor = palette.get(target);
// 1) 生成透明背景的二值掩码只保留 target 像素
BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB);
for (int y = 0; y < maskH; y++) {
for (int x = 0; x < maskW; x++) {
@@ -110,8 +98,6 @@ public class VividModelWrapper implements AutoCloseable {
}
}
}
// 2) mask 缩放到与原图一致如果需要并生成 overlay半透明
BufferedImage maskResized = partMask;
if (original.getWidth() != maskW || original.getHeight() != maskH) {
maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
@@ -120,11 +106,9 @@ public class VividModelWrapper implements AutoCloseable {
g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null);
g.dispose();
}
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
Graphics2D g2 = overlay.createGraphics();
g2.drawImage(original, 0, 0, null);
// 半透明颜色alpha = 0x88
int rgbOnly = (targetColor & 0x00FFFFFF);
int translucent = (0x88 << 24) | rgbOnly;
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
@@ -140,44 +124,17 @@ public class VividModelWrapper implements AutoCloseable {
}
g2.drawImage(colorOverlay, 0, 0, null);
g2.dispose();
// 保存
String safe = safeFileName(target);
File maskOut = outDir.resolve(safe + "_mask.png").toFile();
File overlayOut = outDir.resolve(safe + "_overlay.png").toFile();
ImageIO.write(maskResized, "png", maskOut);
ImageIO.write(overlay, "png", overlayOut);
saved.put(target, new ResultFiles(maskOut, overlayOut));
}
return saved;
}
private static String safeFileName(String s) {
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
}
private static Set<String> parseTargetsSet(Set<String> in) {
if (in == null || in.isEmpty()) return Collections.emptySet();
// 若包含单个 "all"
if (in.size() == 1) {
String only = in.iterator().next();
if ("all".equalsIgnoreCase(only.trim())) {
// 由调用方自行取 labels这里返回 sentinel, but caller already checks palette
// For convenience, return a set containing "all" and let caller logic handle it earlier.
return Set.of("all");
}
}
// 直接返回 trim 后的小写不变集合保持用户传入的名字
Set<String> out = new LinkedHashSet<>();
for (String s : in) {
if (s != null) out.add(s.trim());
}
return out;
}
/**
* 关闭底层资源
*/
@@ -187,75 +144,4 @@ public class VividModelWrapper implements AutoCloseable {
segmenter.close();
} catch (Exception ignore) {}
}
/**
* 存放结果文件路径
*/
public static class ResultFiles {
private final File maskFile;
private final File overlayFile;
public ResultFiles(File maskFile, File overlayFile) {
this.maskFile = maskFile;
this.overlayFile = overlayFile;
}
public File getMaskFile() {
return maskFile;
}
public File getOverlayFile() {
return overlayFile;
}
}
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
Path syn = modelDir.resolve("synset.txt");
if (Files.exists(syn)) {
try {
List<String> lines = Files.readAllLines(syn);
List<String> cleaned = new ArrayList<>();
for (String l : lines) {
String s = l.trim();
if (!s.isEmpty()) cleaned.add(s);
}
if (!cleaned.isEmpty()) return Optional.of(cleaned);
} catch (IOException ignore) {}
}
return Optional.empty();
}
/* ================= convenience 主方法(快速测试) ================= */
public static void main(String[] args) throws Exception {
if (args.length < 4) {
System.out.println("用法: VividModelWrapper <modelDir> <inputImage> <outDir> <targetsCommaOrAll>");
System.out.println("示例: VividModelWrapper /models/bisenet /images/in.jpg outDir eye,face");
return;
}
Path modelDir = Path.of(args[0]);
File input = new File(args[1]);
Path out = Path.of(args[2]);
String targetsArg = args[3];
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels);
Set<String> targets;
if ("all".equalsIgnoreCase(targetsArg.trim())) {
targets = new LinkedHashSet<>(labels);
} else {
String[] parts = targetsArg.split(",");
targets = new LinkedHashSet<>();
for (String p : parts) {
if (!p.trim().isEmpty()) targets.add(p.trim());
}
}
try (VividModelWrapper wrapper = VividModelWrapper.load(modelDir)) {
Map<String, ResultFiles> m = wrapper.segmentAndSave(input, targets, out);
m.forEach((k, v) -> {
System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath()));
});
}
}
}

View File

@@ -1,193 +0,0 @@
package com.chuangzhou.vivid2D.ai.face_parsing;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.*;
/**
* Segmenter: 加载模型并对图片做语义分割
*
* 说明:
* - Translator.processOutput 在翻译器层就把模型输出处理成 (H, W) 的类别索引 NDArray
* 并把该 NDArray 拷贝到 persistentManager 中返回,从而避免后续 native 资源被释放的问题。
* - 这里改为在 Translator 内部把 classMap 转为 Java int[](通过 classMap.toIntArray()
* 再用 persistentManager.create(int[], shape) 创建新的 NDArray 返回,确保安全。
*/
public class Segmenter implements AutoCloseable {
// 内部类用于从Translator安全地传出数据
public static class SegmentationData {
final int[] indices;
final long[] shape;
public SegmentationData(int[] indices, long[] shape) {
this.indices = indices;
this.shape = shape;
}
}
private final ZooModel<Image, SegmentationData> modelWrapper;
private final Predictor<Image, SegmentationData> predictor;
private final List<String> labels;
private final Map<String, Integer> palette;
public Segmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
this.labels = new ArrayList<>(labels);
this.palette = LabelPalette.defaultPalette();
// Translator 的输出类型现在是 SegmentationData
Translator<Image, SegmentationData> translator = new Translator<Image, SegmentationData>() {
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
NDManager manager = ctx.getNDManager();
NDArray array = input.toNDArray(manager);
array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false);
array = array.div(255f);
array = array.expandDims(0);
return new NDList(array);
}
@Override
public SegmentationData processOutput(TranslatorContext ctx, NDList list) {
if (list == null || list.isEmpty()) {
throw new IllegalStateException("Model did not return any output.");
}
NDArray out = list.get(0);
NDArray classMap;
// 1. 解析模型输出,得到类别图谱 (classMap)
long[] shape = out.getShape().getShape();
if (shape.length == 4 && shape[1] > 1) {
classMap = out.argMax(1);
} else if (shape.length == 3) {
classMap = (shape[0] == 1) ? out : out.argMax(0);
} else if (shape.length == 2) {
classMap = out;
} else {
throw new IllegalStateException("Unexpected output shape: " + Arrays.toString(shape));
}
if (classMap.getShape().dimension() == 3) {
classMap = classMap.squeeze(0);
}
// 2. *** 关键步骤 ***
// 在 NDArray 仍然有效的上下文中,将其转换为 Java 原生类型
// 首先,确保数据类型是 INT32
NDArray int32ClassMap = classMap.toType(DataType.INT32, false);
// 然后,获取形状和 int[] 数组
long[] finalShape = int32ClassMap.getShape().getShape();
int[] indices = int32ClassMap.toIntArray();
// 3. 将 Java 对象封装并返回
return new SegmentationData(indices, finalShape);
}
@Override
public Batchifier getBatchifier() {
return null; // 或者根据需要使用 Batchifier.STACK
}
};
// Criteria 的类型也需要更新
Criteria<Image, SegmentationData> criteria = Criteria.builder()
.setTypes(Image.class, SegmentationData.class)
.optModelPath(modelDir)
.optEngine("PyTorch")
.optTranslator(translator)
.build();
this.modelWrapper = criteria.loadModel();
this.predictor = modelWrapper.newPredictor();
}
public SegmentationResult segment(File imgFile) throws TranslateException, IOException {
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
// predict 方法现在直接返回安全的 Java 对象
SegmentationData data = predictor.predict(img);
long[] shp = data.shape;
int[] indices = data.indices;
int height, width;
if (shp.length == 2) {
height = (int) shp[0];
width = (int) shp[1];
} else {
throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp));
}
// 后续处理完全基于 Java 对象,不再有 Native resource 问题
BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
Map<Integer, String> labelsMap = new HashMap<>();
for (int i = 0; i < labels.size(); i++) {
labelsMap.put(i, labels.get(i));
}
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int idx = indices[y * width + x];
String label = labelsMap.getOrDefault(idx, "unknown");
int argb = palette.getOrDefault(label, 0xFF00FF00);
mask.setRGB(x, y, argb);
}
}
return new SegmentationResult(mask, labelsMap, palette);
}
@Override
public void close() {
try {
predictor.close();
} catch (Exception ignore) {
}
try {
modelWrapper.close();
} catch (Exception ignore) {
}
}
// 小测试主函数(示例)
public static void main(String[] args) throws Exception {
if (args.length < 3) {
System.out.println("用法: java Segmenter <modelDir> <inputImage> <outputMaskPng>");
return;
}
Path modelDir = Path.of(args[0]);
File input = new File(args[1]);
File out = new File(args[2]);
List<String> labels = LabelPalette.defaultLabels();
try (Segmenter s = new Segmenter(modelDir, labels)) {
SegmentationResult res = s.segment(input);
ImageIO.write(res.getMaskImage(), "png", out);
System.out.println("分割掩码已保存到: " + out.getAbsolutePath());
}
}
}

View File

@@ -1,184 +0,0 @@
package com.chuangzhou.vivid2D.ai.face_parsing;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.awt.*;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.List;
/**
* SegmenterExample
*
* 使用说明(命令行):
* java -cp <your-classpath> com.chuangzhou.vivid2D.ai.face_parsing.SegmenterExample \
* <modelDir> <inputImage> <outputDir> <targetsCommaOrAll>
*
* 示例:
* java ... SegmenterExample /models/face_bisent /images/in.jpg /out "eye,face"
* java ... SegmenterExample /models/face_bisent /images/in.jpg /out all
*/
public class SegmenterExample {
public static void main(String[] args) throws Exception {
if (args.length < 4) {
System.err.println("用法: SegmenterExample <modelDir> <inputImage> <outputDir> <targetsCommaOrAll>");
System.err.println("例如: SegmenterExample /models/face_bisent input.jpg outDir eye,face");
return;
}
Path modelDir = Path.of(args[0]);
File inputImage = new File(args[1]);
Path outDir = Path.of(args[2]);
String targetsArg = args[3];
if (!Files.exists(modelDir)) {
System.err.println("modelDir 不存在: " + modelDir);
return;
}
if (!inputImage.exists()) {
System.err.println("输入图片不存在: " + inputImage.getAbsolutePath());
return;
}
if (!Files.exists(outDir)) {
Files.createDirectories(outDir);
}
// 读取 synset.txt如果有否则使用默认 LabelPalette
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels);
// 打开 Segmenter
try (Segmenter segmenter = new Segmenter(modelDir, labels)) {
SegmentationResult res = segmenter.segment(inputImage);
// 原始图片
BufferedImage original = ImageIO.read(inputImage);
// palette: labelName -> ARGB int
Map<String, Integer> palette = res.getPalette();
Map<Integer, String> labelsMap = res.getLabels(); // index -> name
// 解析目标 labels 列表
Set<String> targets = parseTargets(targetsArg, labels);
System.out.println("Will export targets: " + targets);
// maskImage: 每像素是类别颜色ARGB
BufferedImage maskImage = res.getMaskImage();
int w = maskImage.getWidth();
int h = maskImage.getHeight();
// 为快速查 color -> labelName
Map<Integer, String> colorToLabel = new HashMap<>();
for (Map.Entry<String, Integer> e : palette.entrySet()) {
colorToLabel.put(e.getValue(), e.getKey());
}
// 对每个 target 生成单独的 mask 和 overlay
for (String target : targets) {
if (!palette.containsKey(target)) {
System.err.println("警告:模型 palette 中没有标签 '" + target + "',跳过。");
continue;
}
int targetColor = palette.get(target);
// 1) 生成透明背景的二值掩码(只保留 target 像素)
BufferedImage partMask = new BufferedImage(w, h, BufferedImage.TYPE_INT_ARGB);
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
int c = maskImage.getRGB(x, y);
if (c == targetColor) {
// 保留为不透明(使用原始颜色)
partMask.setRGB(x, y, targetColor);
} else {
// 透明
partMask.setRGB(x, y, 0x00000000);
}
}
}
// 2) 生成 overlay在原图上叠加半透明的 targetColor 区域
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
// 若分辨率不同,先缩放 mask 到原图大小(简单粗暴地按尺寸相同假设)
BufferedImage maskResized = maskImage;
if (original.getWidth() != w || original.getHeight() != h) {
// 简单缩放 mask 到原图尺寸
maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
Graphics2D g = maskResized.createGraphics();
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g.drawImage(maskImage, 0, 0, original.getWidth(), original.getHeight(), null);
g.dispose();
}
// 在 overlay 上先画原图
Graphics2D g2 = overlay.createGraphics();
g2.drawImage(original, 0, 0, null);
// 创建半透明颜色(将 targetColor 的 alpha 设为 0x88
int rgbOnly = (targetColor & 0x00FFFFFF);
int translucent = (0x88 << 24) | rgbOnly;
// 创建一个图像,把 mask 中对应像素设置为 translucent否则透明
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
for (int y = 0; y < colorOverlay.getHeight(); y++) {
for (int x = 0; x < colorOverlay.getWidth(); x++) {
int mc = maskResized.getRGB(x, y);
if (mc == targetColor) {
colorOverlay.setRGB(x, y, translucent);
} else {
colorOverlay.setRGB(x, y, 0x00000000);
}
}
}
// 将 colorOverlay 画到 overlay 上
g2.drawImage(colorOverlay, 0, 0, null);
g2.dispose();
// 保存文件
File maskOut = outDir.resolve(safeFileName(target) + "_mask.png").toFile();
File overlayOut = outDir.resolve(safeFileName(target) + "_overlay.png").toFile();
ImageIO.write(partMask, "png", maskOut);
ImageIO.write(overlay, "png", overlayOut);
System.out.println("Saved mask: " + maskOut.getAbsolutePath());
System.out.println("Saved overlay: " + overlayOut.getAbsolutePath());
}
}
}
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
Path syn = modelDir.resolve("synset.txt");
if (Files.exists(syn)) {
try {
List<String> lines = Files.readAllLines(syn);
List<String> cleaned = new ArrayList<>();
for (String l : lines) {
String s = l.trim();
if (!s.isEmpty()) cleaned.add(s);
}
if (!cleaned.isEmpty()) return Optional.of(cleaned);
} catch (IOException ignore) {}
}
return Optional.empty();
}
private static Set<String> parseTargets(String arg, List<String> availableLabels) {
String a = arg.trim();
if (a.equalsIgnoreCase("all")) {
return new LinkedHashSet<>(availableLabels);
}
String[] parts = a.split(",");
Set<String> out = new LinkedHashSet<>();
for (String p : parts) {
String t = p.trim();
if (!t.isEmpty()) out.add(t);
}
return out;
}
private static String safeFileName(String s) {
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
}
}

View File

@@ -16,7 +16,7 @@ public class AI2Test {
System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8));
// 使用 AnimeModelWrapper 而不是 VividModelWrapper
AnimeModelWrapper wrapper = AnimeModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\Anime-Face-Segmentation\\anime_unet.pt"));
AnimeModelWrapper wrapper = AnimeModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\Anime-Face-Segmentation\\Anime-Face-Segmentation.pt"));
// 使用 Anime-Face-Segmentation 的 7 个标签
Set<String> animeLabels = Set.of(

View File

@@ -1,7 +1,6 @@
package com.chuangzhou.vivid2D.test;
import com.chuangzhou.vivid2D.ai.anime_segmentation.Anime2VividModelWrapper;
import com.chuangzhou.vivid2D.ai.face_parsing.VividModelWrapper;
import java.io.PrintStream;
import java.nio.charset.StandardCharsets;
@@ -15,7 +14,7 @@ public class AI3Test {
public static void main(String[] args) throws Exception {
System.setOut(new PrintStream(System.out, true, StandardCharsets.UTF_8));
System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8));
Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\anime-segmentation-main\\isnetis_traced.pt"));
Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\anime-segmentation-main\\anime-segmentation.pt"));
Set<String> faceLabels = Set.of("foreground");

View File

@@ -1,6 +1,6 @@
package com.chuangzhou.vivid2D.test;
import com.chuangzhou.vivid2D.ai.face_parsing.VividModelWrapper;
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetVividModelWrapper;
import java.io.PrintStream;
import java.nio.charset.StandardCharsets;
@@ -14,7 +14,7 @@ public class AITest {
public static void main(String[] args) throws Exception {
System.setOut(new PrintStream(System.out, true, StandardCharsets.UTF_8));
System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8));
VividModelWrapper wrapper = VividModelWrapper.load(Paths.get("C:\\models\\bisenet_face_parsing.pt"));
BiSeNetVividModelWrapper wrapper = BiSeNetVividModelWrapper.load(Paths.get("C:\\models\\bisenet_face_parsing.pt"));
// 使用 BiSeNet 人脸解析模型的 18 个非背景标签
Set<String> faceLabels = Set.of(