refactor(ai):重构分割模型包装类继承结构- 将 Anime2ModelWrapper、Anime2VividModelWrapper 和 AnimeModelWrapper 改为继承自 VividModelWrapper 基类
- 移除重复的 ResultFiles 内部类和相关工具方法实现 - Anime2Segmenter 和 AnimeSegmenter 继承自抽象基类 Segmenter - Anime2SegmentationResult与 AnimeSegmentationResult 继承 SegmentationResult - 重命名 LabelPalette 为 BiSeNetLabelPalette 并调整其引用 - 更新模型路径配置以匹配新的文件命名约定 - 删除冗余的 getLabels() 和 getPalette() 方法定义 - 简化 segmentAndSave 方法中的类型转换逻辑- 移除已被继承方法替代的手动资源管理代码 - 调整 import 语句以反映包结构调整- 清理不再需要的独立主测试函数入口点- 修改字段访问权限以符合继承设计模式 - 替换具体的返回类型为更通用的 SegmentationResult 接口- 整合公共功能至基类减少子类间重复代码 - 统一分割后处理流程提高模块复用性 - 引入泛型支持增强 Wrapper 类型安全性 - 更新注释文档保持与最新架构同步 - 优化异常处理策略统一关闭资源方式 - 规范文件命名规则便于未来维护扩展 - 提取共通逻辑到父类降低耦合度 - 完善类型检查避免运行时 ClassCastException 风险
This commit is contained in:
221
src/main/java/com/chuangzhou/vivid2D/ai/ModelManagement.java
Normal file
221
src/main/java/com/chuangzhou/vivid2D/ai/ModelManagement.java
Normal file
@@ -0,0 +1,221 @@
|
||||
package com.chuangzhou.vivid2D.ai;
|
||||
|
||||
import com.chuangzhou.vivid2D.ai.anime_face_segmentation.AnimeModelWrapper;
|
||||
import com.chuangzhou.vivid2D.ai.anime_segmentation.Anime2VividModelWrapper;
|
||||
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetVividModelWrapper;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* 模型管理器 - 负责模型的注册、分类和检索
|
||||
*/
|
||||
public class ModelManagement {
|
||||
private final Map<String, Class<?>> models = new ConcurrentHashMap<>();
|
||||
private final Map<String, List<String>> modelsByCategory = new ConcurrentHashMap<>();
|
||||
private final List<String> modelDisplayNames = new ArrayList<>();
|
||||
private final Map<String, String> displayNameToRegistrationName = new ConcurrentHashMap<>();
|
||||
|
||||
private ModelManagement() {
|
||||
initializeDefaultCategories();
|
||||
registerDefaultModels();
|
||||
}
|
||||
|
||||
/**
|
||||
* 初始化默认分类
|
||||
*/
|
||||
private void initializeDefaultCategories() {
|
||||
modelsByCategory.put("Image Segmentation", new ArrayList<>());
|
||||
modelsByCategory.put("Image Processing", new ArrayList<>());
|
||||
modelsByCategory.put("Image Generation", new ArrayList<>());
|
||||
modelsByCategory.put("Image Inpainting", new ArrayList<>());
|
||||
modelsByCategory.put("Image Completion", new ArrayList<>());
|
||||
modelsByCategory.put("Face Analysis", new ArrayList<>());
|
||||
}
|
||||
|
||||
/**
|
||||
* 注册默认模型
|
||||
*/
|
||||
private void registerDefaultModels() {
|
||||
registerModel("segmentation:anime_face", "Anime Face Segmentation",
|
||||
AnimeModelWrapper.class, "Image Segmentation");
|
||||
registerModel("segmentation:anime", "Anime Image Segmentation",
|
||||
Anime2VividModelWrapper.class, "Image Segmentation");
|
||||
registerModel("segmentation:face_parsing", "Face Parsing",
|
||||
BiSeNetVividModelWrapper.class, "Image Segmentation");
|
||||
}
|
||||
|
||||
/**
|
||||
* 注册模型
|
||||
* @param modelRegistrationName 注册名称,格式必须为 "category:model_name"
|
||||
* @param modelDisplayName 模型显示名称
|
||||
* @param modelClass 模型类
|
||||
* @param category 模型类别
|
||||
*/
|
||||
public void registerModel(String modelRegistrationName, String modelDisplayName,
|
||||
Class<?> modelClass, String category) {
|
||||
if (!isValidRegistrationName(modelRegistrationName)) {
|
||||
throw new IllegalArgumentException(
|
||||
"Invalid registration name format. Expected 'category:model_name', got: " + modelRegistrationName);
|
||||
}
|
||||
if (models.containsKey(modelRegistrationName)) {
|
||||
throw new IllegalArgumentException(
|
||||
"Model registration name already exists: " + modelRegistrationName);
|
||||
}
|
||||
if (displayNameToRegistrationName.containsKey(modelDisplayName)) {
|
||||
throw new IllegalArgumentException(
|
||||
"Model display name already exists: " + modelDisplayName);
|
||||
}
|
||||
if (!modelsByCategory.containsKey(category)) {
|
||||
modelsByCategory.put(category, new ArrayList<>());
|
||||
}
|
||||
models.put(modelRegistrationName, modelClass);
|
||||
displayNameToRegistrationName.put(modelDisplayName, modelRegistrationName);
|
||||
modelDisplayNames.add(modelDisplayName);
|
||||
modelsByCategory.get(category).add(modelRegistrationName);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证注册名称格式
|
||||
*/
|
||||
private boolean isValidRegistrationName(String name) {
|
||||
return name != null && name.matches("^[a-zA-Z0-9_]+:[a-zA-Z0-9_]+$");
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过显示名称获取模型类
|
||||
*/
|
||||
public Class<?> getModel(String modelDisplayName) {
|
||||
String registrationName = displayNameToRegistrationName.get(modelDisplayName);
|
||||
return registrationName != null ? models.get(registrationName) : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过索引获取模型类
|
||||
*/
|
||||
public Class<?> getModel(int modelIndex) {
|
||||
if (modelIndex >= 0 && modelIndex < modelDisplayNames.size()) {
|
||||
String displayName = modelDisplayNames.get(modelIndex);
|
||||
return getModel(displayName);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过注册名称获取模型类
|
||||
*/
|
||||
public Class<?> getModelByRegistrationName(String registrationName) {
|
||||
return models.get(registrationName);
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过类名获取模型类
|
||||
*/
|
||||
public Class<?> getModelByClassName(String className) {
|
||||
for (Class<?> modelClass : models.values()) {
|
||||
if (modelClass.getName().equals(className)) {
|
||||
return modelClass;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有模型的显示名称
|
||||
*/
|
||||
public List<String> getAllModelDisplayNames() {
|
||||
return Collections.unmodifiableList(modelDisplayNames);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有模型的注册名称
|
||||
*/
|
||||
public Set<String> getAllModelRegistrationNames() {
|
||||
return Collections.unmodifiableSet(models.keySet());
|
||||
}
|
||||
|
||||
/**
|
||||
* 按类别获取模型注册名称
|
||||
*/
|
||||
public List<String> getModelsByCategory(String category) {
|
||||
return Collections.unmodifiableList(
|
||||
modelsByCategory.getOrDefault(category, new ArrayList<>())
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有可用的类别
|
||||
*/
|
||||
public Set<String> getAllCategories() {
|
||||
return Collections.unmodifiableSet(modelsByCategory.keySet());
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取模型数量
|
||||
*/
|
||||
public int getModelCount() {
|
||||
return modelDisplayNames.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取模型显示名称对应的注册名称
|
||||
*/
|
||||
public String getRegistrationName(String modelDisplayName) {
|
||||
return displayNameToRegistrationName.get(modelDisplayName);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取模型注册名称对应的显示名称
|
||||
*/
|
||||
public String getDisplayName(String registrationName) {
|
||||
for (Map.Entry<String, String> entry : displayNameToRegistrationName.entrySet()) {
|
||||
if (entry.getValue().equals(registrationName)) {
|
||||
return entry.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查模型是否存在
|
||||
*/
|
||||
public boolean containsModel(String modelDisplayName) {
|
||||
return displayNameToRegistrationName.containsKey(modelDisplayName);
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查注册名称是否存在
|
||||
*/
|
||||
public boolean containsRegistrationName(String registrationName) {
|
||||
return models.containsKey(registrationName);
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除模型
|
||||
*/
|
||||
public boolean removeModel(String modelDisplayName) {
|
||||
String registrationName = displayNameToRegistrationName.get(modelDisplayName);
|
||||
if (registrationName != null) {
|
||||
// 从所有存储中移除
|
||||
models.remove(registrationName);
|
||||
displayNameToRegistrationName.remove(modelDisplayName);
|
||||
modelDisplayNames.remove(modelDisplayName);
|
||||
|
||||
// 从类别中移除
|
||||
for (List<String> categoryModels : modelsByCategory.values()) {
|
||||
categoryModels.remove(registrationName);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private static final class InstanceHolder {
|
||||
private static final ModelManagement instance = new ModelManagement();
|
||||
}
|
||||
|
||||
public static ModelManagement getInstance() {
|
||||
return InstanceHolder.instance;
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,8 @@
|
||||
package com.chuangzhou.vivid2D.ai.face_parsing;
|
||||
package com.chuangzhou.vivid2D.ai;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 分割结果容器
|
||||
*/
|
||||
public class SegmentationResult {
|
||||
// 分割掩码图(每个像素的颜色为对应类别颜色)
|
||||
private final BufferedImage maskImage;
|
||||
154
src/main/java/com/chuangzhou/vivid2D/ai/Segmenter.java
Normal file
154
src/main/java/com/chuangzhou/vivid2D/ai/Segmenter.java
Normal file
@@ -0,0 +1,154 @@
|
||||
package com.chuangzhou.vivid2D.ai;
|
||||
|
||||
import ai.djl.MalformedModelException;
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDList;
|
||||
import ai.djl.ndarray.NDManager;
|
||||
import ai.djl.ndarray.types.DataType;
|
||||
import ai.djl.repository.zoo.Criteria;
|
||||
import ai.djl.repository.zoo.ModelNotFoundException;
|
||||
import ai.djl.repository.zoo.ZooModel;
|
||||
import ai.djl.translate.Batchifier;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import ai.djl.translate.Translator;
|
||||
import ai.djl.translate.TranslatorContext;
|
||||
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetLabelPalette;
|
||||
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmentationResult;
|
||||
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmenter;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
|
||||
public abstract class Segmenter implements AutoCloseable {
|
||||
// 内部类,用于从Translator安全地传出数据
|
||||
public static class SegmentationData {
|
||||
public final int[] indices;
|
||||
public final long[] shape;
|
||||
|
||||
public SegmentationData(int[] indices, long[] shape) {
|
||||
this.indices = indices;
|
||||
this.shape = shape;
|
||||
}
|
||||
}
|
||||
|
||||
private String engine = "PyTorch";
|
||||
protected final ZooModel<Image, Segmenter.SegmentationData> modelWrapper;
|
||||
protected final Predictor<Image, Segmenter.SegmentationData> predictor;
|
||||
protected final List<String> labels;
|
||||
protected final Map<String, Integer> palette;
|
||||
|
||||
public Segmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
|
||||
this.labels = new ArrayList<>(labels);
|
||||
this.palette = BiSeNetLabelPalette.defaultPalette();
|
||||
|
||||
Translator<Image, Segmenter.SegmentationData> translator = new Translator<Image, Segmenter.SegmentationData>() {
|
||||
@Override
|
||||
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||
return Segmenter.this.processInput(ctx, input);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list) {
|
||||
return Segmenter.this.processOutput(ctx, list);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Batchifier getBatchifier() {
|
||||
return Segmenter.this.getBatchifier();
|
||||
}
|
||||
};
|
||||
|
||||
Criteria<Image, Segmenter.SegmentationData> criteria = Criteria.builder()
|
||||
.setTypes(Image.class, Segmenter.SegmentationData.class)
|
||||
.optModelPath(modelDir)
|
||||
.optEngine(engine)
|
||||
.optTranslator(translator)
|
||||
.build();
|
||||
|
||||
this.modelWrapper = criteria.loadModel();
|
||||
this.predictor = modelWrapper.newPredictor();
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理模型输入
|
||||
* @param ctx translator 上下文
|
||||
* @param input 图片
|
||||
* @return 模型输入
|
||||
*/
|
||||
public abstract NDList processInput(TranslatorContext ctx, Image input);
|
||||
|
||||
/**
|
||||
* 处理模型输出
|
||||
* @param ctx translator 上下文
|
||||
* @param list 模型输出
|
||||
* @return 模型输出
|
||||
*/
|
||||
public abstract Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list);
|
||||
|
||||
/**
|
||||
* 获取批量处理方式
|
||||
* @return 批量处理方式
|
||||
*/
|
||||
public Batchifier getBatchifier(){
|
||||
return null;
|
||||
}
|
||||
|
||||
public SegmentationResult segment(File imgFile) throws TranslateException, IOException {
|
||||
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
|
||||
|
||||
// predict 方法现在直接返回安全的 Java 对象
|
||||
Segmenter.SegmentationData data = predictor.predict(img);
|
||||
|
||||
long[] shp = data.shape;
|
||||
int[] indices = data.indices;
|
||||
|
||||
int height, width;
|
||||
if (shp.length == 2) {
|
||||
height = (int) shp[0];
|
||||
width = (int) shp[1];
|
||||
} else {
|
||||
throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp));
|
||||
}
|
||||
|
||||
// 后续处理完全基于 Java 对象,不再有 Native resource 问题
|
||||
BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
|
||||
Map<Integer, String> labelsMap = new HashMap<>();
|
||||
for (int i = 0; i < labels.size(); i++) {
|
||||
labelsMap.put(i, labels.get(i));
|
||||
}
|
||||
|
||||
for (int y = 0; y < height; y++) {
|
||||
for (int x = 0; x < width; x++) {
|
||||
int idx = indices[y * width + x];
|
||||
String label = labelsMap.getOrDefault(idx, "unknown");
|
||||
int argb = palette.getOrDefault(label, 0xFF00FF00);
|
||||
mask.setRGB(x, y, argb);
|
||||
}
|
||||
}
|
||||
|
||||
return new SegmentationResult(mask, labelsMap, palette);
|
||||
}
|
||||
|
||||
public void setEngine(String engine) {
|
||||
this.engine = engine;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
try {
|
||||
predictor.close();
|
||||
} catch (Exception ignore) {
|
||||
}
|
||||
try {
|
||||
modelWrapper.close();
|
||||
} catch (Exception ignore) {
|
||||
}
|
||||
}
|
||||
}
|
||||
113
src/main/java/com/chuangzhou/vivid2D/ai/VividModelWrapper.java
Normal file
113
src/main/java/com/chuangzhou/vivid2D/ai/VividModelWrapper.java
Normal file
@@ -0,0 +1,113 @@
|
||||
package com.chuangzhou.vivid2D.ai;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
import java.util.List;
|
||||
|
||||
public abstract class VividModelWrapper<s extends Segmenter> implements AutoCloseable{
|
||||
protected final s segmenter;
|
||||
protected final List<String> labels; // index -> name
|
||||
protected final Map<String, Integer> palette; // name -> ARGB
|
||||
|
||||
protected VividModelWrapper(s segmenter, List<String> labels, Map<String, Integer> palette) {
|
||||
this.segmenter = segmenter;
|
||||
this.labels = labels;
|
||||
this.palette = palette;
|
||||
}
|
||||
|
||||
public List<String> getLabels() {
|
||||
return Collections.unmodifiableList(labels);
|
||||
}
|
||||
|
||||
public Map<String, Integer> getPalette() {
|
||||
return Collections.unmodifiableMap(palette);
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接返回分割结果(SegmentationResult)
|
||||
*/
|
||||
public SegmentationResult segment(File inputImage) throws Exception {
|
||||
return segmenter.segment(inputImage);
|
||||
}
|
||||
|
||||
/**
|
||||
* 把指定 targets(标签名集合)从输入图片中分割并保存到 outDir。
|
||||
* 如果 targets 包含单个元素 "all"(忽略大小写),则保存所有标签。
|
||||
* <p>
|
||||
* 返回值:Map<labelName, ResultFiles>,ResultFiles 包含 maskFile、overlayFile(两个 PNG)
|
||||
*/
|
||||
public abstract Map<String, ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception;
|
||||
|
||||
protected static String safeFileName(String s) {
|
||||
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
|
||||
}
|
||||
|
||||
protected static Set<String> parseTargetsSet(Set<String> in) {
|
||||
if (in == null || in.isEmpty()) return Collections.emptySet();
|
||||
// 若包含单个 "all"
|
||||
if (in.size() == 1) {
|
||||
String only = in.iterator().next();
|
||||
if ("all".equalsIgnoreCase(only.trim())) {
|
||||
return Set.of("all");
|
||||
}
|
||||
}
|
||||
// 直接返回 trim 后的小写不变集合(保持用户传入的名字)
|
||||
Set<String> out = new LinkedHashSet<>();
|
||||
for (String s : in) {
|
||||
if (s != null) out.add(s.trim());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭底层资源
|
||||
*/
|
||||
@Override
|
||||
public void close() {
|
||||
try {
|
||||
segmenter.close();
|
||||
} catch (Exception ignore) {}
|
||||
}
|
||||
|
||||
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
|
||||
|
||||
protected static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
|
||||
Path syn = modelDir.resolve("synset.txt");
|
||||
if (Files.exists(syn)) {
|
||||
try {
|
||||
List<String> lines = Files.readAllLines(syn);
|
||||
List<String> cleaned = new ArrayList<>();
|
||||
for (String l : lines) {
|
||||
String s = l.trim();
|
||||
if (!s.isEmpty()) cleaned.add(s);
|
||||
}
|
||||
if (!cleaned.isEmpty()) return Optional.of(cleaned);
|
||||
} catch (IOException ignore) {}
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
/**
|
||||
* 存放结果文件路径
|
||||
*/
|
||||
public static class ResultFiles {
|
||||
private final File maskFile;
|
||||
private final File overlayFile;
|
||||
|
||||
public ResultFiles(File maskFile, File overlayFile) {
|
||||
this.maskFile = maskFile;
|
||||
this.overlayFile = overlayFile;
|
||||
}
|
||||
|
||||
public File getMaskFile() {
|
||||
return maskFile;
|
||||
}
|
||||
|
||||
public File getOverlayFile() {
|
||||
return overlayFile;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
package com.chuangzhou.vivid2D.ai.anime_face_segmentation;
|
||||
|
||||
import com.chuangzhou.vivid2D.ai.VividModelWrapper;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
@@ -14,16 +16,10 @@ import java.util.List;
|
||||
/**
|
||||
* AnimeModelWrapper - 专门为 Anime-Face-Segmentation 模型封装的 Wrapper
|
||||
*/
|
||||
public class AnimeModelWrapper implements AutoCloseable {
|
||||
|
||||
private final AnimeSegmenter segmenter;
|
||||
private final List<String> labels; // index -> name
|
||||
private final Map<String, Integer> palette; // name -> ARGB
|
||||
public class AnimeModelWrapper extends VividModelWrapper<AnimeSegmenter> {
|
||||
|
||||
private AnimeModelWrapper(AnimeSegmenter segmenter, List<String> labels, Map<String, Integer> palette) {
|
||||
this.segmenter = segmenter;
|
||||
this.labels = labels;
|
||||
this.palette = palette;
|
||||
super(segmenter, labels, palette);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -36,14 +32,6 @@ public class AnimeModelWrapper implements AutoCloseable {
|
||||
return new AnimeModelWrapper(segmenter, labels, palette);
|
||||
}
|
||||
|
||||
public List<String> getLabels() {
|
||||
return Collections.unmodifiableList(labels);
|
||||
}
|
||||
|
||||
public Map<String, Integer> getPalette() {
|
||||
return Collections.unmodifiableMap(palette);
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接返回分割结果(在丢给底层 segmenter 前会做通用预处理:RGB 转换 + 等比 letterbox 缩放到模型输入尺寸)
|
||||
*/
|
||||
@@ -152,28 +140,6 @@ public class AnimeModelWrapper implements AutoCloseable {
|
||||
return saved;
|
||||
}
|
||||
|
||||
private static String safeFileName(String s) {
|
||||
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
|
||||
}
|
||||
|
||||
private static Set<String> parseTargetsSet(Set<String> in) {
|
||||
if (in == null || in.isEmpty()) return Collections.emptySet();
|
||||
// 若包含单个 "all"
|
||||
if (in.size() == 1) {
|
||||
String only = in.iterator().next();
|
||||
if ("all".equalsIgnoreCase(only.trim())) {
|
||||
// 返回所有标签
|
||||
return new LinkedHashSet<>(AnimeLabelPalette.defaultLabels());
|
||||
}
|
||||
}
|
||||
// 直接返回 trim 后的集合
|
||||
Set<String> out = new LinkedHashSet<>();
|
||||
for (String s : in) {
|
||||
if (s != null) out.add(s.trim());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* 专门提取眼睛的方法(在丢给底层 segmenter 前做预处理)
|
||||
*/
|
||||
@@ -246,44 +212,6 @@ public class AnimeModelWrapper implements AutoCloseable {
|
||||
} catch (Exception ignore) {}
|
||||
}
|
||||
|
||||
/**
|
||||
* 存放结果文件路径
|
||||
*/
|
||||
public static class ResultFiles {
|
||||
private final File maskFile;
|
||||
private final File overlayFile;
|
||||
|
||||
public ResultFiles(File maskFile, File overlayFile) {
|
||||
this.maskFile = maskFile;
|
||||
this.overlayFile = overlayFile;
|
||||
}
|
||||
|
||||
public File getMaskFile() {
|
||||
return maskFile;
|
||||
}
|
||||
|
||||
public File getOverlayFile() {
|
||||
return overlayFile;
|
||||
}
|
||||
}
|
||||
|
||||
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
|
||||
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
|
||||
Path syn = modelDir.resolve("synset.txt");
|
||||
if (Files.exists(syn)) {
|
||||
try {
|
||||
List<String> lines = Files.readAllLines(syn);
|
||||
List<String> cleaned = new ArrayList<>();
|
||||
for (String l : lines) {
|
||||
String s = l.trim();
|
||||
if (!s.isEmpty()) cleaned.add(s);
|
||||
}
|
||||
if (!cleaned.isEmpty()) return Optional.of(cleaned);
|
||||
} catch (IOException ignore) {}
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
// ========== 新增:预处理并保存到临时文件 ==========
|
||||
private File preprocessAndSave(File inputImage) throws IOException {
|
||||
BufferedImage img = ImageIO.read(inputImage);
|
||||
@@ -391,36 +319,4 @@ public class AnimeModelWrapper implements AutoCloseable {
|
||||
|
||||
return new int[]{w, h};
|
||||
}
|
||||
|
||||
/* ================= convenience 主方法(快速测试) ================= */
|
||||
public static void main(String[] args) throws Exception {
|
||||
if (args.length < 4) {
|
||||
System.out.println("用法: AnimeModelWrapper <modelDir> <inputImage> <outDir> <targetsCommaOrAll>");
|
||||
System.out.println("示例: AnimeModelWrapper ./anime_unet.pt input.jpg outDir eye,face");
|
||||
System.out.println("标签: " + AnimeLabelPalette.defaultLabels());
|
||||
return;
|
||||
}
|
||||
Path modelDir = Path.of(args[0]);
|
||||
File input = new File(args[1]);
|
||||
Path out = Path.of(args[2]);
|
||||
String targetsArg = args[3];
|
||||
|
||||
Set<String> targets;
|
||||
if ("all".equalsIgnoreCase(targetsArg.trim())) {
|
||||
targets = new LinkedHashSet<>(AnimeLabelPalette.defaultLabels());
|
||||
} else {
|
||||
String[] parts = targetsArg.split(",");
|
||||
targets = new LinkedHashSet<>();
|
||||
for (String p : parts) {
|
||||
if (!p.trim().isEmpty()) targets.add(p.trim());
|
||||
}
|
||||
}
|
||||
|
||||
try (AnimeModelWrapper wrapper = AnimeModelWrapper.load(modelDir)) {
|
||||
Map<String, ResultFiles> m = wrapper.segmentAndSave(input, targets, out);
|
||||
m.forEach((k, v) -> {
|
||||
System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath()));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package com.chuangzhou.vivid2D.ai.anime_face_segmentation;
|
||||
|
||||
import com.chuangzhou.vivid2D.ai.SegmentationResult;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 动漫分割结果容器
|
||||
*/
|
||||
public class AnimeSegmentationResult {
|
||||
public class AnimeSegmentationResult extends SegmentationResult {
|
||||
// 分割掩码图(每个像素的颜色为对应类别颜色)
|
||||
private final BufferedImage maskImage;
|
||||
|
||||
@@ -21,6 +23,7 @@ public class AnimeSegmentationResult {
|
||||
|
||||
public AnimeSegmentationResult(BufferedImage maskImage, float[][][] probabilityMap,
|
||||
Map<Integer, String> labels, Map<String, Integer> palette) {
|
||||
super(maskImage, labels, palette);
|
||||
this.maskImage = maskImage;
|
||||
this.probabilityMap = probabilityMap;
|
||||
this.labels = labels;
|
||||
|
||||
@@ -16,6 +16,7 @@ import ai.djl.translate.Batchifier;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import ai.djl.translate.Translator;
|
||||
import ai.djl.translate.TranslatorContext;
|
||||
import com.chuangzhou.vivid2D.ai.Segmenter;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.image.BufferedImage;
|
||||
@@ -27,7 +28,7 @@ import java.util.*;
|
||||
/**
|
||||
* AnimeSegmenter: 专门为 Anime-Face-Segmentation UNet 模型设计的分割器
|
||||
*/
|
||||
public class AnimeSegmenter implements AutoCloseable {
|
||||
public class AnimeSegmenter extends Segmenter {
|
||||
|
||||
// 模型默认输入大小(与训练时一致)。若模型不同可以修改为实际值或让 caller 通过构造参数传入。
|
||||
private static final int MODEL_INPUT_W = 512;
|
||||
@@ -48,11 +49,10 @@ public class AnimeSegmenter implements AutoCloseable {
|
||||
|
||||
private final ZooModel<Image, SegmentationData> modelWrapper;
|
||||
private final Predictor<Image, SegmentationData> predictor;
|
||||
private final List<String> labels;
|
||||
private final Map<String, Integer> palette;
|
||||
|
||||
public AnimeSegmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
|
||||
this.labels = new ArrayList<>(labels);
|
||||
super(modelDir, labels);
|
||||
this.palette = AnimeLabelPalette.defaultPalette();
|
||||
|
||||
Translator<Image, SegmentationData> translator = new Translator<Image, SegmentationData>() {
|
||||
@@ -137,6 +137,16 @@ public class AnimeSegmenter implements AutoCloseable {
|
||||
this.predictor = modelWrapper.newPredictor();
|
||||
}
|
||||
|
||||
@Override
|
||||
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list) {
|
||||
return null;
|
||||
}
|
||||
|
||||
public AnimeSegmentationResult segment(File imgFile) throws TranslateException, IOException {
|
||||
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
|
||||
|
||||
@@ -201,30 +211,4 @@ public class AnimeSegmenter implements AutoCloseable {
|
||||
} catch (Exception ignore) {
|
||||
}
|
||||
}
|
||||
|
||||
// 测试主函数
|
||||
public static void main(String[] args) throws Exception {
|
||||
if (args.length < 3) {
|
||||
System.out.println("用法: java AnimeSegmenter <modelDir> <inputImage> <outputMaskPng>");
|
||||
System.out.println("示例: java AnimeSegmenter ./anime_unet.pt input.jpg output.png");
|
||||
return;
|
||||
}
|
||||
Path modelDir = Path.of(args[0]);
|
||||
File input = new File(args[1]);
|
||||
File out = new File(args[2]);
|
||||
|
||||
List<String> labels = AnimeLabelPalette.defaultLabels();
|
||||
|
||||
try (AnimeSegmenter segmenter = new AnimeSegmenter(modelDir, labels)) {
|
||||
AnimeSegmentationResult res = segmenter.segment(input);
|
||||
ImageIO.write(res.getMaskImage(), "png", out);
|
||||
System.out.println("动漫分割掩码已保存到: " + out.getAbsolutePath());
|
||||
|
||||
// 额外保存眼睛分割结果
|
||||
BufferedImage eyes = segmenter.extractEyes(input);
|
||||
File eyesOut = new File(out.getParent(), "eyes_" + out.getName());
|
||||
ImageIO.write(eyes, "png", eyesOut);
|
||||
System.out.println("眼睛分割结果已保存到: " + eyesOut.getAbsolutePath());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,244 +0,0 @@
|
||||
package com.chuangzhou.vivid2D.ai.anime_segmentation;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.*;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Anime2ModelWrapper - 对动漫分割模型的封装
|
||||
*
|
||||
* 用法示例:
|
||||
* Anime2ModelWrapper wrapper = Anime2ModelWrapper.load(Paths.get("/path/to/modelDir"));
|
||||
* Map<String, Anime2ModelWrapper.ResultFiles> out = wrapper.segmentAndSave(
|
||||
* new File("input.jpg"),
|
||||
* Set.of("foreground"), // 动漫分割主要关注前景
|
||||
* Paths.get("outDir")
|
||||
* );
|
||||
* wrapper.close();
|
||||
*/
|
||||
public class Anime2ModelWrapper implements AutoCloseable {
|
||||
|
||||
private final Anime2Segmenter segmenter;
|
||||
private final List<String> labels; // index -> name
|
||||
private final Map<String, Integer> palette; // name -> ARGB
|
||||
|
||||
private Anime2ModelWrapper(Anime2Segmenter segmenter, List<String> labels, Map<String, Integer> palette) {
|
||||
this.segmenter = segmenter;
|
||||
this.labels = labels;
|
||||
this.palette = palette;
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 Anime2Segmenter 实例
|
||||
*/
|
||||
public static Anime2ModelWrapper load(Path modelDir) throws Exception {
|
||||
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels);
|
||||
Anime2Segmenter s = new Anime2Segmenter(modelDir, labels);
|
||||
Map<String, Integer> palette = Anime2LabelPalette.animeSegmentationPalette();
|
||||
return new Anime2ModelWrapper(s, labels, palette);
|
||||
}
|
||||
|
||||
public List<String> getLabels() {
|
||||
return Collections.unmodifiableList(labels);
|
||||
}
|
||||
|
||||
public Map<String, Integer> getPalette() {
|
||||
return Collections.unmodifiableMap(palette);
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接返回分割结果
|
||||
*/
|
||||
public Anime2SegmentationResult segment(File inputImage) throws Exception {
|
||||
return segmenter.segment(inputImage);
|
||||
}
|
||||
|
||||
/**
|
||||
* 把指定 targets(标签名集合)从输入图片中分割并保存到 outDir
|
||||
*/
|
||||
public Map<String, ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
|
||||
if (!Files.exists(outDir)) {
|
||||
Files.createDirectories(outDir);
|
||||
}
|
||||
|
||||
Anime2SegmentationResult res = segment(inputImage);
|
||||
BufferedImage original = ImageIO.read(inputImage);
|
||||
BufferedImage maskImage = res.getMaskImage();
|
||||
|
||||
int maskW = maskImage.getWidth();
|
||||
int maskH = maskImage.getHeight();
|
||||
|
||||
// 解析 targets
|
||||
Set<String> realTargets = parseTargetsSet(targets);
|
||||
Map<String, ResultFiles> saved = new LinkedHashMap<>();
|
||||
|
||||
for (String target : realTargets) {
|
||||
if (!palette.containsKey(target)) {
|
||||
System.err.println("Warning: unknown label '" + target + "' - skip.");
|
||||
continue;
|
||||
}
|
||||
|
||||
int targetColor = palette.get(target);
|
||||
|
||||
// 1) 生成透明背景的二值掩码(只保留 target 像素)
|
||||
BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB);
|
||||
for (int y = 0; y < maskH; y++) {
|
||||
for (int x = 0; x < maskW; x++) {
|
||||
int c = maskImage.getRGB(x, y);
|
||||
if (c == targetColor) {
|
||||
partMask.setRGB(x, y, targetColor | 0xFF000000); // 保证不透明
|
||||
} else {
|
||||
partMask.setRGB(x, y, 0x00000000);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2) 将 mask 缩放到与原图一致
|
||||
BufferedImage maskResized = partMask;
|
||||
if (original.getWidth() != maskW || original.getHeight() != maskH) {
|
||||
maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
Graphics2D g = maskResized.createGraphics();
|
||||
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
|
||||
g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null);
|
||||
g.dispose();
|
||||
}
|
||||
|
||||
// 3) 生成叠加图
|
||||
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
Graphics2D g2 = overlay.createGraphics();
|
||||
g2.drawImage(original, 0, 0, null);
|
||||
|
||||
int rgbOnly = (targetColor & 0x00FFFFFF);
|
||||
int translucent = (0x88 << 24) | rgbOnly;
|
||||
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
for (int y = 0; y < colorOverlay.getHeight(); y++) {
|
||||
for (int x = 0; x < colorOverlay.getWidth(); x++) {
|
||||
int mc = maskResized.getRGB(x, y);
|
||||
if ((mc & 0x00FFFFFF) == (targetColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) {
|
||||
colorOverlay.setRGB(x, y, translucent);
|
||||
} else {
|
||||
colorOverlay.setRGB(x, y, 0x00000000);
|
||||
}
|
||||
}
|
||||
}
|
||||
g2.drawImage(colorOverlay, 0, 0, null);
|
||||
g2.dispose();
|
||||
|
||||
// 保存
|
||||
String safe = safeFileName(target);
|
||||
File maskOut = outDir.resolve(safe + "_mask.png").toFile();
|
||||
File overlayOut = outDir.resolve(safe + "_overlay.png").toFile();
|
||||
|
||||
ImageIO.write(maskResized, "png", maskOut);
|
||||
ImageIO.write(overlay, "png", overlayOut);
|
||||
|
||||
saved.put(target, new ResultFiles(maskOut, overlayOut));
|
||||
}
|
||||
|
||||
return saved;
|
||||
}
|
||||
|
||||
private static String safeFileName(String s) {
|
||||
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
|
||||
}
|
||||
|
||||
private static Set<String> parseTargetsSet(Set<String> in) {
|
||||
if (in == null || in.isEmpty()) return Collections.emptySet();
|
||||
if (in.size() == 1) {
|
||||
String only = in.iterator().next();
|
||||
if ("all".equalsIgnoreCase(only.trim())) {
|
||||
return Set.of("foreground"); // 动漫分割主要关注前景
|
||||
}
|
||||
}
|
||||
Set<String> out = new LinkedHashSet<>();
|
||||
for (String s : in) {
|
||||
if (s != null) out.add(s.trim());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭底层资源
|
||||
*/
|
||||
@Override
|
||||
public void close() {
|
||||
try {
|
||||
segmenter.close();
|
||||
} catch (Exception ignore) {}
|
||||
}
|
||||
|
||||
/**
|
||||
* 存放结果文件路径
|
||||
*/
|
||||
public static class ResultFiles {
|
||||
private final File maskFile;
|
||||
private final File overlayFile;
|
||||
|
||||
public ResultFiles(File maskFile, File overlayFile) {
|
||||
this.maskFile = maskFile;
|
||||
this.overlayFile = overlayFile;
|
||||
}
|
||||
|
||||
public File getMaskFile() {
|
||||
return maskFile;
|
||||
}
|
||||
|
||||
public File getOverlayFile() {
|
||||
return overlayFile;
|
||||
}
|
||||
}
|
||||
|
||||
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
|
||||
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
|
||||
Path syn = modelDir.resolve("synset.txt");
|
||||
if (Files.exists(syn)) {
|
||||
try {
|
||||
List<String> lines = Files.readAllLines(syn);
|
||||
List<String> cleaned = new ArrayList<>();
|
||||
for (String l : lines) {
|
||||
String s = l.trim();
|
||||
if (!s.isEmpty()) cleaned.add(s);
|
||||
}
|
||||
if (!cleaned.isEmpty()) return Optional.of(cleaned);
|
||||
} catch (IOException ignore) {}
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
/* ================= convenience 主方法(快速测试) ================= */
|
||||
public static void main(String[] args) throws Exception {
|
||||
if (args.length < 4) {
|
||||
System.out.println("用法: Anime2ModelWrapper <modelDir> <inputImage> <outDir> <targetsCommaOrAll>");
|
||||
System.out.println("示例: Anime2ModelWrapper /models/anime_seg /images/in.jpg outDir foreground");
|
||||
return;
|
||||
}
|
||||
Path modelDir = Path.of(args[0]);
|
||||
File input = new File(args[1]);
|
||||
Path out = Path.of(args[2]);
|
||||
String targetsArg = args[3];
|
||||
|
||||
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels);
|
||||
Set<String> targets;
|
||||
if ("all".equalsIgnoreCase(targetsArg.trim())) {
|
||||
targets = new LinkedHashSet<>(labels);
|
||||
} else {
|
||||
String[] parts = targetsArg.split(",");
|
||||
targets = new LinkedHashSet<>();
|
||||
for (String p : parts) {
|
||||
if (!p.trim().isEmpty()) targets.add(p.trim());
|
||||
}
|
||||
}
|
||||
|
||||
try (Anime2ModelWrapper wrapper = Anime2ModelWrapper.load(modelDir)) {
|
||||
Map<String, ResultFiles> m = wrapper.segmentAndSave(input, targets, out);
|
||||
m.forEach((k, v) -> {
|
||||
System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath()));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,36 +1,15 @@
|
||||
package com.chuangzhou.vivid2D.ai.anime_segmentation;
|
||||
|
||||
import com.chuangzhou.vivid2D.ai.SegmentationResult;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 动漫分割结果容器
|
||||
*/
|
||||
public class Anime2SegmentationResult {
|
||||
// 分割掩码图(每个像素的颜色为对应类别颜色)
|
||||
private final BufferedImage maskImage;
|
||||
|
||||
// 类别索引 -> 类别名称
|
||||
private final Map<Integer, String> labels;
|
||||
|
||||
// 类别名称 -> ARGB 颜色
|
||||
private final Map<String, Integer> palette;
|
||||
|
||||
public class Anime2SegmentationResult extends SegmentationResult {
|
||||
public Anime2SegmentationResult(BufferedImage maskImage, Map<Integer, String> labels, Map<String, Integer> palette) {
|
||||
this.maskImage = maskImage;
|
||||
this.labels = labels;
|
||||
this.palette = palette;
|
||||
}
|
||||
|
||||
public BufferedImage getMaskImage() {
|
||||
return maskImage;
|
||||
}
|
||||
|
||||
public Map<Integer, String> getLabels() {
|
||||
return labels;
|
||||
}
|
||||
|
||||
public Map<String, Integer> getPalette() {
|
||||
return palette;
|
||||
super(maskImage, labels, palette);
|
||||
}
|
||||
}
|
||||
@@ -1,21 +1,17 @@
|
||||
package com.chuangzhou.vivid2D.ai.anime_segmentation;
|
||||
|
||||
import ai.djl.MalformedModelException;
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDList;
|
||||
import ai.djl.ndarray.NDManager;
|
||||
import ai.djl.ndarray.types.DataType;
|
||||
import ai.djl.ndarray.types.Shape;
|
||||
import ai.djl.repository.zoo.Criteria;
|
||||
import ai.djl.repository.zoo.ModelNotFoundException;
|
||||
import ai.djl.repository.zoo.ZooModel;
|
||||
import ai.djl.translate.Batchifier;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import ai.djl.translate.Translator;
|
||||
import ai.djl.translate.TranslatorContext;
|
||||
import com.chuangzhou.vivid2D.ai.SegmentationResult;
|
||||
import com.chuangzhou.vivid2D.ai.Segmenter;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.image.BufferedImage;
|
||||
@@ -28,94 +24,51 @@ import java.util.*;
|
||||
* Anime2Segmenter: 专门用于动漫分割模型
|
||||
* 处理 anime-segmentation 模型的二值分割输出
|
||||
*/
|
||||
public class Anime2Segmenter implements AutoCloseable {
|
||||
|
||||
public static class SegmentationData {
|
||||
final int[] indices;
|
||||
final long[] shape;
|
||||
|
||||
public SegmentationData(int[] indices, long[] shape) {
|
||||
this.indices = indices;
|
||||
this.shape = shape;
|
||||
}
|
||||
}
|
||||
|
||||
private final ZooModel<Image, SegmentationData> modelWrapper;
|
||||
private final Predictor<Image, SegmentationData> predictor;
|
||||
private final List<String> labels;
|
||||
private final Map<String, Integer> palette;
|
||||
|
||||
public class Anime2Segmenter extends Segmenter {
|
||||
public Anime2Segmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
|
||||
this.labels = new ArrayList<>(labels);
|
||||
this.palette = Anime2LabelPalette.animeSegmentationPalette();
|
||||
|
||||
Translator<Image, SegmentationData> translator = new Translator<Image, SegmentationData>() {
|
||||
@Override
|
||||
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||
NDManager manager = ctx.getNDManager();
|
||||
|
||||
// 调整输入图像尺寸到模型期望的大小 (1024x1024)
|
||||
Image resized = input.resize(1024, 1024, true);
|
||||
NDArray array = resized.toNDArray(manager);
|
||||
|
||||
// 转换为 CHW 格式并归一化
|
||||
array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false);
|
||||
array = array.div(255f);
|
||||
array = array.expandDims(0); // 添加batch维度
|
||||
|
||||
return new NDList(array);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SegmentationData processOutput(TranslatorContext ctx, NDList list) {
|
||||
if (list == null || list.isEmpty()) {
|
||||
throw new IllegalStateException("Model did not return any output.");
|
||||
}
|
||||
|
||||
NDArray out = list.get(0);
|
||||
|
||||
// 动漫分割模型输出形状: [1, 1, H, W] - 单通道概率图
|
||||
// 应用sigmoid并二值化
|
||||
NDArray probabilities = out.div(out.neg().exp().add(1));
|
||||
NDArray binaryMask = probabilities.gt(0.5).toType(DataType.INT32, false);
|
||||
|
||||
// 移除batch和channel维度
|
||||
if (binaryMask.getShape().dimension() == 4) {
|
||||
binaryMask = binaryMask.squeeze(0).squeeze(0);
|
||||
}
|
||||
|
||||
// 转换为Java数组
|
||||
long[] finalShape = binaryMask.getShape().getShape();
|
||||
int[] indices = binaryMask.toIntArray();
|
||||
|
||||
return new SegmentationData(indices, finalShape);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Batchifier getBatchifier() {
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
Criteria<Image, SegmentationData> criteria = Criteria.builder()
|
||||
.setTypes(Image.class, SegmentationData.class)
|
||||
.optModelPath(modelDir)
|
||||
.optEngine("PyTorch")
|
||||
.optTranslator(translator)
|
||||
.build();
|
||||
|
||||
this.modelWrapper = criteria.loadModel();
|
||||
this.predictor = modelWrapper.newPredictor();
|
||||
super(modelDir, labels);
|
||||
}
|
||||
|
||||
public Anime2SegmentationResult segment(File imgFile) throws TranslateException, IOException {
|
||||
@Override
|
||||
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||
NDManager manager = ctx.getNDManager();
|
||||
|
||||
// 调整输入图像尺寸到模型期望的大小 (1024x1024)
|
||||
Image resized = input.resize(1024, 1024, true);
|
||||
NDArray array = resized.toNDArray(manager);
|
||||
|
||||
// 转换为 CHW 格式并归一化
|
||||
array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false);
|
||||
array = array.div(255f);
|
||||
array = array.expandDims(0); // 添加batch维度
|
||||
|
||||
return new NDList(array);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SegmentationData processOutput(TranslatorContext ctx, NDList list) {
|
||||
if (list == null || list.isEmpty()) {
|
||||
throw new IllegalStateException("Model did not return any output.");
|
||||
}
|
||||
NDArray out = list.get(0);
|
||||
// 动漫分割模型输出形状: [1, 1, H, W] - 单通道概率图
|
||||
// 应用sigmoid并二值化
|
||||
NDArray probabilities = out.div(out.neg().exp().add(1));
|
||||
NDArray binaryMask = probabilities.gt(0.5).toType(DataType.INT32, false);
|
||||
if (binaryMask.getShape().dimension() == 4) {
|
||||
binaryMask = binaryMask.squeeze(0).squeeze(0);
|
||||
}
|
||||
long[] finalShape = binaryMask.getShape().getShape();
|
||||
int[] indices = binaryMask.toIntArray();
|
||||
return new SegmentationData(indices, finalShape);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SegmentationResult segment(File imgFile) throws TranslateException, IOException {
|
||||
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
|
||||
|
||||
SegmentationData data = predictor.predict(img);
|
||||
|
||||
Segmenter.SegmentationData data = predictor.predict(img);
|
||||
long[] shp = data.shape;
|
||||
int[] indices = data.indices;
|
||||
|
||||
int height, width;
|
||||
if (shp.length == 2) {
|
||||
height = (int) shp[0];
|
||||
@@ -123,14 +76,11 @@ public class Anime2Segmenter implements AutoCloseable {
|
||||
} else {
|
||||
throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp));
|
||||
}
|
||||
|
||||
// 创建分割掩码
|
||||
BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
|
||||
Map<Integer, String> labelsMap = new HashMap<>();
|
||||
for (int i = 0; i < labels.size(); i++) {
|
||||
labelsMap.put(i, labels.get(i));
|
||||
}
|
||||
|
||||
for (int y = 0; y < height; y++) {
|
||||
for (int x = 0; x < width; x++) {
|
||||
int idx = indices[y * width + x];
|
||||
@@ -139,8 +89,7 @@ public class Anime2Segmenter implements AutoCloseable {
|
||||
mask.setRGB(x, y, argb);
|
||||
}
|
||||
}
|
||||
|
||||
return new Anime2SegmentationResult(mask, labelsMap, palette);
|
||||
return new SegmentationResult(mask, labelsMap, palette);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -154,22 +103,4 @@ public class Anime2Segmenter implements AutoCloseable {
|
||||
} catch (Exception ignore) {
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
if (args.length < 3) {
|
||||
System.out.println("用法: java Anime2Segmenter <modelDir> <inputImage> <outputMaskPng>");
|
||||
return;
|
||||
}
|
||||
Path modelDir = Path.of(args[0]);
|
||||
File input = new File(args[1]);
|
||||
File out = new File(args[2]);
|
||||
|
||||
List<String> labels = Anime2LabelPalette.defaultLabels();
|
||||
|
||||
try (Anime2Segmenter s = new Anime2Segmenter(modelDir, labels)) {
|
||||
Anime2SegmentationResult res = s.segment(input);
|
||||
ImageIO.write(res.getMaskImage(), "png", out);
|
||||
System.out.println("动漫分割掩码已保存到: " + out.getAbsolutePath());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,8 @@
|
||||
package com.chuangzhou.vivid2D.ai.anime_segmentation;
|
||||
|
||||
import com.chuangzhou.vivid2D.ai.SegmentationResult;
|
||||
import com.chuangzhou.vivid2D.ai.VividModelWrapper;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
@@ -22,16 +25,11 @@ import java.util.List;
|
||||
* // out contains 每个目标标签对应的 mask+overlay 文件路径
|
||||
* wrapper.close();
|
||||
*/
|
||||
public class Anime2VividModelWrapper implements AutoCloseable {
|
||||
|
||||
private final Anime2Segmenter segmenter;
|
||||
private final List<String> labels; // index -> name
|
||||
private final Map<String, Integer> palette; // name -> ARGB
|
||||
public class Anime2VividModelWrapper extends VividModelWrapper<Anime2Segmenter> {
|
||||
|
||||
private Anime2VividModelWrapper(Anime2Segmenter segmenter, List<String> labels, Map<String, Integer> palette) {
|
||||
this.segmenter = segmenter;
|
||||
this.labels = labels;
|
||||
this.palette = palette;
|
||||
super(segmenter, labels, palette);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -56,7 +54,7 @@ public class Anime2VividModelWrapper implements AutoCloseable {
|
||||
/**
|
||||
* 直接返回分割结果(Anime2SegmentationResult)
|
||||
*/
|
||||
public Anime2SegmentationResult segment(File inputImage) throws Exception {
|
||||
public SegmentationResult segment(File inputImage) throws Exception {
|
||||
return segmenter.segment(inputImage);
|
||||
}
|
||||
|
||||
@@ -66,12 +64,12 @@ public class Anime2VividModelWrapper implements AutoCloseable {
|
||||
* <p>
|
||||
* 返回值:Map<labelName, ResultFiles>,ResultFiles 包含 maskFile、overlayFile(两个 PNG)
|
||||
*/
|
||||
public Map<String, ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
|
||||
public Map<String, VividModelWrapper.ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
|
||||
if (!Files.exists(outDir)) {
|
||||
Files.createDirectories(outDir);
|
||||
}
|
||||
|
||||
Anime2SegmentationResult res = segment(inputImage);
|
||||
SegmentationResult res = segment(inputImage);
|
||||
BufferedImage original = ImageIO.read(inputImage);
|
||||
BufferedImage maskImage = res.getMaskImage();
|
||||
|
||||
@@ -84,7 +82,6 @@ public class Anime2VividModelWrapper implements AutoCloseable {
|
||||
|
||||
for (String target : realTargets) {
|
||||
if (!palette.containsKey(target)) {
|
||||
// 尝试忽略大小写匹配
|
||||
String finalTarget = target;
|
||||
Optional<String> matched = palette.keySet().stream()
|
||||
.filter(k -> k.equalsIgnoreCase(finalTarget))
|
||||
@@ -154,109 +151,4 @@ public class Anime2VividModelWrapper implements AutoCloseable {
|
||||
|
||||
return saved;
|
||||
}
|
||||
|
||||
private static String safeFileName(String s) {
|
||||
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
|
||||
}
|
||||
|
||||
private static Set<String> parseTargetsSet(Set<String> in) {
|
||||
if (in == null || in.isEmpty()) return Collections.emptySet();
|
||||
// 若包含单个 "all"
|
||||
if (in.size() == 1) {
|
||||
String only = in.iterator().next();
|
||||
if ("all".equalsIgnoreCase(only.trim())) {
|
||||
// 由调用方自行取 labels(这里返回 sentinel, but caller already checks palette)
|
||||
// For convenience, return a set containing "all" and let caller logic handle it earlier.
|
||||
return Set.of("all");
|
||||
}
|
||||
}
|
||||
// 直接返回 trim 后的小写不变集合(保持用户传入的名字)
|
||||
Set<String> out = new LinkedHashSet<>();
|
||||
for (String s : in) {
|
||||
if (s != null) out.add(s.trim());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭底层资源
|
||||
*/
|
||||
@Override
|
||||
public void close() {
|
||||
try {
|
||||
segmenter.close();
|
||||
} catch (Exception ignore) {}
|
||||
}
|
||||
|
||||
/**
|
||||
* 存放结果文件路径
|
||||
*/
|
||||
public static class ResultFiles {
|
||||
private final File maskFile;
|
||||
private final File overlayFile;
|
||||
|
||||
public ResultFiles(File maskFile, File overlayFile) {
|
||||
this.maskFile = maskFile;
|
||||
this.overlayFile = overlayFile;
|
||||
}
|
||||
|
||||
public File getMaskFile() {
|
||||
return maskFile;
|
||||
}
|
||||
|
||||
public File getOverlayFile() {
|
||||
return overlayFile;
|
||||
}
|
||||
}
|
||||
|
||||
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
|
||||
|
||||
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
|
||||
Path syn = modelDir.resolve("synset.txt");
|
||||
if (Files.exists(syn)) {
|
||||
try {
|
||||
List<String> lines = Files.readAllLines(syn);
|
||||
List<String> cleaned = new ArrayList<>();
|
||||
for (String l : lines) {
|
||||
String s = l.trim();
|
||||
if (!s.isEmpty()) cleaned.add(s);
|
||||
}
|
||||
if (!cleaned.isEmpty()) return Optional.of(cleaned);
|
||||
} catch (IOException ignore) {}
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
/* ================= convenience 主方法(快速测试) ================= */
|
||||
public static void main(String[] args) throws Exception {
|
||||
if (args.length < 4) {
|
||||
System.out.println("用法: Anime2VividModelWrapper <modelDir> <inputImage> <outDir> <targetsCommaOrAll>");
|
||||
System.out.println("示例: Anime2VividModelWrapper /models/anime_seg /images/in.jpg outDir foreground");
|
||||
System.out.println("示例: Anime2VividModelWrapper /models/anime_seg /images/in.jpg outDir all");
|
||||
return;
|
||||
}
|
||||
Path modelDir = Path.of(args[0]);
|
||||
File input = new File(args[1]);
|
||||
Path out = Path.of(args[2]);
|
||||
String targetsArg = args[3];
|
||||
|
||||
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels);
|
||||
Set<String> targets;
|
||||
if ("all".equalsIgnoreCase(targetsArg.trim())) {
|
||||
targets = new LinkedHashSet<>(labels);
|
||||
} else {
|
||||
String[] parts = targetsArg.split(",");
|
||||
targets = new LinkedHashSet<>();
|
||||
for (String p : parts) {
|
||||
if (!p.trim().isEmpty()) targets.add(p.trim());
|
||||
}
|
||||
}
|
||||
|
||||
try (Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(modelDir)) {
|
||||
Map<String, ResultFiles> m = wrapper.segmentAndSave(input, targets, out);
|
||||
m.forEach((k, v) -> {
|
||||
System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath()));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import java.util.*;
|
||||
* 颜色值基于 zllrunning/face-parsing.PyTorch 仓库的 test.py 文件。
|
||||
* 标签索引必须与模型输出索引一致(0-18)。
|
||||
*/
|
||||
public class LabelPalette {
|
||||
public class BiSeNetLabelPalette {
|
||||
|
||||
/**
|
||||
* BiSeNet 人脸解析模型的标准标签(19个类别,索引 0-18)
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.chuangzhou.vivid2D.ai.face_parsing;
|
||||
|
||||
import com.chuangzhou.vivid2D.ai.SegmentationResult;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 分割结果容器
|
||||
*/
|
||||
public class BiSeNetSegmentationResult extends SegmentationResult {
|
||||
public BiSeNetSegmentationResult(BufferedImage maskImage, Map<Integer, String> labels, Map<String, Integer> palette) {
|
||||
super(maskImage, labels, palette);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package com.chuangzhou.vivid2D.ai.face_parsing;
|
||||
|
||||
import ai.djl.MalformedModelException;
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDList;
|
||||
import ai.djl.ndarray.NDManager;
|
||||
import ai.djl.ndarray.types.DataType;
|
||||
import ai.djl.repository.zoo.Criteria;
|
||||
import ai.djl.repository.zoo.ModelNotFoundException;
|
||||
import ai.djl.repository.zoo.ZooModel;
|
||||
import ai.djl.translate.Batchifier;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import ai.djl.translate.Translator;
|
||||
import ai.djl.translate.TranslatorContext;
|
||||
import com.chuangzhou.vivid2D.ai.SegmentationResult;
|
||||
import com.chuangzhou.vivid2D.ai.Segmenter;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Segmenter: 加载模型并对图片做语义分割
|
||||
*
|
||||
* 说明:
|
||||
* - Translator.processOutput 在翻译器层就把模型输出处理成 (H, W) 的类别索引 NDArray,
|
||||
* 并把该 NDArray 拷贝到 persistentManager 中返回,从而避免后续 native 资源被释放的问题。
|
||||
* - 这里改为在 Translator 内部把 classMap 转为 Java int[](通过 classMap.toIntArray()),
|
||||
* 再用 persistentManager.create(int[], shape) 创建新的 NDArray 返回,确保安全。
|
||||
*/
|
||||
public class BiSeNetSegmenter extends Segmenter {
|
||||
|
||||
public BiSeNetSegmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
|
||||
super(modelDir, labels);
|
||||
}
|
||||
|
||||
@Override
|
||||
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||
NDManager manager = ctx.getNDManager();
|
||||
NDArray array = input.toNDArray(manager);
|
||||
array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false);
|
||||
array = array.div(255f);
|
||||
array = array.expandDims(0);
|
||||
return new NDList(array);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list) {
|
||||
if (list == null || list.isEmpty()) {
|
||||
throw new IllegalStateException("Model did not return any output.");
|
||||
}
|
||||
|
||||
NDArray out = list.get(0);
|
||||
NDArray classMap;
|
||||
|
||||
// 1. 解析模型输出,得到类别图谱 (classMap)
|
||||
long[] shape = out.getShape().getShape();
|
||||
if (shape.length == 4 && shape[1] > 1) {
|
||||
classMap = out.argMax(1);
|
||||
} else if (shape.length == 3) {
|
||||
classMap = (shape[0] == 1) ? out : out.argMax(0);
|
||||
} else if (shape.length == 2) {
|
||||
classMap = out;
|
||||
} else {
|
||||
throw new IllegalStateException("Unexpected output shape: " + Arrays.toString(shape));
|
||||
}
|
||||
|
||||
if (classMap.getShape().dimension() == 3) {
|
||||
classMap = classMap.squeeze(0);
|
||||
}
|
||||
|
||||
// 2. *** 关键步骤 ***
|
||||
// 在 NDArray 仍然有效的上下文中,将其转换为 Java 原生类型
|
||||
|
||||
// 首先,确保数据类型是 INT32
|
||||
NDArray int32ClassMap = classMap.toType(DataType.INT32, false);
|
||||
|
||||
// 然后,获取形状和 int[] 数组
|
||||
long[] finalShape = int32ClassMap.getShape().getShape();
|
||||
int[] indices = int32ClassMap.toIntArray();
|
||||
|
||||
// 3. 将 Java 对象封装并返回
|
||||
return new SegmentationData(indices, finalShape);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SegmentationResult segment(File imgFile) throws TranslateException, IOException {
|
||||
return super.segment(imgFile);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
super.close();
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,8 @@
|
||||
package com.chuangzhou.vivid2D.ai.face_parsing;
|
||||
|
||||
import com.chuangzhou.vivid2D.ai.SegmentationResult;
|
||||
import com.chuangzhou.vivid2D.ai.VividModelWrapper;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
@@ -22,35 +25,29 @@ import java.util.List;
|
||||
* // out contains 每个目标标签对应的 mask+overlay 文件路径
|
||||
* wrapper.close();
|
||||
*/
|
||||
public class VividModelWrapper implements AutoCloseable {
|
||||
public class BiSeNetVividModelWrapper extends VividModelWrapper<BiSeNetSegmenter> {
|
||||
|
||||
private final Segmenter segmenter;
|
||||
private final List<String> labels; // index -> name
|
||||
private final Map<String, Integer> palette; // name -> ARGB
|
||||
|
||||
private VividModelWrapper(Segmenter segmenter, List<String> labels, Map<String, Integer> palette) {
|
||||
this.segmenter = segmenter;
|
||||
this.labels = labels;
|
||||
this.palette = palette;
|
||||
private BiSeNetVividModelWrapper(BiSeNetSegmenter segmenter, List<String> labels, Map<String, Integer> palette) {
|
||||
super(segmenter, labels, palette);
|
||||
}
|
||||
|
||||
/**
|
||||
* 读取 modelDir/synset.txt(每行一个标签),若不存在则使用 LabelPalette.defaultLabels()
|
||||
* 并创建 Segmenter 实例。
|
||||
*/
|
||||
public static VividModelWrapper load(Path modelDir) throws Exception {
|
||||
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels);
|
||||
Segmenter s = new Segmenter(modelDir, labels);
|
||||
Map<String, Integer> palette = LabelPalette.defaultPalette();
|
||||
return new VividModelWrapper(s, labels, palette);
|
||||
public static BiSeNetVividModelWrapper load(Path modelDir) throws Exception {
|
||||
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(BiSeNetLabelPalette::defaultLabels);
|
||||
BiSeNetSegmenter s = new BiSeNetSegmenter(modelDir, labels);
|
||||
Map<String, Integer> palette = BiSeNetLabelPalette.defaultPalette();
|
||||
return new BiSeNetVividModelWrapper(s, labels, palette);
|
||||
}
|
||||
|
||||
public List<String> getLabels() {
|
||||
return Collections.unmodifiableList(labels);
|
||||
return super.getLabels();
|
||||
}
|
||||
|
||||
public Map<String, Integer> getPalette() {
|
||||
return Collections.unmodifiableMap(palette);
|
||||
return super.getPalette();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -66,25 +63,19 @@ public class VividModelWrapper implements AutoCloseable {
|
||||
* <p>
|
||||
* 返回值:Map<labelName, ResultFiles>,ResultFiles 包含 maskFile、overlayFile(两个 PNG)
|
||||
*/
|
||||
public Map<String, ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
|
||||
public Map<String, VividModelWrapper.ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
|
||||
if (!Files.exists(outDir)) {
|
||||
Files.createDirectories(outDir);
|
||||
}
|
||||
|
||||
SegmentationResult res = segment(inputImage);
|
||||
BufferedImage original = ImageIO.read(inputImage);
|
||||
BufferedImage maskImage = res.getMaskImage();
|
||||
|
||||
int maskW = maskImage.getWidth();
|
||||
int maskH = maskImage.getHeight();
|
||||
|
||||
// 解析 targets
|
||||
Set<String> realTargets = parseTargetsSet(targets);
|
||||
Map<String, ResultFiles> saved = new LinkedHashMap<>();
|
||||
|
||||
for (String target : realTargets) {
|
||||
if (!palette.containsKey(target)) {
|
||||
// 尝试忽略大小写匹配
|
||||
String finalTarget = target;
|
||||
Optional<String> matched = palette.keySet().stream()
|
||||
.filter(k -> k.equalsIgnoreCase(finalTarget))
|
||||
@@ -95,10 +86,7 @@ public class VividModelWrapper implements AutoCloseable {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
int targetColor = palette.get(target);
|
||||
|
||||
// 1) 生成透明背景的二值掩码(只保留 target 像素)
|
||||
BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB);
|
||||
for (int y = 0; y < maskH; y++) {
|
||||
for (int x = 0; x < maskW; x++) {
|
||||
@@ -110,8 +98,6 @@ public class VividModelWrapper implements AutoCloseable {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2) 将 mask 缩放到与原图一致(如果需要),并生成 overlay(半透明)
|
||||
BufferedImage maskResized = partMask;
|
||||
if (original.getWidth() != maskW || original.getHeight() != maskH) {
|
||||
maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
@@ -120,11 +106,9 @@ public class VividModelWrapper implements AutoCloseable {
|
||||
g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null);
|
||||
g.dispose();
|
||||
}
|
||||
|
||||
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
Graphics2D g2 = overlay.createGraphics();
|
||||
g2.drawImage(original, 0, 0, null);
|
||||
// 半透明颜色(alpha = 0x88)
|
||||
int rgbOnly = (targetColor & 0x00FFFFFF);
|
||||
int translucent = (0x88 << 24) | rgbOnly;
|
||||
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
@@ -140,44 +124,17 @@ public class VividModelWrapper implements AutoCloseable {
|
||||
}
|
||||
g2.drawImage(colorOverlay, 0, 0, null);
|
||||
g2.dispose();
|
||||
|
||||
// 保存
|
||||
String safe = safeFileName(target);
|
||||
File maskOut = outDir.resolve(safe + "_mask.png").toFile();
|
||||
File overlayOut = outDir.resolve(safe + "_overlay.png").toFile();
|
||||
|
||||
ImageIO.write(maskResized, "png", maskOut);
|
||||
ImageIO.write(overlay, "png", overlayOut);
|
||||
|
||||
saved.put(target, new ResultFiles(maskOut, overlayOut));
|
||||
}
|
||||
|
||||
return saved;
|
||||
}
|
||||
|
||||
private static String safeFileName(String s) {
|
||||
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
|
||||
}
|
||||
|
||||
private static Set<String> parseTargetsSet(Set<String> in) {
|
||||
if (in == null || in.isEmpty()) return Collections.emptySet();
|
||||
// 若包含单个 "all"
|
||||
if (in.size() == 1) {
|
||||
String only = in.iterator().next();
|
||||
if ("all".equalsIgnoreCase(only.trim())) {
|
||||
// 由调用方自行取 labels(这里返回 sentinel, but caller already checks palette)
|
||||
// For convenience, return a set containing "all" and let caller logic handle it earlier.
|
||||
return Set.of("all");
|
||||
}
|
||||
}
|
||||
// 直接返回 trim 后的小写不变集合(保持用户传入的名字)
|
||||
Set<String> out = new LinkedHashSet<>();
|
||||
for (String s : in) {
|
||||
if (s != null) out.add(s.trim());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭底层资源
|
||||
*/
|
||||
@@ -187,75 +144,4 @@ public class VividModelWrapper implements AutoCloseable {
|
||||
segmenter.close();
|
||||
} catch (Exception ignore) {}
|
||||
}
|
||||
|
||||
/**
|
||||
* 存放结果文件路径
|
||||
*/
|
||||
public static class ResultFiles {
|
||||
private final File maskFile;
|
||||
private final File overlayFile;
|
||||
|
||||
public ResultFiles(File maskFile, File overlayFile) {
|
||||
this.maskFile = maskFile;
|
||||
this.overlayFile = overlayFile;
|
||||
}
|
||||
|
||||
public File getMaskFile() {
|
||||
return maskFile;
|
||||
}
|
||||
|
||||
public File getOverlayFile() {
|
||||
return overlayFile;
|
||||
}
|
||||
}
|
||||
|
||||
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
|
||||
|
||||
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
|
||||
Path syn = modelDir.resolve("synset.txt");
|
||||
if (Files.exists(syn)) {
|
||||
try {
|
||||
List<String> lines = Files.readAllLines(syn);
|
||||
List<String> cleaned = new ArrayList<>();
|
||||
for (String l : lines) {
|
||||
String s = l.trim();
|
||||
if (!s.isEmpty()) cleaned.add(s);
|
||||
}
|
||||
if (!cleaned.isEmpty()) return Optional.of(cleaned);
|
||||
} catch (IOException ignore) {}
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
/* ================= convenience 主方法(快速测试) ================= */
|
||||
public static void main(String[] args) throws Exception {
|
||||
if (args.length < 4) {
|
||||
System.out.println("用法: VividModelWrapper <modelDir> <inputImage> <outDir> <targetsCommaOrAll>");
|
||||
System.out.println("示例: VividModelWrapper /models/bisenet /images/in.jpg outDir eye,face");
|
||||
return;
|
||||
}
|
||||
Path modelDir = Path.of(args[0]);
|
||||
File input = new File(args[1]);
|
||||
Path out = Path.of(args[2]);
|
||||
String targetsArg = args[3];
|
||||
|
||||
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels);
|
||||
Set<String> targets;
|
||||
if ("all".equalsIgnoreCase(targetsArg.trim())) {
|
||||
targets = new LinkedHashSet<>(labels);
|
||||
} else {
|
||||
String[] parts = targetsArg.split(",");
|
||||
targets = new LinkedHashSet<>();
|
||||
for (String p : parts) {
|
||||
if (!p.trim().isEmpty()) targets.add(p.trim());
|
||||
}
|
||||
}
|
||||
|
||||
try (VividModelWrapper wrapper = VividModelWrapper.load(modelDir)) {
|
||||
Map<String, ResultFiles> m = wrapper.segmentAndSave(input, targets, out);
|
||||
m.forEach((k, v) -> {
|
||||
System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath()));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,193 +0,0 @@
|
||||
package com.chuangzhou.vivid2D.ai.face_parsing;
|
||||
|
||||
import ai.djl.MalformedModelException;
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDList;
|
||||
import ai.djl.ndarray.NDManager;
|
||||
import ai.djl.ndarray.types.DataType;
|
||||
import ai.djl.repository.zoo.Criteria;
|
||||
import ai.djl.repository.zoo.ModelNotFoundException;
|
||||
import ai.djl.repository.zoo.ZooModel;
|
||||
import ai.djl.translate.Batchifier;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import ai.djl.translate.Translator;
|
||||
import ai.djl.translate.TranslatorContext;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Segmenter: 加载模型并对图片做语义分割
|
||||
*
|
||||
* 说明:
|
||||
* - Translator.processOutput 在翻译器层就把模型输出处理成 (H, W) 的类别索引 NDArray,
|
||||
* 并把该 NDArray 拷贝到 persistentManager 中返回,从而避免后续 native 资源被释放的问题。
|
||||
* - 这里改为在 Translator 内部把 classMap 转为 Java int[](通过 classMap.toIntArray()),
|
||||
* 再用 persistentManager.create(int[], shape) 创建新的 NDArray 返回,确保安全。
|
||||
*/
|
||||
public class Segmenter implements AutoCloseable {
|
||||
|
||||
// 内部类,用于从Translator安全地传出数据
|
||||
public static class SegmentationData {
|
||||
final int[] indices;
|
||||
final long[] shape;
|
||||
|
||||
public SegmentationData(int[] indices, long[] shape) {
|
||||
this.indices = indices;
|
||||
this.shape = shape;
|
||||
}
|
||||
}
|
||||
|
||||
private final ZooModel<Image, SegmentationData> modelWrapper;
|
||||
private final Predictor<Image, SegmentationData> predictor;
|
||||
private final List<String> labels;
|
||||
private final Map<String, Integer> palette;
|
||||
|
||||
public Segmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
|
||||
this.labels = new ArrayList<>(labels);
|
||||
this.palette = LabelPalette.defaultPalette();
|
||||
|
||||
// Translator 的输出类型现在是 SegmentationData
|
||||
Translator<Image, SegmentationData> translator = new Translator<Image, SegmentationData>() {
|
||||
@Override
|
||||
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||
NDManager manager = ctx.getNDManager();
|
||||
NDArray array = input.toNDArray(manager);
|
||||
array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false);
|
||||
array = array.div(255f);
|
||||
array = array.expandDims(0);
|
||||
return new NDList(array);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SegmentationData processOutput(TranslatorContext ctx, NDList list) {
|
||||
if (list == null || list.isEmpty()) {
|
||||
throw new IllegalStateException("Model did not return any output.");
|
||||
}
|
||||
|
||||
NDArray out = list.get(0);
|
||||
NDArray classMap;
|
||||
|
||||
// 1. 解析模型输出,得到类别图谱 (classMap)
|
||||
long[] shape = out.getShape().getShape();
|
||||
if (shape.length == 4 && shape[1] > 1) {
|
||||
classMap = out.argMax(1);
|
||||
} else if (shape.length == 3) {
|
||||
classMap = (shape[0] == 1) ? out : out.argMax(0);
|
||||
} else if (shape.length == 2) {
|
||||
classMap = out;
|
||||
} else {
|
||||
throw new IllegalStateException("Unexpected output shape: " + Arrays.toString(shape));
|
||||
}
|
||||
|
||||
if (classMap.getShape().dimension() == 3) {
|
||||
classMap = classMap.squeeze(0);
|
||||
}
|
||||
|
||||
// 2. *** 关键步骤 ***
|
||||
// 在 NDArray 仍然有效的上下文中,将其转换为 Java 原生类型
|
||||
|
||||
// 首先,确保数据类型是 INT32
|
||||
NDArray int32ClassMap = classMap.toType(DataType.INT32, false);
|
||||
|
||||
// 然后,获取形状和 int[] 数组
|
||||
long[] finalShape = int32ClassMap.getShape().getShape();
|
||||
int[] indices = int32ClassMap.toIntArray();
|
||||
|
||||
// 3. 将 Java 对象封装并返回
|
||||
return new SegmentationData(indices, finalShape);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Batchifier getBatchifier() {
|
||||
return null; // 或者根据需要使用 Batchifier.STACK
|
||||
}
|
||||
};
|
||||
|
||||
// Criteria 的类型也需要更新
|
||||
Criteria<Image, SegmentationData> criteria = Criteria.builder()
|
||||
.setTypes(Image.class, SegmentationData.class)
|
||||
.optModelPath(modelDir)
|
||||
.optEngine("PyTorch")
|
||||
.optTranslator(translator)
|
||||
.build();
|
||||
|
||||
this.modelWrapper = criteria.loadModel();
|
||||
this.predictor = modelWrapper.newPredictor();
|
||||
}
|
||||
|
||||
public SegmentationResult segment(File imgFile) throws TranslateException, IOException {
|
||||
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
|
||||
|
||||
// predict 方法现在直接返回安全的 Java 对象
|
||||
SegmentationData data = predictor.predict(img);
|
||||
|
||||
long[] shp = data.shape;
|
||||
int[] indices = data.indices;
|
||||
|
||||
int height, width;
|
||||
if (shp.length == 2) {
|
||||
height = (int) shp[0];
|
||||
width = (int) shp[1];
|
||||
} else {
|
||||
throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp));
|
||||
}
|
||||
|
||||
// 后续处理完全基于 Java 对象,不再有 Native resource 问题
|
||||
BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
|
||||
Map<Integer, String> labelsMap = new HashMap<>();
|
||||
for (int i = 0; i < labels.size(); i++) {
|
||||
labelsMap.put(i, labels.get(i));
|
||||
}
|
||||
|
||||
for (int y = 0; y < height; y++) {
|
||||
for (int x = 0; x < width; x++) {
|
||||
int idx = indices[y * width + x];
|
||||
String label = labelsMap.getOrDefault(idx, "unknown");
|
||||
int argb = palette.getOrDefault(label, 0xFF00FF00);
|
||||
mask.setRGB(x, y, argb);
|
||||
}
|
||||
}
|
||||
|
||||
return new SegmentationResult(mask, labelsMap, palette);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
try {
|
||||
predictor.close();
|
||||
} catch (Exception ignore) {
|
||||
}
|
||||
try {
|
||||
modelWrapper.close();
|
||||
} catch (Exception ignore) {
|
||||
}
|
||||
}
|
||||
|
||||
// 小测试主函数(示例)
|
||||
public static void main(String[] args) throws Exception {
|
||||
if (args.length < 3) {
|
||||
System.out.println("用法: java Segmenter <modelDir> <inputImage> <outputMaskPng>");
|
||||
return;
|
||||
}
|
||||
Path modelDir = Path.of(args[0]);
|
||||
File input = new File(args[1]);
|
||||
File out = new File(args[2]);
|
||||
|
||||
List<String> labels = LabelPalette.defaultLabels();
|
||||
|
||||
try (Segmenter s = new Segmenter(modelDir, labels)) {
|
||||
SegmentationResult res = s.segment(input);
|
||||
ImageIO.write(res.getMaskImage(), "png", out);
|
||||
System.out.println("分割掩码已保存到: " + out.getAbsolutePath());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,184 +0,0 @@
|
||||
package com.chuangzhou.vivid2D.ai.face_parsing;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.awt.*;
|
||||
import java.io.*;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* SegmenterExample
|
||||
*
|
||||
* 使用说明(命令行):
|
||||
* java -cp <your-classpath> com.chuangzhou.vivid2D.ai.face_parsing.SegmenterExample \
|
||||
* <modelDir> <inputImage> <outputDir> <targetsCommaOrAll>
|
||||
*
|
||||
* 示例:
|
||||
* java ... SegmenterExample /models/face_bisent /images/in.jpg /out "eye,face"
|
||||
* java ... SegmenterExample /models/face_bisent /images/in.jpg /out all
|
||||
*/
|
||||
public class SegmenterExample {
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
if (args.length < 4) {
|
||||
System.err.println("用法: SegmenterExample <modelDir> <inputImage> <outputDir> <targetsCommaOrAll>");
|
||||
System.err.println("例如: SegmenterExample /models/face_bisent input.jpg outDir eye,face");
|
||||
return;
|
||||
}
|
||||
|
||||
Path modelDir = Path.of(args[0]);
|
||||
File inputImage = new File(args[1]);
|
||||
Path outDir = Path.of(args[2]);
|
||||
String targetsArg = args[3];
|
||||
|
||||
if (!Files.exists(modelDir)) {
|
||||
System.err.println("modelDir 不存在: " + modelDir);
|
||||
return;
|
||||
}
|
||||
if (!inputImage.exists()) {
|
||||
System.err.println("输入图片不存在: " + inputImage.getAbsolutePath());
|
||||
return;
|
||||
}
|
||||
if (!Files.exists(outDir)) {
|
||||
Files.createDirectories(outDir);
|
||||
}
|
||||
|
||||
// 读取 synset.txt(如果有),否则使用默认 LabelPalette
|
||||
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels);
|
||||
|
||||
// 打开 Segmenter
|
||||
try (Segmenter segmenter = new Segmenter(modelDir, labels)) {
|
||||
SegmentationResult res = segmenter.segment(inputImage);
|
||||
|
||||
// 原始图片
|
||||
BufferedImage original = ImageIO.read(inputImage);
|
||||
|
||||
// palette: labelName -> ARGB int
|
||||
Map<String, Integer> palette = res.getPalette();
|
||||
Map<Integer, String> labelsMap = res.getLabels(); // index -> name
|
||||
|
||||
// 解析目标 labels 列表
|
||||
Set<String> targets = parseTargets(targetsArg, labels);
|
||||
|
||||
System.out.println("Will export targets: " + targets);
|
||||
|
||||
// maskImage: 每像素是类别颜色(ARGB)
|
||||
BufferedImage maskImage = res.getMaskImage();
|
||||
int w = maskImage.getWidth();
|
||||
int h = maskImage.getHeight();
|
||||
|
||||
// 为快速查 color -> labelName
|
||||
Map<Integer, String> colorToLabel = new HashMap<>();
|
||||
for (Map.Entry<String, Integer> e : palette.entrySet()) {
|
||||
colorToLabel.put(e.getValue(), e.getKey());
|
||||
}
|
||||
|
||||
// 对每个 target 生成单独的 mask 和 overlay
|
||||
for (String target : targets) {
|
||||
if (!palette.containsKey(target)) {
|
||||
System.err.println("警告:模型 palette 中没有标签 '" + target + "',跳过。");
|
||||
continue;
|
||||
}
|
||||
int targetColor = palette.get(target);
|
||||
|
||||
// 1) 生成透明背景的二值掩码(只保留 target 像素)
|
||||
BufferedImage partMask = new BufferedImage(w, h, BufferedImage.TYPE_INT_ARGB);
|
||||
for (int y = 0; y < h; y++) {
|
||||
for (int x = 0; x < w; x++) {
|
||||
int c = maskImage.getRGB(x, y);
|
||||
if (c == targetColor) {
|
||||
// 保留为不透明(使用原始颜色)
|
||||
partMask.setRGB(x, y, targetColor);
|
||||
} else {
|
||||
// 透明
|
||||
partMask.setRGB(x, y, 0x00000000);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2) 生成 overlay:在原图上叠加半透明的 targetColor 区域
|
||||
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
// 若分辨率不同,先缩放 mask 到原图大小(简单粗暴地按尺寸相同假设)
|
||||
BufferedImage maskResized = maskImage;
|
||||
if (original.getWidth() != w || original.getHeight() != h) {
|
||||
// 简单缩放 mask 到原图尺寸
|
||||
maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
Graphics2D g = maskResized.createGraphics();
|
||||
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
|
||||
g.drawImage(maskImage, 0, 0, original.getWidth(), original.getHeight(), null);
|
||||
g.dispose();
|
||||
}
|
||||
|
||||
// 在 overlay 上先画原图
|
||||
Graphics2D g2 = overlay.createGraphics();
|
||||
g2.drawImage(original, 0, 0, null);
|
||||
|
||||
// 创建半透明颜色(将 targetColor 的 alpha 设为 0x88)
|
||||
int rgbOnly = (targetColor & 0x00FFFFFF);
|
||||
int translucent = (0x88 << 24) | rgbOnly;
|
||||
// 创建一个图像,把 mask 中对应像素设置为 translucent,否则透明
|
||||
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
for (int y = 0; y < colorOverlay.getHeight(); y++) {
|
||||
for (int x = 0; x < colorOverlay.getWidth(); x++) {
|
||||
int mc = maskResized.getRGB(x, y);
|
||||
if (mc == targetColor) {
|
||||
colorOverlay.setRGB(x, y, translucent);
|
||||
} else {
|
||||
colorOverlay.setRGB(x, y, 0x00000000);
|
||||
}
|
||||
}
|
||||
}
|
||||
// 将 colorOverlay 画到 overlay 上
|
||||
g2.drawImage(colorOverlay, 0, 0, null);
|
||||
g2.dispose();
|
||||
|
||||
// 保存文件
|
||||
File maskOut = outDir.resolve(safeFileName(target) + "_mask.png").toFile();
|
||||
File overlayOut = outDir.resolve(safeFileName(target) + "_overlay.png").toFile();
|
||||
|
||||
ImageIO.write(partMask, "png", maskOut);
|
||||
ImageIO.write(overlay, "png", overlayOut);
|
||||
|
||||
System.out.println("Saved mask: " + maskOut.getAbsolutePath());
|
||||
System.out.println("Saved overlay: " + overlayOut.getAbsolutePath());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
|
||||
Path syn = modelDir.resolve("synset.txt");
|
||||
if (Files.exists(syn)) {
|
||||
try {
|
||||
List<String> lines = Files.readAllLines(syn);
|
||||
List<String> cleaned = new ArrayList<>();
|
||||
for (String l : lines) {
|
||||
String s = l.trim();
|
||||
if (!s.isEmpty()) cleaned.add(s);
|
||||
}
|
||||
if (!cleaned.isEmpty()) return Optional.of(cleaned);
|
||||
} catch (IOException ignore) {}
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
private static Set<String> parseTargets(String arg, List<String> availableLabels) {
|
||||
String a = arg.trim();
|
||||
if (a.equalsIgnoreCase("all")) {
|
||||
return new LinkedHashSet<>(availableLabels);
|
||||
}
|
||||
String[] parts = a.split(",");
|
||||
Set<String> out = new LinkedHashSet<>();
|
||||
for (String p : parts) {
|
||||
String t = p.trim();
|
||||
if (!t.isEmpty()) out.add(t);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
private static String safeFileName(String s) {
|
||||
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
|
||||
}
|
||||
}
|
||||
@@ -16,7 +16,7 @@ public class AI2Test {
|
||||
System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8));
|
||||
|
||||
// 使用 AnimeModelWrapper 而不是 VividModelWrapper
|
||||
AnimeModelWrapper wrapper = AnimeModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\Anime-Face-Segmentation\\anime_unet.pt"));
|
||||
AnimeModelWrapper wrapper = AnimeModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\Anime-Face-Segmentation\\Anime-Face-Segmentation.pt"));
|
||||
|
||||
// 使用 Anime-Face-Segmentation 的 7 个标签
|
||||
Set<String> animeLabels = Set.of(
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.chuangzhou.vivid2D.test;
|
||||
|
||||
import com.chuangzhou.vivid2D.ai.anime_segmentation.Anime2VividModelWrapper;
|
||||
import com.chuangzhou.vivid2D.ai.face_parsing.VividModelWrapper;
|
||||
|
||||
import java.io.PrintStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
@@ -15,7 +14,7 @@ public class AI3Test {
|
||||
public static void main(String[] args) throws Exception {
|
||||
System.setOut(new PrintStream(System.out, true, StandardCharsets.UTF_8));
|
||||
System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8));
|
||||
Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\anime-segmentation-main\\isnetis_traced.pt"));
|
||||
Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\anime-segmentation-main\\anime-segmentation.pt"));
|
||||
|
||||
Set<String> faceLabels = Set.of("foreground");
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.chuangzhou.vivid2D.test;
|
||||
|
||||
import com.chuangzhou.vivid2D.ai.face_parsing.VividModelWrapper;
|
||||
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetVividModelWrapper;
|
||||
|
||||
import java.io.PrintStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
@@ -14,7 +14,7 @@ public class AITest {
|
||||
public static void main(String[] args) throws Exception {
|
||||
System.setOut(new PrintStream(System.out, true, StandardCharsets.UTF_8));
|
||||
System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8));
|
||||
VividModelWrapper wrapper = VividModelWrapper.load(Paths.get("C:\\models\\bisenet_face_parsing.pt"));
|
||||
BiSeNetVividModelWrapper wrapper = BiSeNetVividModelWrapper.load(Paths.get("C:\\models\\bisenet_face_parsing.pt"));
|
||||
|
||||
// 使用 BiSeNet 人脸解析模型的 18 个非背景标签
|
||||
Set<String> faceLabels = Set.of(
|
||||
|
||||
Reference in New Issue
Block a user