feat(ai): 集成动漫人物分割与面部解析AI模型- 添加 DJL 深度学习框架依赖项以支持 PyTorch 和 ONNX Runtime 引擎

- 实现 Anime2VividModelWrapper 封装类用于动漫人物前景背景分离
- 开发 AnimeModelWrapper用于精细的动漫面部特征(如头发、眼睛)分割
- 创建配套的标签调色板和结果处理工具类提升可视化效果
- 增加多个测试用例验证不同AI模型的推理及文件输出功能
- 支持通过 synset.txt 自定义模型标签并增强命令行可测试性
This commit is contained in:
tzdwindows 7
2025-10-27 18:39:13 +08:00
parent f2cb74379e
commit a725e7eb23
18 changed files with 2414 additions and 0 deletions

View File

@@ -51,6 +51,16 @@ dependencies {
implementation files('libs/dog api 1.3.jar')
implementation files('libs/DesktopWallpaperSdk-1.0-SNAPSHOT.jar')
// === DJL API ===
implementation platform('ai.djl:bom:0.35.0')
implementation 'ai.djl:api'
implementation 'ai.djl:model-zoo'
implementation 'ai.djl.pytorch:pytorch-model-zoo:0.35.0'
implementation 'ai.djl.pytorch:pytorch-engine'
implementation 'ai.djl:basicdataset'
implementation 'ai.djl.onnxruntime:onnxruntime-engine'
runtimeOnly 'ai.djl.pytorch:pytorch-native-cpu:2.7.1'
runtimeOnly 'ai.djl.onnxruntime:onnxruntime-native-cpu:1.3.0'
// === 核心工具库 ===
implementation 'com.google.code.gson:gson:2.10.1' // 统一版本
implementation 'org.apache.logging.log4j:log4j-api:2.20.0'

View File

@@ -0,0 +1,62 @@
package com.chuangzhou.vivid2D.ai.anime_face_segmentation;
import java.util.*;
/**
* Anime-Face-Segmentation UNet 模型的标签和颜色调色板。
* 基于 Anime-Face-Segmentation 项目的 util.py 中的颜色定义。
* 标签索引必须与模型输出索引一致0-6
*/
public class AnimeLabelPalette {
/**
* Anime-Face-Segmentation UNet 模型的标准标签7个类别索引 0-6
*/
public static List<String> defaultLabels() {
return Arrays.asList(
"background", // 0 - 青色 (0,255,255)
"hair", // 1 - 蓝色 (255,0,0)
"eye", // 2 - 红色 (0,0,255)
"mouth", // 3 - 白色 (255,255,255)
"face", // 4 - 绿色 (0,255,0)
"skin", // 5 - 黄色 (255,255,0)
"clothes" // 6 - 紫色 (255,0,255)
);
}
/**
* 返回对应的调色板:类别名 -> ARGB 颜色值。
* 颜色值基于 util.py 中的 PALETTE 数组的 RGB 值转换为 ARGB 格式 (0xFFRRGGBB)。
*/
public static Map<String, Integer> defaultPalette() {
Map<String, Integer> map = new HashMap<>();
// 索引 0: background -> (0,255,255) 青色
map.put("background", 0xFF00FFFF);
// 索引 1: hair -> (255,0,0) 蓝色
map.put("hair", 0xFFFF0000);
// 索引 2: eye -> (0,0,255) 红色
map.put("eye", 0xFF0000FF);
// 索引 3: mouth -> (255,255,255) 白色
map.put("mouth", 0xFFFFFFFF);
// 索引 4: face -> (0,255,0) 绿色
map.put("face", 0xFF00FF00);
// 索引 5: skin -> (255,255,0) 黄色
map.put("skin", 0xFFFFFF00);
// 索引 6: clothes -> (255,0,255) 紫色
map.put("clothes", 0xFFFF00FF);
return map;
}
/**
* 获取类别索引到名称的映射
*/
public static Map<Integer, String> getIndexToLabelMap() {
List<String> labels = defaultLabels();
Map<Integer, String> map = new HashMap<>();
for (int i = 0; i < labels.size(); i++) {
map.put(i, labels.get(i));
}
return map;
}
}

View File

@@ -0,0 +1,426 @@
package com.chuangzhou.vivid2D.ai.anime_face_segmentation;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Method;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.List;
/**
* AnimeModelWrapper - 专门为 Anime-Face-Segmentation 模型封装的 Wrapper
*/
public class AnimeModelWrapper implements AutoCloseable {
private final AnimeSegmenter segmenter;
private final List<String> labels; // index -> name
private final Map<String, Integer> palette; // name -> ARGB
private AnimeModelWrapper(AnimeSegmenter segmenter, List<String> labels, Map<String, Integer> palette) {
this.segmenter = segmenter;
this.labels = labels;
this.palette = palette;
}
/**
* 加载模型
*/
public static AnimeModelWrapper load(Path modelDir) throws Exception {
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(AnimeLabelPalette::defaultLabels);
AnimeSegmenter segmenter = new AnimeSegmenter(modelDir, labels);
Map<String, Integer> palette = AnimeLabelPalette.defaultPalette();
return new AnimeModelWrapper(segmenter, labels, palette);
}
public List<String> getLabels() {
return Collections.unmodifiableList(labels);
}
public Map<String, Integer> getPalette() {
return Collections.unmodifiableMap(palette);
}
/**
* 直接返回分割结果(在丢给底层 segmenter 前会做通用预处理RGB 转换 + 等比 letterbox 缩放到模型输入尺寸)
*/
public AnimeSegmentationResult segment(File inputImage) throws Exception {
File pre = null;
try {
pre = preprocessAndSave(inputImage);
// 将预处理后的临时文件丢给底层 segmenter
return segmenter.segment(pre);
} finally {
if (pre != null && pre.exists()) {
try { Files.deleteIfExists(pre.toPath()); } catch (Exception ignore) {}
}
}
}
/**
* 分割并保存结果
*/
public Map<String, ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
if (!Files.exists(outDir)) {
Files.createDirectories(outDir);
}
AnimeSegmentationResult res = segment(inputImage);
BufferedImage original = ImageIO.read(inputImage);
BufferedImage maskImage = res.getMaskImage();
int maskW = maskImage.getWidth();
int maskH = maskImage.getHeight();
// 解析 targets
Set<String> realTargets = parseTargetsSet(targets);
Map<String, ResultFiles> saved = new LinkedHashMap<>();
for (String target : realTargets) {
if (!palette.containsKey(target)) {
// 尝试忽略大小写匹配
String finalTarget = target;
Optional<String> matched = palette.keySet().stream()
.filter(k -> k.equalsIgnoreCase(finalTarget))
.findFirst();
if (matched.isPresent()) target = matched.get();
else {
System.err.println("Warning: unknown label '" + target + "' - skip.");
continue;
}
}
int targetColor = palette.get(target);
// 1) 生成透明背景的二值掩码(只保留 target 像素)
BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB);
for (int y = 0; y < maskH; y++) {
for (int x = 0; x < maskW; x++) {
int c = maskImage.getRGB(x, y);
if (c == targetColor) {
partMask.setRGB(x, y, targetColor | 0xFF000000); // 保证不透明
} else {
partMask.setRGB(x, y, 0x00000000);
}
}
}
// 2) 将 mask 缩放到与原图一致(如果需要),并生成 overlay半透明
BufferedImage maskResized = partMask;
if (original.getWidth() != maskW || original.getHeight() != maskH) {
maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
Graphics2D g = maskResized.createGraphics();
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null);
g.dispose();
}
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
Graphics2D g2 = overlay.createGraphics();
g2.drawImage(original, 0, 0, null);
// 半透明颜色alpha = 0x88
int rgbOnly = (targetColor & 0x00FFFFFF);
int translucent = (0x88 << 24) | rgbOnly;
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
for (int y = 0; y < colorOverlay.getHeight(); y++) {
for (int x = 0; x < colorOverlay.getWidth(); x++) {
int mc = maskResized.getRGB(x, y);
if ((mc & 0x00FFFFFF) == (targetColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) {
colorOverlay.setRGB(x, y, translucent);
} else {
colorOverlay.setRGB(x, y, 0x00000000);
}
}
}
g2.drawImage(colorOverlay, 0, 0, null);
g2.dispose();
// 保存
String safe = safeFileName(target);
File maskOut = outDir.resolve(safe + "_mask.png").toFile();
File overlayOut = outDir.resolve(safe + "_overlay.png").toFile();
ImageIO.write(maskResized, "png", maskOut);
ImageIO.write(overlay, "png", overlayOut);
saved.put(target, new ResultFiles(maskOut, overlayOut));
}
return saved;
}
private static String safeFileName(String s) {
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
}
private static Set<String> parseTargetsSet(Set<String> in) {
if (in == null || in.isEmpty()) return Collections.emptySet();
// 若包含单个 "all"
if (in.size() == 1) {
String only = in.iterator().next();
if ("all".equalsIgnoreCase(only.trim())) {
// 返回所有标签
return new LinkedHashSet<>(AnimeLabelPalette.defaultLabels());
}
}
// 直接返回 trim 后的集合
Set<String> out = new LinkedHashSet<>();
for (String s : in) {
if (s != null) out.add(s.trim());
}
return out;
}
/**
* 专门提取眼睛的方法(在丢给底层 segmenter 前做预处理)
*/
public ResultFiles extractEyes(File inputImage, Path outDir) throws Exception {
if (!Files.exists(outDir)) {
Files.createDirectories(outDir);
}
File pre = null;
BufferedImage eyes;
try {
pre = preprocessAndSave(inputImage);
eyes = segmenter.extractEyes(pre);
} finally {
if (pre != null && pre.exists()) {
try { Files.deleteIfExists(pre.toPath()); } catch (Exception ignore) {}
}
}
File eyesMask = outDir.resolve("eyes_mask.png").toFile();
ImageIO.write(eyes, "png", eyesMask);
// 创建眼睛的 overlay原有逻辑保持不变
BufferedImage original = ImageIO.read(inputImage);
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
Graphics2D g2 = overlay.createGraphics();
g2.drawImage(original, 0, 0, null);
// 缩放眼睛掩码到原图尺寸
BufferedImage eyesResized = eyes;
if (original.getWidth() != eyes.getWidth() || original.getHeight() != eyes.getHeight()) {
eyesResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
Graphics2D g = eyesResized.createGraphics();
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g.drawImage(eyes, 0, 0, original.getWidth(), original.getHeight(), null);
g.dispose();
}
int eyeColor = palette.getOrDefault("eye", 0xFF00FF); // 若没有 eye给个显眼默认色
int rgbOnly = (eyeColor & 0x00FFFFFF);
int translucent = (0x88 << 24) | rgbOnly;
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
for (int y = 0; y < colorOverlay.getHeight(); y++) {
for (int x = 0; x < colorOverlay.getWidth(); x++) {
int mc = eyesResized.getRGB(x, y);
if ((mc & 0x00FFFFFF) == (eyeColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) {
colorOverlay.setRGB(x, y, translucent);
} else {
colorOverlay.setRGB(x, y, 0x00000000);
}
}
}
g2.drawImage(colorOverlay, 0, 0, null);
g2.dispose();
File eyesOverlay = outDir.resolve("eyes_overlay.png").toFile();
ImageIO.write(overlay, "png", eyesOverlay);
return new ResultFiles(eyesMask, eyesOverlay);
}
/**
* 关闭底层资源
*/
@Override
public void close() {
try {
segmenter.close();
} catch (Exception ignore) {}
}
/**
* 存放结果文件路径
*/
public static class ResultFiles {
private final File maskFile;
private final File overlayFile;
public ResultFiles(File maskFile, File overlayFile) {
this.maskFile = maskFile;
this.overlayFile = overlayFile;
}
public File getMaskFile() {
return maskFile;
}
public File getOverlayFile() {
return overlayFile;
}
}
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
Path syn = modelDir.resolve("synset.txt");
if (Files.exists(syn)) {
try {
List<String> lines = Files.readAllLines(syn);
List<String> cleaned = new ArrayList<>();
for (String l : lines) {
String s = l.trim();
if (!s.isEmpty()) cleaned.add(s);
}
if (!cleaned.isEmpty()) return Optional.of(cleaned);
} catch (IOException ignore) {}
}
return Optional.empty();
}
// ========== 新增:预处理并保存到临时文件 ==========
private File preprocessAndSave(File inputImage) throws IOException {
BufferedImage img = ImageIO.read(inputImage);
if (img == null) throw new IOException("无法读取图片: " + inputImage);
// 转成标准 RGB去掉 alpha / 保证三通道)
BufferedImage rgb = new BufferedImage(img.getWidth(), img.getHeight(), BufferedImage.TYPE_INT_RGB);
Graphics2D g = rgb.createGraphics();
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g.drawImage(img, 0, 0, null);
g.dispose();
// 获取模型输入尺寸(尝试反射读取,找不到则使用默认 512x512
int[] size = getModelInputSize();
int targetW = size[0], targetH = size[1];
// 等比缩放并居中填充letterbox背景用白色
double scale = Math.min((double) targetW / rgb.getWidth(), (double) targetH / rgb.getHeight());
int newW = Math.max(1, (int) Math.round(rgb.getWidth() * scale));
int newH = Math.max(1, (int) Math.round(rgb.getHeight() * scale));
BufferedImage resized = new BufferedImage(targetW, targetH, BufferedImage.TYPE_INT_RGB);
Graphics2D g2 = resized.createGraphics();
g2.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g2.setColor(Color.WHITE);
g2.fillRect(0, 0, targetW, targetH);
int x = (targetW - newW) / 2;
int y = (targetH - newH) / 2;
g2.drawImage(rgb, x, y, newW, newH, null);
g2.dispose();
// 保存为临时 PNG 文件(确保无压缩失真)
File tmp = Files.createTempFile("anime_pre_", ".png").toFile();
ImageIO.write(resized, "png", tmp);
return tmp;
}
// ========== 新增:尝试通过反射从 segmenter 上读取模型输入尺寸 ==========
private int[] getModelInputSize() {
// 默认值
int defaultSize = 512;
int w = defaultSize, h = defaultSize;
try {
Class<?> cls = segmenter.getClass();
// 尝试方法 getInputWidth/getInputHeight
try {
Method mw = cls.getMethod("getInputWidth");
Method mh = cls.getMethod("getInputHeight");
Object ow = mw.invoke(segmenter);
Object oh = mh.invoke(segmenter);
if (ow instanceof Number && oh instanceof Number) {
int iw = ((Number) ow).intValue();
int ih = ((Number) oh).intValue();
if (iw > 0 && ih > 0) {
return new int[]{iw, ih};
}
}
} catch (NoSuchMethodException ignored) {}
// 尝试方法 getInputSize 返回 int[] 或 Dimension
try {
Method ms = cls.getMethod("getInputSize");
Object os = ms.invoke(segmenter);
if (os instanceof int[] && ((int[]) os).length >= 2) {
int iw = ((int[]) os)[0];
int ih = ((int[]) os)[1];
if (iw > 0 && ih > 0) return new int[]{iw, ih};
} else if (os != null) {
// 处理 java.awt.Dimension
try {
Method gw = os.getClass().getMethod("getWidth");
Method gh = os.getClass().getMethod("getHeight");
Object ow2 = gw.invoke(os);
Object oh2 = gh.invoke(os);
if (ow2 instanceof Number && oh2 instanceof Number) {
int iw = ((Number) ow2).intValue();
int ih = ((Number) oh2).intValue();
if (iw > 0 && ih > 0) return new int[]{iw, ih};
}
} catch (Exception ignored2) {}
}
} catch (NoSuchMethodException ignored) {}
// 尝试字段 inputWidth/inputHeight
try {
try {
java.lang.reflect.Field fw = cls.getDeclaredField("inputWidth");
java.lang.reflect.Field fh = cls.getDeclaredField("inputHeight");
fw.setAccessible(true); fh.setAccessible(true);
Object ow = fw.get(segmenter);
Object oh = fh.get(segmenter);
if (ow instanceof Number && oh instanceof Number) {
int iw = ((Number) ow).intValue();
int ih = ((Number) oh).intValue();
if (iw > 0 && ih > 0) return new int[]{iw, ih};
}
} catch (NoSuchFieldException ignoredField) {}
} catch (Exception ignored) {}
} catch (Exception ignored) {
// 任何反射异常都回退到默认值
}
return new int[]{w, h};
}
/* ================= convenience 主方法(快速测试) ================= */
public static void main(String[] args) throws Exception {
if (args.length < 4) {
System.out.println("用法: AnimeModelWrapper <modelDir> <inputImage> <outDir> <targetsCommaOrAll>");
System.out.println("示例: AnimeModelWrapper ./anime_unet.pt input.jpg outDir eye,face");
System.out.println("标签: " + AnimeLabelPalette.defaultLabels());
return;
}
Path modelDir = Path.of(args[0]);
File input = new File(args[1]);
Path out = Path.of(args[2]);
String targetsArg = args[3];
Set<String> targets;
if ("all".equalsIgnoreCase(targetsArg.trim())) {
targets = new LinkedHashSet<>(AnimeLabelPalette.defaultLabels());
} else {
String[] parts = targetsArg.split(",");
targets = new LinkedHashSet<>();
for (String p : parts) {
if (!p.trim().isEmpty()) targets.add(p.trim());
}
}
try (AnimeModelWrapper wrapper = AnimeModelWrapper.load(modelDir)) {
Map<String, ResultFiles> m = wrapper.segmentAndSave(input, targets, out);
m.forEach((k, v) -> {
System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath()));
});
}
}
}

View File

@@ -0,0 +1,61 @@
package com.chuangzhou.vivid2D.ai.anime_face_segmentation;
import java.awt.image.BufferedImage;
import java.util.Map;
/**
* 动漫分割结果容器
*/
public class AnimeSegmentationResult {
// 分割掩码图(每个像素的颜色为对应类别颜色)
private final BufferedImage maskImage;
// 分割概率图(每个像素的类别概率分布)
private final float[][][] probabilityMap;
// 类别索引 -> 类别名称
private final Map<Integer, String> labels;
// 类别名称 -> ARGB 颜色
private final Map<String, Integer> palette;
public AnimeSegmentationResult(BufferedImage maskImage, float[][][] probabilityMap,
Map<Integer, String> labels, Map<String, Integer> palette) {
this.maskImage = maskImage;
this.probabilityMap = probabilityMap;
this.labels = labels;
this.palette = palette;
}
public BufferedImage getMaskImage() {
return maskImage;
}
public float[][][] getProbabilityMap() {
return probabilityMap;
}
public Map<Integer, String> getLabels() {
return labels;
}
public Map<String, Integer> getPalette() {
return palette;
}
/**
* 获取指定类别的概率图
*/
public float[][] getClassProbability(int classIndex) {
if (probabilityMap == null) return null;
int height = probabilityMap.length;
int width = probabilityMap[0].length;
float[][] result = new float[height][width];
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
result[y][x] = probabilityMap[y][x][classIndex];
}
}
return result;
}
}

View File

@@ -0,0 +1,230 @@
package com.chuangzhou.vivid2D.ai.anime_face_segmentation;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.*;
/**
* AnimeSegmenter: 专门为 Anime-Face-Segmentation UNet 模型设计的分割器
*/
public class AnimeSegmenter implements AutoCloseable {
// 模型默认输入大小(与训练时一致)。若模型不同可以修改为实际值或让 caller 通过构造参数传入。
private static final int MODEL_INPUT_W = 512;
private static final int MODEL_INPUT_H = 512;
// 内部类用于从Translator安全地传出数据
public static class SegmentationData {
final int[] indices; // 类别索引 [H * W]
final float[][][] probMap; // 概率图 [H][W][C]
final long[] shape; // 形状 [H, W]
public SegmentationData(int[] indices, float[][][] probMap, long[] shape) {
this.indices = indices;
this.probMap = probMap;
this.shape = shape;
}
}
private final ZooModel<Image, SegmentationData> modelWrapper;
private final Predictor<Image, SegmentationData> predictor;
private final List<String> labels;
private final Map<String, Integer> palette;
public AnimeSegmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
this.labels = new ArrayList<>(labels);
this.palette = AnimeLabelPalette.defaultPalette();
Translator<Image, SegmentationData> translator = new Translator<Image, SegmentationData>() {
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
NDManager manager = ctx.getNDManager();
// 如果图片已经是模型输入大小则不再 resize避免重复缩放导致失真
Image toUse = input;
if (!(input.getWidth() == MODEL_INPUT_W && input.getHeight() == MODEL_INPUT_H)) {
toUse = input.resize(MODEL_INPUT_W, MODEL_INPUT_H, true);
}
// 转换为NDArray并预处理
NDArray array = toUse.toNDArray(manager);
// DJL 返回 HWC 格式数组,转换为 CHW并标准化到 [0,1]
array = array.transpose(2, 0, 1) // HWC -> CHW
.toType(DataType.FLOAT32, false)
.div(255f) // 归一化到[0,1]
.expandDims(0); // 添加batch维度 [1,3,H,W]
return new NDList(array);
}
@Override
public SegmentationData processOutput(TranslatorContext ctx, NDList list) {
if (list == null || list.isEmpty()) {
throw new IllegalStateException("Model did not return any output.");
}
NDArray output = list.get(0); // 期望形状 [1,C,H,W] 或 [1,C,W,H](以训练时一致为准)
// 确保维度:把 output 视作 [1, C, H, W]
Shape outShape = output.getShape();
if (outShape.dimension() != 4) {
throw new IllegalStateException("Unexpected output shape: " + outShape);
}
// 1. 获取类别索引argmax -> [H, W]
NDArray squeezed = output.squeeze(0); // [C,H,W]
NDArray classMap = squeezed.argMax(0).toType(DataType.INT32, false); // argMax over channel维度
// 2. 获取概率图softmax 输出或模型已经输出概率),转换为 [H,W,C]
NDArray probabilities = squeezed.transpose(1, 2, 0) // [H,W,C]
.toType(DataType.FLOAT32, false);
// 3. 转换为Java数组
long[] shape = classMap.getShape().getShape(); // [H, W]
int[] indices = classMap.toIntArray();
long[] probShape = probabilities.getShape().getShape(); // [H, W, C]
int height = (int) probShape[0];
int width = (int) probShape[1];
int classes = (int) probShape[2];
float[] flatProbMap = probabilities.toFloatArray();
float[][][] probMap = new float[height][width][classes];
for (int i = 0; i < height; i++) {
for (int j = 0; j < width; j++) {
for (int k = 0; k < classes; k++) {
int index = i * width * classes + j * classes + k;
probMap[i][j][k] = flatProbMap[index];
}
}
}
return new SegmentationData(indices, probMap, shape);
}
@Override
public Batchifier getBatchifier() {
return null;
}
};
Criteria<Image, SegmentationData> criteria = Criteria.builder()
.setTypes(Image.class, SegmentationData.class)
.optModelPath(modelDir)
.optEngine("PyTorch")
.optTranslator(translator)
.build();
this.modelWrapper = criteria.loadModel();
this.predictor = modelWrapper.newPredictor();
}
public AnimeSegmentationResult segment(File imgFile) throws TranslateException, IOException {
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
// 预测并获取分割数据
SegmentationData data = predictor.predict(img);
long[] shp = data.shape;
int[] indices = data.indices;
float[][][] probMap = data.probMap;
int height = (int) shp[0];
int width = (int) shp[1];
// 创建掩码图像
BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
Map<Integer, String> labelsMap = AnimeLabelPalette.getIndexToLabelMap();
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int idx = indices[y * width + x];
String label = labelsMap.getOrDefault(idx, "unknown");
int argb = palette.getOrDefault(label, 0xFF00FF00); // 默认绿色
mask.setRGB(x, y, argb);
}
}
return new AnimeSegmentationResult(mask, probMap, labelsMap, palette);
}
/**
* 专门针对眼睛的分割方法
*/
public BufferedImage extractEyes(File imgFile) throws TranslateException, IOException {
AnimeSegmentationResult result = segment(imgFile);
BufferedImage mask = result.getMaskImage();
BufferedImage eyeMask = new BufferedImage(mask.getWidth(), mask.getHeight(), BufferedImage.TYPE_INT_ARGB);
int eyeColor = palette.get("eye");
for (int y = 0; y < mask.getHeight(); y++) {
for (int x = 0; x < mask.getWidth(); x++) {
int rgb = mask.getRGB(x, y);
if (rgb == eyeColor) {
eyeMask.setRGB(x, y, eyeColor);
} else {
eyeMask.setRGB(x, y, 0x00000000); // 透明
}
}
}
return eyeMask;
}
@Override
public void close() {
try {
predictor.close();
} catch (Exception ignore) {
}
try {
modelWrapper.close();
} catch (Exception ignore) {
}
}
// 测试主函数
public static void main(String[] args) throws Exception {
if (args.length < 3) {
System.out.println("用法: java AnimeSegmenter <modelDir> <inputImage> <outputMaskPng>");
System.out.println("示例: java AnimeSegmenter ./anime_unet.pt input.jpg output.png");
return;
}
Path modelDir = Path.of(args[0]);
File input = new File(args[1]);
File out = new File(args[2]);
List<String> labels = AnimeLabelPalette.defaultLabels();
try (AnimeSegmenter segmenter = new AnimeSegmenter(modelDir, labels)) {
AnimeSegmentationResult res = segmenter.segment(input);
ImageIO.write(res.getMaskImage(), "png", out);
System.out.println("动漫分割掩码已保存到: " + out.getAbsolutePath());
// 额外保存眼睛分割结果
BufferedImage eyes = segmenter.extractEyes(input);
File eyesOut = new File(out.getParent(), "eyes_" + out.getName());
ImageIO.write(eyes, "png", eyesOut);
System.out.println("眼睛分割结果已保存到: " + eyesOut.getAbsolutePath());
}
}
}

View File

@@ -0,0 +1,46 @@
package com.chuangzhou.vivid2D.ai.anime_segmentation;
import java.util.*;
/**
* 动漫分割模型的标签和颜色调色板。
* 这是一个二分类模型:背景和前景(动漫人物)
*/
public class Anime2LabelPalette {
/**
* 动漫分割模型的标准标签2个类别
*/
public static List<String> defaultLabels() {
return Arrays.asList(
"background", // 0
"foreground" // 1
);
}
/**
* 返回动漫分割模型的调色板
*/
public static Map<String, Integer> defaultPalette() {
Map<String, Integer> map = new HashMap<>();
// 索引 0: background - 黑色
map.put("background", 0xFF000000);
// 索引 1: foreground - 白色
map.put("foreground", 0xFFFFFFFF);
return map;
}
/**
* 专门为动漫分割模型设计的调色板(可视化更友好)
*/
public static Map<String, Integer> animeSegmentationPalette() {
Map<String, Integer> map = new HashMap<>();
// 背景 - 透明
map.put("background", 0x00000000);
// 前景 - 红色(用于可视化)
map.put("foreground", 0xFFFF0000);
return map;
}
}

View File

@@ -0,0 +1,244 @@
package com.chuangzhou.vivid2D.ai.anime_segmentation;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.List;
/**
* Anime2ModelWrapper - 对动漫分割模型的封装
*
* 用法示例:
* Anime2ModelWrapper wrapper = Anime2ModelWrapper.load(Paths.get("/path/to/modelDir"));
* Map<String, Anime2ModelWrapper.ResultFiles> out = wrapper.segmentAndSave(
* new File("input.jpg"),
* Set.of("foreground"), // 动漫分割主要关注前景
* Paths.get("outDir")
* );
* wrapper.close();
*/
public class Anime2ModelWrapper implements AutoCloseable {
private final Anime2Segmenter segmenter;
private final List<String> labels; // index -> name
private final Map<String, Integer> palette; // name -> ARGB
private Anime2ModelWrapper(Anime2Segmenter segmenter, List<String> labels, Map<String, Integer> palette) {
this.segmenter = segmenter;
this.labels = labels;
this.palette = palette;
}
/**
* 创建 Anime2Segmenter 实例
*/
public static Anime2ModelWrapper load(Path modelDir) throws Exception {
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels);
Anime2Segmenter s = new Anime2Segmenter(modelDir, labels);
Map<String, Integer> palette = Anime2LabelPalette.animeSegmentationPalette();
return new Anime2ModelWrapper(s, labels, palette);
}
public List<String> getLabels() {
return Collections.unmodifiableList(labels);
}
public Map<String, Integer> getPalette() {
return Collections.unmodifiableMap(palette);
}
/**
* 直接返回分割结果
*/
public Anime2SegmentationResult segment(File inputImage) throws Exception {
return segmenter.segment(inputImage);
}
/**
* 把指定 targets标签名集合从输入图片中分割并保存到 outDir
*/
public Map<String, ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
if (!Files.exists(outDir)) {
Files.createDirectories(outDir);
}
Anime2SegmentationResult res = segment(inputImage);
BufferedImage original = ImageIO.read(inputImage);
BufferedImage maskImage = res.getMaskImage();
int maskW = maskImage.getWidth();
int maskH = maskImage.getHeight();
// 解析 targets
Set<String> realTargets = parseTargetsSet(targets);
Map<String, ResultFiles> saved = new LinkedHashMap<>();
for (String target : realTargets) {
if (!palette.containsKey(target)) {
System.err.println("Warning: unknown label '" + target + "' - skip.");
continue;
}
int targetColor = palette.get(target);
// 1) 生成透明背景的二值掩码(只保留 target 像素)
BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB);
for (int y = 0; y < maskH; y++) {
for (int x = 0; x < maskW; x++) {
int c = maskImage.getRGB(x, y);
if (c == targetColor) {
partMask.setRGB(x, y, targetColor | 0xFF000000); // 保证不透明
} else {
partMask.setRGB(x, y, 0x00000000);
}
}
}
// 2) 将 mask 缩放到与原图一致
BufferedImage maskResized = partMask;
if (original.getWidth() != maskW || original.getHeight() != maskH) {
maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
Graphics2D g = maskResized.createGraphics();
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null);
g.dispose();
}
// 3) 生成叠加图
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
Graphics2D g2 = overlay.createGraphics();
g2.drawImage(original, 0, 0, null);
int rgbOnly = (targetColor & 0x00FFFFFF);
int translucent = (0x88 << 24) | rgbOnly;
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
for (int y = 0; y < colorOverlay.getHeight(); y++) {
for (int x = 0; x < colorOverlay.getWidth(); x++) {
int mc = maskResized.getRGB(x, y);
if ((mc & 0x00FFFFFF) == (targetColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) {
colorOverlay.setRGB(x, y, translucent);
} else {
colorOverlay.setRGB(x, y, 0x00000000);
}
}
}
g2.drawImage(colorOverlay, 0, 0, null);
g2.dispose();
// 保存
String safe = safeFileName(target);
File maskOut = outDir.resolve(safe + "_mask.png").toFile();
File overlayOut = outDir.resolve(safe + "_overlay.png").toFile();
ImageIO.write(maskResized, "png", maskOut);
ImageIO.write(overlay, "png", overlayOut);
saved.put(target, new ResultFiles(maskOut, overlayOut));
}
return saved;
}
private static String safeFileName(String s) {
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
}
private static Set<String> parseTargetsSet(Set<String> in) {
if (in == null || in.isEmpty()) return Collections.emptySet();
if (in.size() == 1) {
String only = in.iterator().next();
if ("all".equalsIgnoreCase(only.trim())) {
return Set.of("foreground"); // 动漫分割主要关注前景
}
}
Set<String> out = new LinkedHashSet<>();
for (String s : in) {
if (s != null) out.add(s.trim());
}
return out;
}
/**
* 关闭底层资源
*/
@Override
public void close() {
try {
segmenter.close();
} catch (Exception ignore) {}
}
/**
* 存放结果文件路径
*/
public static class ResultFiles {
private final File maskFile;
private final File overlayFile;
public ResultFiles(File maskFile, File overlayFile) {
this.maskFile = maskFile;
this.overlayFile = overlayFile;
}
public File getMaskFile() {
return maskFile;
}
public File getOverlayFile() {
return overlayFile;
}
}
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
Path syn = modelDir.resolve("synset.txt");
if (Files.exists(syn)) {
try {
List<String> lines = Files.readAllLines(syn);
List<String> cleaned = new ArrayList<>();
for (String l : lines) {
String s = l.trim();
if (!s.isEmpty()) cleaned.add(s);
}
if (!cleaned.isEmpty()) return Optional.of(cleaned);
} catch (IOException ignore) {}
}
return Optional.empty();
}
/* ================= convenience 主方法(快速测试) ================= */
public static void main(String[] args) throws Exception {
if (args.length < 4) {
System.out.println("用法: Anime2ModelWrapper <modelDir> <inputImage> <outDir> <targetsCommaOrAll>");
System.out.println("示例: Anime2ModelWrapper /models/anime_seg /images/in.jpg outDir foreground");
return;
}
Path modelDir = Path.of(args[0]);
File input = new File(args[1]);
Path out = Path.of(args[2]);
String targetsArg = args[3];
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels);
Set<String> targets;
if ("all".equalsIgnoreCase(targetsArg.trim())) {
targets = new LinkedHashSet<>(labels);
} else {
String[] parts = targetsArg.split(",");
targets = new LinkedHashSet<>();
for (String p : parts) {
if (!p.trim().isEmpty()) targets.add(p.trim());
}
}
try (Anime2ModelWrapper wrapper = Anime2ModelWrapper.load(modelDir)) {
Map<String, ResultFiles> m = wrapper.segmentAndSave(input, targets, out);
m.forEach((k, v) -> {
System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath()));
});
}
}
}

View File

@@ -0,0 +1,36 @@
package com.chuangzhou.vivid2D.ai.anime_segmentation;
import java.awt.image.BufferedImage;
import java.util.Map;
/**
* 动漫分割结果容器
*/
public class Anime2SegmentationResult {
// 分割掩码图(每个像素的颜色为对应类别颜色)
private final BufferedImage maskImage;
// 类别索引 -> 类别名称
private final Map<Integer, String> labels;
// 类别名称 -> ARGB 颜色
private final Map<String, Integer> palette;
public Anime2SegmentationResult(BufferedImage maskImage, Map<Integer, String> labels, Map<String, Integer> palette) {
this.maskImage = maskImage;
this.labels = labels;
this.palette = palette;
}
public BufferedImage getMaskImage() {
return maskImage;
}
public Map<Integer, String> getLabels() {
return labels;
}
public Map<String, Integer> getPalette() {
return palette;
}
}

View File

@@ -0,0 +1,175 @@
package com.chuangzhou.vivid2D.ai.anime_segmentation;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.*;
/**
* Anime2Segmenter: 专门用于动漫分割模型
* 处理 anime-segmentation 模型的二值分割输出
*/
public class Anime2Segmenter implements AutoCloseable {
public static class SegmentationData {
final int[] indices;
final long[] shape;
public SegmentationData(int[] indices, long[] shape) {
this.indices = indices;
this.shape = shape;
}
}
private final ZooModel<Image, SegmentationData> modelWrapper;
private final Predictor<Image, SegmentationData> predictor;
private final List<String> labels;
private final Map<String, Integer> palette;
public Anime2Segmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
this.labels = new ArrayList<>(labels);
this.palette = Anime2LabelPalette.animeSegmentationPalette();
Translator<Image, SegmentationData> translator = new Translator<Image, SegmentationData>() {
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
NDManager manager = ctx.getNDManager();
// 调整输入图像尺寸到模型期望的大小 (1024x1024)
Image resized = input.resize(1024, 1024, true);
NDArray array = resized.toNDArray(manager);
// 转换为 CHW 格式并归一化
array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false);
array = array.div(255f);
array = array.expandDims(0); // 添加batch维度
return new NDList(array);
}
@Override
public SegmentationData processOutput(TranslatorContext ctx, NDList list) {
if (list == null || list.isEmpty()) {
throw new IllegalStateException("Model did not return any output.");
}
NDArray out = list.get(0);
// 动漫分割模型输出形状: [1, 1, H, W] - 单通道概率图
// 应用sigmoid并二值化
NDArray probabilities = out.div(out.neg().exp().add(1));
NDArray binaryMask = probabilities.gt(0.5).toType(DataType.INT32, false);
// 移除batch和channel维度
if (binaryMask.getShape().dimension() == 4) {
binaryMask = binaryMask.squeeze(0).squeeze(0);
}
// 转换为Java数组
long[] finalShape = binaryMask.getShape().getShape();
int[] indices = binaryMask.toIntArray();
return new SegmentationData(indices, finalShape);
}
@Override
public Batchifier getBatchifier() {
return null;
}
};
Criteria<Image, SegmentationData> criteria = Criteria.builder()
.setTypes(Image.class, SegmentationData.class)
.optModelPath(modelDir)
.optEngine("PyTorch")
.optTranslator(translator)
.build();
this.modelWrapper = criteria.loadModel();
this.predictor = modelWrapper.newPredictor();
}
public Anime2SegmentationResult segment(File imgFile) throws TranslateException, IOException {
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
SegmentationData data = predictor.predict(img);
long[] shp = data.shape;
int[] indices = data.indices;
int height, width;
if (shp.length == 2) {
height = (int) shp[0];
width = (int) shp[1];
} else {
throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp));
}
// 创建分割掩码
BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
Map<Integer, String> labelsMap = new HashMap<>();
for (int i = 0; i < labels.size(); i++) {
labelsMap.put(i, labels.get(i));
}
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int idx = indices[y * width + x];
String label = labelsMap.getOrDefault(idx, "unknown");
int argb = palette.getOrDefault(label, 0xFFFF0000);
mask.setRGB(x, y, argb);
}
}
return new Anime2SegmentationResult(mask, labelsMap, palette);
}
@Override
public void close() {
try {
predictor.close();
} catch (Exception ignore) {
}
try {
modelWrapper.close();
} catch (Exception ignore) {
}
}
public static void main(String[] args) throws Exception {
if (args.length < 3) {
System.out.println("用法: java Anime2Segmenter <modelDir> <inputImage> <outputMaskPng>");
return;
}
Path modelDir = Path.of(args[0]);
File input = new File(args[1]);
File out = new File(args[2]);
List<String> labels = Anime2LabelPalette.defaultLabels();
try (Anime2Segmenter s = new Anime2Segmenter(modelDir, labels)) {
Anime2SegmentationResult res = s.segment(input);
ImageIO.write(res.getMaskImage(), "png", out);
System.out.println("动漫分割掩码已保存到: " + out.getAbsolutePath());
}
}
}

View File

@@ -0,0 +1,262 @@
package com.chuangzhou.vivid2D.ai.anime_segmentation;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.List;
/**
* Anime2VividModelWrapper - 对之前 Anime2Segmenter 的封装提供更便捷的API
*
* 用法示例:
* Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(Paths.get("/path/to/modelDir"));
* Map<String, Anime2VividModelWrapper.ResultFiles> out = wrapper.segmentAndSave(
* new File("input.jpg"),
* Set.of("foreground"), // 动漫分割主要关注前景
* Paths.get("outDir")
* );
* // out contains 每个目标标签对应的 mask+overlay 文件路径
* wrapper.close();
*/
public class Anime2VividModelWrapper implements AutoCloseable {
private final Anime2Segmenter segmenter;
private final List<String> labels; // index -> name
private final Map<String, Integer> palette; // name -> ARGB
private Anime2VividModelWrapper(Anime2Segmenter segmenter, List<String> labels, Map<String, Integer> palette) {
this.segmenter = segmenter;
this.labels = labels;
this.palette = palette;
}
/**
* 读取 modelDir/synset.txt每行一个标签若不存在则使用 Anime2LabelPalette.defaultLabels()
* 并创建 Anime2Segmenter 实例。
*/
public static Anime2VividModelWrapper load(Path modelDir) throws Exception {
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels);
Anime2Segmenter s = new Anime2Segmenter(modelDir, labels);
Map<String, Integer> palette = Anime2LabelPalette.animeSegmentationPalette();
return new Anime2VividModelWrapper(s, labels, palette);
}
public List<String> getLabels() {
return Collections.unmodifiableList(labels);
}
public Map<String, Integer> getPalette() {
return Collections.unmodifiableMap(palette);
}
/**
* 直接返回分割结果Anime2SegmentationResult
*/
public Anime2SegmentationResult segment(File inputImage) throws Exception {
return segmenter.segment(inputImage);
}
/**
* 把指定 targets标签名集合从输入图片中分割并保存到 outDir。
* 如果 targets 包含单个元素 "all"(忽略大小写),则保存所有标签。
* <p>
* 返回值Map<labelName, ResultFiles>ResultFiles 包含 maskFile、overlayFile两个 PNG
*/
public Map<String, ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
if (!Files.exists(outDir)) {
Files.createDirectories(outDir);
}
Anime2SegmentationResult res = segment(inputImage);
BufferedImage original = ImageIO.read(inputImage);
BufferedImage maskImage = res.getMaskImage();
int maskW = maskImage.getWidth();
int maskH = maskImage.getHeight();
// 解析 targets
Set<String> realTargets = parseTargetsSet(targets);
Map<String, ResultFiles> saved = new LinkedHashMap<>();
for (String target : realTargets) {
if (!palette.containsKey(target)) {
// 尝试忽略大小写匹配
String finalTarget = target;
Optional<String> matched = palette.keySet().stream()
.filter(k -> k.equalsIgnoreCase(finalTarget))
.findFirst();
if (matched.isPresent()) target = matched.get();
else {
System.err.println("Warning: unknown label '" + target + "' - skip.");
continue;
}
}
int targetColor = palette.get(target);
// 1) 生成透明背景的二值掩码(只保留 target 像素)
BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB);
for (int y = 0; y < maskH; y++) {
for (int x = 0; x < maskW; x++) {
int c = maskImage.getRGB(x, y);
if (c == targetColor) {
partMask.setRGB(x, y, targetColor | 0xFF000000); // 保证不透明
} else {
partMask.setRGB(x, y, 0x00000000);
}
}
}
// 2) 将 mask 缩放到与原图一致(如果需要),并生成 overlay半透明
BufferedImage maskResized = partMask;
if (original.getWidth() != maskW || original.getHeight() != maskH) {
maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
Graphics2D g = maskResized.createGraphics();
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null);
g.dispose();
}
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
Graphics2D g2 = overlay.createGraphics();
g2.drawImage(original, 0, 0, null);
// 半透明颜色alpha = 0x88
int rgbOnly = (targetColor & 0x00FFFFFF);
int translucent = (0x88 << 24) | rgbOnly;
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
for (int y = 0; y < colorOverlay.getHeight(); y++) {
for (int x = 0; x < colorOverlay.getWidth(); x++) {
int mc = maskResized.getRGB(x, y);
if ((mc & 0x00FFFFFF) == (targetColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) {
colorOverlay.setRGB(x, y, translucent);
} else {
colorOverlay.setRGB(x, y, 0x00000000);
}
}
}
g2.drawImage(colorOverlay, 0, 0, null);
g2.dispose();
// 保存
String safe = safeFileName(target);
File maskOut = outDir.resolve(safe + "_mask.png").toFile();
File overlayOut = outDir.resolve(safe + "_overlay.png").toFile();
ImageIO.write(maskResized, "png", maskOut);
ImageIO.write(overlay, "png", overlayOut);
saved.put(target, new ResultFiles(maskOut, overlayOut));
}
return saved;
}
private static String safeFileName(String s) {
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
}
private static Set<String> parseTargetsSet(Set<String> in) {
if (in == null || in.isEmpty()) return Collections.emptySet();
// 若包含单个 "all"
if (in.size() == 1) {
String only = in.iterator().next();
if ("all".equalsIgnoreCase(only.trim())) {
// 由调用方自行取 labels这里返回 sentinel, but caller already checks palette
// For convenience, return a set containing "all" and let caller logic handle it earlier.
return Set.of("all");
}
}
// 直接返回 trim 后的小写不变集合(保持用户传入的名字)
Set<String> out = new LinkedHashSet<>();
for (String s : in) {
if (s != null) out.add(s.trim());
}
return out;
}
/**
* 关闭底层资源
*/
@Override
public void close() {
try {
segmenter.close();
} catch (Exception ignore) {}
}
/**
* 存放结果文件路径
*/
public static class ResultFiles {
private final File maskFile;
private final File overlayFile;
public ResultFiles(File maskFile, File overlayFile) {
this.maskFile = maskFile;
this.overlayFile = overlayFile;
}
public File getMaskFile() {
return maskFile;
}
public File getOverlayFile() {
return overlayFile;
}
}
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
Path syn = modelDir.resolve("synset.txt");
if (Files.exists(syn)) {
try {
List<String> lines = Files.readAllLines(syn);
List<String> cleaned = new ArrayList<>();
for (String l : lines) {
String s = l.trim();
if (!s.isEmpty()) cleaned.add(s);
}
if (!cleaned.isEmpty()) return Optional.of(cleaned);
} catch (IOException ignore) {}
}
return Optional.empty();
}
/* ================= convenience 主方法(快速测试) ================= */
public static void main(String[] args) throws Exception {
if (args.length < 4) {
System.out.println("用法: Anime2VividModelWrapper <modelDir> <inputImage> <outDir> <targetsCommaOrAll>");
System.out.println("示例: Anime2VividModelWrapper /models/anime_seg /images/in.jpg outDir foreground");
System.out.println("示例: Anime2VividModelWrapper /models/anime_seg /images/in.jpg outDir all");
return;
}
Path modelDir = Path.of(args[0]);
File input = new File(args[1]);
Path out = Path.of(args[2]);
String targetsArg = args[3];
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels);
Set<String> targets;
if ("all".equalsIgnoreCase(targetsArg.trim())) {
targets = new LinkedHashSet<>(labels);
} else {
String[] parts = targetsArg.split(",");
targets = new LinkedHashSet<>();
for (String p : parts) {
if (!p.trim().isEmpty()) targets.add(p.trim());
}
}
try (Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(modelDir)) {
Map<String, ResultFiles> m = wrapper.segmentAndSave(input, targets, out);
m.forEach((k, v) -> {
System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath()));
});
}
}
}

View File

@@ -0,0 +1,89 @@
package com.chuangzhou.vivid2D.ai.face_parsing;
import java.util.*;
/**
* BiSeNet 人脸解析模型的标准标签和颜色调色板。
* 颜色值基于 zllrunning/face-parsing.PyTorch 仓库的 test.py 文件。
* 标签索引必须与模型输出索引一致0-18
*/
public class LabelPalette {
/**
* BiSeNet 人脸解析模型的标准标签19个类别索引 0-18
*/
public static List<String> defaultLabels() {
return Arrays.asList(
"background", // 0
"skin", // 1
"nose", // 2
"eye_left", // 3
"eye_right", // 4
"eyebrow_left", // 5
"eyebrow_right",// 6
"ear_left", // 7
"ear_right", // 8
"mouth", // 9
"lip_upper", // 10
"lip_lower", // 11
"hair", // 12
"hat", // 13
"earring", // 14
"necklace", // 15
"clothes", // 16
"facial_hair",// 17
"neck" // 18
);
}
/**
* 返回一个对应的调色板:类别名 -> ARGB 颜色值。
* 颜色值基于 test.py 中 part_colors 数组的 RGB 值转换为 ARGB 格式 (0xFFRRGGBB)。
*/
public static Map<String, Integer> defaultPalette() {
Map<String, Integer> map = new HashMap<>();
// 索引 0: background
map.put("background", 0xFF000000); // 黑色
// 索引 1-18: 对应 part_colors 数组的前 18 个颜色
// 注意:这里假设 part_colors[i-1] 对应 索引 i 的标签。
// 索引 1: skin -> [255, 0, 0]
map.put("skin", 0xFFFF0000);
// 索引 2: nose -> [255, 85, 0]
map.put("nose", 0xFFFF5500);
// 索引 3: eye_left -> [255, 170, 0]
map.put("eye_left", 0xFFFFAA00);
// 索引 4: eye_right -> [255, 0, 85]
map.put("eye_right", 0xFFFF0055);
// 索引 5: eyebrow_left -> [255, 0, 170]
map.put("eyebrow_left",0xFFFF00AA);
// 索引 6: eyebrow_right -> [0, 255, 0]
map.put("eyebrow_right",0xFF00FF00);
// 索引 7: ear_left -> [85, 255, 0]
map.put("ear_left", 0xFF55FF00);
// 索引 8: ear_right -> [170, 255, 0]
map.put("ear_right", 0xFFAAFF00);
// 索引 9: mouth -> [0, 255, 85]
map.put("mouth", 0xFF00FF55);
// 索引 10: lip_upper -> [0, 255, 170]
map.put("lip_upper", 0xFF00FFAA);
// 索引 11: lip_lower -> [0, 0, 255]
map.put("lip_lower", 0xFF0000FF);
// 索引 12: hair -> [85, 0, 255]
map.put("hair", 0xFF5500FF);
// 索引 13: hat -> [170, 0, 255]
map.put("hat", 0xFFAA00FF);
// 索引 14: earring -> [0, 85, 255]
map.put("earring", 0xFF0055FF);
// 索引 15: necklace -> [0, 170, 255]
map.put("necklace", 0xFF00AAFF);
// 索引 16: clothes -> [255, 255, 0]
map.put("clothes", 0xFFFFFF00);
// 索引 17: facial_hair -> [255, 85, 85]
map.put("facial_hair", 0xFFFF5555);
// 索引 18: neck -> [255, 170, 170]
map.put("neck", 0xFFFFAAAA);
return map;
}
}

View File

@@ -0,0 +1,36 @@
package com.chuangzhou.vivid2D.ai.face_parsing;
import java.awt.image.BufferedImage;
import java.util.Map;
/**
* 分割结果容器
*/
public class SegmentationResult {
// 分割掩码图(每个像素的颜色为对应类别颜色)
private final BufferedImage maskImage;
// 类别索引 -> 类别名称
private final Map<Integer, String> labels;
// 类别名称 -> ARGB 颜色
private final Map<String, Integer> palette;
public SegmentationResult(BufferedImage maskImage, Map<Integer, String> labels, Map<String, Integer> palette) {
this.maskImage = maskImage;
this.labels = labels;
this.palette = palette;
}
public BufferedImage getMaskImage() {
return maskImage;
}
public Map<Integer, String> getLabels() {
return labels;
}
public Map<String, Integer> getPalette() {
return palette;
}
}

View File

@@ -0,0 +1,193 @@
package com.chuangzhou.vivid2D.ai.face_parsing;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.*;
/**
* Segmenter: 加载模型并对图片做语义分割
*
* 说明:
* - Translator.processOutput 在翻译器层就把模型输出处理成 (H, W) 的类别索引 NDArray
* 并把该 NDArray 拷贝到 persistentManager 中返回,从而避免后续 native 资源被释放的问题。
* - 这里改为在 Translator 内部把 classMap 转为 Java int[](通过 classMap.toIntArray()
* 再用 persistentManager.create(int[], shape) 创建新的 NDArray 返回,确保安全。
*/
public class Segmenter implements AutoCloseable {
// 内部类用于从Translator安全地传出数据
public static class SegmentationData {
final int[] indices;
final long[] shape;
public SegmentationData(int[] indices, long[] shape) {
this.indices = indices;
this.shape = shape;
}
}
private final ZooModel<Image, SegmentationData> modelWrapper;
private final Predictor<Image, SegmentationData> predictor;
private final List<String> labels;
private final Map<String, Integer> palette;
public Segmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
this.labels = new ArrayList<>(labels);
this.palette = LabelPalette.defaultPalette();
// Translator 的输出类型现在是 SegmentationData
Translator<Image, SegmentationData> translator = new Translator<Image, SegmentationData>() {
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
NDManager manager = ctx.getNDManager();
NDArray array = input.toNDArray(manager);
array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false);
array = array.div(255f);
array = array.expandDims(0);
return new NDList(array);
}
@Override
public SegmentationData processOutput(TranslatorContext ctx, NDList list) {
if (list == null || list.isEmpty()) {
throw new IllegalStateException("Model did not return any output.");
}
NDArray out = list.get(0);
NDArray classMap;
// 1. 解析模型输出,得到类别图谱 (classMap)
long[] shape = out.getShape().getShape();
if (shape.length == 4 && shape[1] > 1) {
classMap = out.argMax(1);
} else if (shape.length == 3) {
classMap = (shape[0] == 1) ? out : out.argMax(0);
} else if (shape.length == 2) {
classMap = out;
} else {
throw new IllegalStateException("Unexpected output shape: " + Arrays.toString(shape));
}
if (classMap.getShape().dimension() == 3) {
classMap = classMap.squeeze(0);
}
// 2. *** 关键步骤 ***
// 在 NDArray 仍然有效的上下文中,将其转换为 Java 原生类型
// 首先,确保数据类型是 INT32
NDArray int32ClassMap = classMap.toType(DataType.INT32, false);
// 然后,获取形状和 int[] 数组
long[] finalShape = int32ClassMap.getShape().getShape();
int[] indices = int32ClassMap.toIntArray();
// 3. 将 Java 对象封装并返回
return new SegmentationData(indices, finalShape);
}
@Override
public Batchifier getBatchifier() {
return null; // 或者根据需要使用 Batchifier.STACK
}
};
// Criteria 的类型也需要更新
Criteria<Image, SegmentationData> criteria = Criteria.builder()
.setTypes(Image.class, SegmentationData.class)
.optModelPath(modelDir)
.optEngine("PyTorch")
.optTranslator(translator)
.build();
this.modelWrapper = criteria.loadModel();
this.predictor = modelWrapper.newPredictor();
}
public SegmentationResult segment(File imgFile) throws TranslateException, IOException {
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
// predict 方法现在直接返回安全的 Java 对象
SegmentationData data = predictor.predict(img);
long[] shp = data.shape;
int[] indices = data.indices;
int height, width;
if (shp.length == 2) {
height = (int) shp[0];
width = (int) shp[1];
} else {
throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp));
}
// 后续处理完全基于 Java 对象,不再有 Native resource 问题
BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
Map<Integer, String> labelsMap = new HashMap<>();
for (int i = 0; i < labels.size(); i++) {
labelsMap.put(i, labels.get(i));
}
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int idx = indices[y * width + x];
String label = labelsMap.getOrDefault(idx, "unknown");
int argb = palette.getOrDefault(label, 0xFF00FF00);
mask.setRGB(x, y, argb);
}
}
return new SegmentationResult(mask, labelsMap, palette);
}
@Override
public void close() {
try {
predictor.close();
} catch (Exception ignore) {
}
try {
modelWrapper.close();
} catch (Exception ignore) {
}
}
// 小测试主函数(示例)
public static void main(String[] args) throws Exception {
if (args.length < 3) {
System.out.println("用法: java Segmenter <modelDir> <inputImage> <outputMaskPng>");
return;
}
Path modelDir = Path.of(args[0]);
File input = new File(args[1]);
File out = new File(args[2]);
List<String> labels = LabelPalette.defaultLabels();
try (Segmenter s = new Segmenter(modelDir, labels)) {
SegmentationResult res = s.segment(input);
ImageIO.write(res.getMaskImage(), "png", out);
System.out.println("分割掩码已保存到: " + out.getAbsolutePath());
}
}
}

View File

@@ -0,0 +1,184 @@
package com.chuangzhou.vivid2D.ai.face_parsing;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.awt.*;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.List;
/**
* SegmenterExample
*
* 使用说明(命令行):
* java -cp <your-classpath> com.chuangzhou.vivid2D.ai.face_parsing.SegmenterExample \
* <modelDir> <inputImage> <outputDir> <targetsCommaOrAll>
*
* 示例:
* java ... SegmenterExample /models/face_bisent /images/in.jpg /out "eye,face"
* java ... SegmenterExample /models/face_bisent /images/in.jpg /out all
*/
public class SegmenterExample {
public static void main(String[] args) throws Exception {
if (args.length < 4) {
System.err.println("用法: SegmenterExample <modelDir> <inputImage> <outputDir> <targetsCommaOrAll>");
System.err.println("例如: SegmenterExample /models/face_bisent input.jpg outDir eye,face");
return;
}
Path modelDir = Path.of(args[0]);
File inputImage = new File(args[1]);
Path outDir = Path.of(args[2]);
String targetsArg = args[3];
if (!Files.exists(modelDir)) {
System.err.println("modelDir 不存在: " + modelDir);
return;
}
if (!inputImage.exists()) {
System.err.println("输入图片不存在: " + inputImage.getAbsolutePath());
return;
}
if (!Files.exists(outDir)) {
Files.createDirectories(outDir);
}
// 读取 synset.txt如果有否则使用默认 LabelPalette
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels);
// 打开 Segmenter
try (Segmenter segmenter = new Segmenter(modelDir, labels)) {
SegmentationResult res = segmenter.segment(inputImage);
// 原始图片
BufferedImage original = ImageIO.read(inputImage);
// palette: labelName -> ARGB int
Map<String, Integer> palette = res.getPalette();
Map<Integer, String> labelsMap = res.getLabels(); // index -> name
// 解析目标 labels 列表
Set<String> targets = parseTargets(targetsArg, labels);
System.out.println("Will export targets: " + targets);
// maskImage: 每像素是类别颜色ARGB
BufferedImage maskImage = res.getMaskImage();
int w = maskImage.getWidth();
int h = maskImage.getHeight();
// 为快速查 color -> labelName
Map<Integer, String> colorToLabel = new HashMap<>();
for (Map.Entry<String, Integer> e : palette.entrySet()) {
colorToLabel.put(e.getValue(), e.getKey());
}
// 对每个 target 生成单独的 mask 和 overlay
for (String target : targets) {
if (!palette.containsKey(target)) {
System.err.println("警告:模型 palette 中没有标签 '" + target + "',跳过。");
continue;
}
int targetColor = palette.get(target);
// 1) 生成透明背景的二值掩码(只保留 target 像素)
BufferedImage partMask = new BufferedImage(w, h, BufferedImage.TYPE_INT_ARGB);
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
int c = maskImage.getRGB(x, y);
if (c == targetColor) {
// 保留为不透明(使用原始颜色)
partMask.setRGB(x, y, targetColor);
} else {
// 透明
partMask.setRGB(x, y, 0x00000000);
}
}
}
// 2) 生成 overlay在原图上叠加半透明的 targetColor 区域
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
// 若分辨率不同,先缩放 mask 到原图大小(简单粗暴地按尺寸相同假设)
BufferedImage maskResized = maskImage;
if (original.getWidth() != w || original.getHeight() != h) {
// 简单缩放 mask 到原图尺寸
maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
Graphics2D g = maskResized.createGraphics();
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g.drawImage(maskImage, 0, 0, original.getWidth(), original.getHeight(), null);
g.dispose();
}
// 在 overlay 上先画原图
Graphics2D g2 = overlay.createGraphics();
g2.drawImage(original, 0, 0, null);
// 创建半透明颜色(将 targetColor 的 alpha 设为 0x88
int rgbOnly = (targetColor & 0x00FFFFFF);
int translucent = (0x88 << 24) | rgbOnly;
// 创建一个图像,把 mask 中对应像素设置为 translucent否则透明
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
for (int y = 0; y < colorOverlay.getHeight(); y++) {
for (int x = 0; x < colorOverlay.getWidth(); x++) {
int mc = maskResized.getRGB(x, y);
if (mc == targetColor) {
colorOverlay.setRGB(x, y, translucent);
} else {
colorOverlay.setRGB(x, y, 0x00000000);
}
}
}
// 将 colorOverlay 画到 overlay 上
g2.drawImage(colorOverlay, 0, 0, null);
g2.dispose();
// 保存文件
File maskOut = outDir.resolve(safeFileName(target) + "_mask.png").toFile();
File overlayOut = outDir.resolve(safeFileName(target) + "_overlay.png").toFile();
ImageIO.write(partMask, "png", maskOut);
ImageIO.write(overlay, "png", overlayOut);
System.out.println("Saved mask: " + maskOut.getAbsolutePath());
System.out.println("Saved overlay: " + overlayOut.getAbsolutePath());
}
}
}
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
Path syn = modelDir.resolve("synset.txt");
if (Files.exists(syn)) {
try {
List<String> lines = Files.readAllLines(syn);
List<String> cleaned = new ArrayList<>();
for (String l : lines) {
String s = l.trim();
if (!s.isEmpty()) cleaned.add(s);
}
if (!cleaned.isEmpty()) return Optional.of(cleaned);
} catch (IOException ignore) {}
}
return Optional.empty();
}
private static Set<String> parseTargets(String arg, List<String> availableLabels) {
String a = arg.trim();
if (a.equalsIgnoreCase("all")) {
return new LinkedHashSet<>(availableLabels);
}
String[] parts = a.split(",");
Set<String> out = new LinkedHashSet<>();
for (String p : parts) {
String t = p.trim();
if (!t.isEmpty()) out.add(t);
}
return out;
}
private static String safeFileName(String s) {
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
}
}

View File

@@ -0,0 +1,261 @@
package com.chuangzhou.vivid2D.ai.face_parsing;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.List;
/**
* VividModelWrapper - 对之前 Segmenter / SegmenterExample 的封装
*
* 用法示例:
* VividModelWrapper wrapper = VividModelWrapper.load(Paths.get("/path/to/modelDir"));
* Map<String, VividModelWrapper.ResultFiles> out = wrapper.segmentAndSave(
* new File("input.jpg"),
* Set.of("eye","face"), // 或 Set.of(all labels...);若想全部传 "all" 可以用 helper parseTargets
* Paths.get("outDir")
* );
* // out contains 每个目标标签对应的 mask+overlay 文件路径
* wrapper.close();
*/
public class VividModelWrapper implements AutoCloseable {
private final Segmenter segmenter;
private final List<String> labels; // index -> name
private final Map<String, Integer> palette; // name -> ARGB
private VividModelWrapper(Segmenter segmenter, List<String> labels, Map<String, Integer> palette) {
this.segmenter = segmenter;
this.labels = labels;
this.palette = palette;
}
/**
* 读取 modelDir/synset.txt每行一个标签若不存在则使用 LabelPalette.defaultLabels()
* 并创建 Segmenter 实例。
*/
public static VividModelWrapper load(Path modelDir) throws Exception {
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels);
Segmenter s = new Segmenter(modelDir, labels);
Map<String, Integer> palette = LabelPalette.defaultPalette();
return new VividModelWrapper(s, labels, palette);
}
public List<String> getLabels() {
return Collections.unmodifiableList(labels);
}
public Map<String, Integer> getPalette() {
return Collections.unmodifiableMap(palette);
}
/**
* 直接返回分割结果SegmentationResult
*/
public SegmentationResult segment(File inputImage) throws Exception {
return segmenter.segment(inputImage);
}
/**
* 把指定 targets标签名集合从输入图片中分割并保存到 outDir。
* 如果 targets 包含单个元素 "all"(忽略大小写),则保存所有标签。
* <p>
* 返回值Map<labelName, ResultFiles>ResultFiles 包含 maskFile、overlayFile两个 PNG
*/
public Map<String, ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
if (!Files.exists(outDir)) {
Files.createDirectories(outDir);
}
SegmentationResult res = segment(inputImage);
BufferedImage original = ImageIO.read(inputImage);
BufferedImage maskImage = res.getMaskImage();
int maskW = maskImage.getWidth();
int maskH = maskImage.getHeight();
// 解析 targets
Set<String> realTargets = parseTargetsSet(targets);
Map<String, ResultFiles> saved = new LinkedHashMap<>();
for (String target : realTargets) {
if (!palette.containsKey(target)) {
// 尝试忽略大小写匹配
String finalTarget = target;
Optional<String> matched = palette.keySet().stream()
.filter(k -> k.equalsIgnoreCase(finalTarget))
.findFirst();
if (matched.isPresent()) target = matched.get();
else {
System.err.println("Warning: unknown label '" + target + "' - skip.");
continue;
}
}
int targetColor = palette.get(target);
// 1) 生成透明背景的二值掩码(只保留 target 像素)
BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB);
for (int y = 0; y < maskH; y++) {
for (int x = 0; x < maskW; x++) {
int c = maskImage.getRGB(x, y);
if (c == targetColor) {
partMask.setRGB(x, y, targetColor | 0xFF000000); // 保证不透明
} else {
partMask.setRGB(x, y, 0x00000000);
}
}
}
// 2) 将 mask 缩放到与原图一致(如果需要),并生成 overlay半透明
BufferedImage maskResized = partMask;
if (original.getWidth() != maskW || original.getHeight() != maskH) {
maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
Graphics2D g = maskResized.createGraphics();
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null);
g.dispose();
}
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
Graphics2D g2 = overlay.createGraphics();
g2.drawImage(original, 0, 0, null);
// 半透明颜色alpha = 0x88
int rgbOnly = (targetColor & 0x00FFFFFF);
int translucent = (0x88 << 24) | rgbOnly;
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
for (int y = 0; y < colorOverlay.getHeight(); y++) {
for (int x = 0; x < colorOverlay.getWidth(); x++) {
int mc = maskResized.getRGB(x, y);
if ((mc & 0x00FFFFFF) == (targetColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) {
colorOverlay.setRGB(x, y, translucent);
} else {
colorOverlay.setRGB(x, y, 0x00000000);
}
}
}
g2.drawImage(colorOverlay, 0, 0, null);
g2.dispose();
// 保存
String safe = safeFileName(target);
File maskOut = outDir.resolve(safe + "_mask.png").toFile();
File overlayOut = outDir.resolve(safe + "_overlay.png").toFile();
ImageIO.write(maskResized, "png", maskOut);
ImageIO.write(overlay, "png", overlayOut);
saved.put(target, new ResultFiles(maskOut, overlayOut));
}
return saved;
}
private static String safeFileName(String s) {
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
}
private static Set<String> parseTargetsSet(Set<String> in) {
if (in == null || in.isEmpty()) return Collections.emptySet();
// 若包含单个 "all"
if (in.size() == 1) {
String only = in.iterator().next();
if ("all".equalsIgnoreCase(only.trim())) {
// 由调用方自行取 labels这里返回 sentinel, but caller already checks palette
// For convenience, return a set containing "all" and let caller logic handle it earlier.
return Set.of("all");
}
}
// 直接返回 trim 后的小写不变集合(保持用户传入的名字)
Set<String> out = new LinkedHashSet<>();
for (String s : in) {
if (s != null) out.add(s.trim());
}
return out;
}
/**
* 关闭底层资源
*/
@Override
public void close() {
try {
segmenter.close();
} catch (Exception ignore) {}
}
/**
* 存放结果文件路径
*/
public static class ResultFiles {
private final File maskFile;
private final File overlayFile;
public ResultFiles(File maskFile, File overlayFile) {
this.maskFile = maskFile;
this.overlayFile = overlayFile;
}
public File getMaskFile() {
return maskFile;
}
public File getOverlayFile() {
return overlayFile;
}
}
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
Path syn = modelDir.resolve("synset.txt");
if (Files.exists(syn)) {
try {
List<String> lines = Files.readAllLines(syn);
List<String> cleaned = new ArrayList<>();
for (String l : lines) {
String s = l.trim();
if (!s.isEmpty()) cleaned.add(s);
}
if (!cleaned.isEmpty()) return Optional.of(cleaned);
} catch (IOException ignore) {}
}
return Optional.empty();
}
/* ================= convenience 主方法(快速测试) ================= */
public static void main(String[] args) throws Exception {
if (args.length < 4) {
System.out.println("用法: VividModelWrapper <modelDir> <inputImage> <outDir> <targetsCommaOrAll>");
System.out.println("示例: VividModelWrapper /models/bisenet /images/in.jpg outDir eye,face");
return;
}
Path modelDir = Path.of(args[0]);
File input = new File(args[1]);
Path out = Path.of(args[2]);
String targetsArg = args[3];
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels);
Set<String> targets;
if ("all".equalsIgnoreCase(targetsArg.trim())) {
targets = new LinkedHashSet<>(labels);
} else {
String[] parts = targetsArg.split(",");
targets = new LinkedHashSet<>();
for (String p : parts) {
if (!p.trim().isEmpty()) targets.add(p.trim());
}
}
try (VividModelWrapper wrapper = VividModelWrapper.load(modelDir)) {
Map<String, ResultFiles> m = wrapper.segmentAndSave(input, targets, out);
m.forEach((k, v) -> {
System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath()));
});
}
}
}

View File

@@ -0,0 +1,38 @@
package com.chuangzhou.vivid2D.test;
import com.chuangzhou.vivid2D.ai.anime_face_segmentation.AnimeModelWrapper;
import java.io.PrintStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.Set;
/**
* 用来分析人物的脸部信息头发、眼睛、嘴巴、脸部、皮肤、衣服
*/
public class AI2Test {
public static void main(String[] args) throws Exception {
System.setOut(new PrintStream(System.out, true, StandardCharsets.UTF_8));
System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8));
// 使用 AnimeModelWrapper 而不是 VividModelWrapper
AnimeModelWrapper wrapper = AnimeModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\Anime-Face-Segmentation\\anime_unet.pt"));
// 使用 Anime-Face-Segmentation 的 7 个标签
Set<String> animeLabels = Set.of(
"background",
"hair", // 头发
"eye", // 眼睛
"mouth", // 嘴巴
"face", // 脸部
"skin", // 皮肤
"clothes" // 衣服
);
wrapper.segmentAndSave(
Paths.get("C:\\Users\\Administrator\\Desktop\\b_215609167a3a20ac2075487bd532bbff.jpg").toFile(),
animeLabels,
Paths.get("C:\\models\\out")
);
}
}

View File

@@ -0,0 +1,28 @@
package com.chuangzhou.vivid2D.test;
import com.chuangzhou.vivid2D.ai.anime_segmentation.Anime2VividModelWrapper;
import com.chuangzhou.vivid2D.ai.face_parsing.VividModelWrapper;
import java.io.PrintStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.Set;
/**
* 这个ai模型负责分离人物与背景
*/
public class AI3Test {
public static void main(String[] args) throws Exception {
System.setOut(new PrintStream(System.out, true, StandardCharsets.UTF_8));
System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8));
Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\anime-segmentation-main\\isnetis_traced.pt"));
Set<String> faceLabels = Set.of("foreground");
wrapper.segmentAndSave(
Paths.get("C:\\Users\\Administrator\\Desktop\\b_7a8349adece17d1e4bebd20cb2387cf6.jpg").toFile(),
faceLabels,
Paths.get("C:\\models\\out")
);
}
}

View File

@@ -0,0 +1,33 @@
package com.chuangzhou.vivid2D.test;
import com.chuangzhou.vivid2D.ai.face_parsing.VividModelWrapper;
import java.io.PrintStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.Set;
/**
* 测试人脸解析模型
*/
public class AITest {
public static void main(String[] args) throws Exception {
System.setOut(new PrintStream(System.out, true, StandardCharsets.UTF_8));
System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8));
VividModelWrapper wrapper = VividModelWrapper.load(Paths.get("C:\\models\\bisenet_face_parsing.pt"));
// 使用 BiSeNet 人脸解析模型的 18 个非背景标签
Set<String> faceLabels = Set.of(
"skin", "nose", "eye_left", "eye_right", "eyebrow_left",
"eyebrow_right", "ear_left", "ear_right", "mouth", "lip_upper",
"lip_lower", "hair", "hat", "earring", "necklace", "clothes",
"facial_hair", "neck"
);
wrapper.segmentAndSave(
Paths.get("C:\\Users\\Administrator\\Desktop\\b_f4881214f0d18b6cf848b6736f554821.png").toFile(),
faceLabels,
Paths.get("C:\\models\\out")
);
}
}