feat(ai): 集成动漫人物分割与面部解析AI模型- 添加 DJL 深度学习框架依赖项以支持 PyTorch 和 ONNX Runtime 引擎
- 实现 Anime2VividModelWrapper 封装类用于动漫人物前景背景分离 - 开发 AnimeModelWrapper用于精细的动漫面部特征(如头发、眼睛)分割 - 创建配套的标签调色板和结果处理工具类提升可视化效果 - 增加多个测试用例验证不同AI模型的推理及文件输出功能 - 支持通过 synset.txt 自定义模型标签并增强命令行可测试性
This commit is contained in:
10
build.gradle
10
build.gradle
@@ -51,6 +51,16 @@ dependencies {
|
||||
implementation files('libs/dog api 1.3.jar')
|
||||
implementation files('libs/DesktopWallpaperSdk-1.0-SNAPSHOT.jar')
|
||||
|
||||
// === DJL API ===
|
||||
implementation platform('ai.djl:bom:0.35.0')
|
||||
implementation 'ai.djl:api'
|
||||
implementation 'ai.djl:model-zoo'
|
||||
implementation 'ai.djl.pytorch:pytorch-model-zoo:0.35.0'
|
||||
implementation 'ai.djl.pytorch:pytorch-engine'
|
||||
implementation 'ai.djl:basicdataset'
|
||||
implementation 'ai.djl.onnxruntime:onnxruntime-engine'
|
||||
runtimeOnly 'ai.djl.pytorch:pytorch-native-cpu:2.7.1'
|
||||
runtimeOnly 'ai.djl.onnxruntime:onnxruntime-native-cpu:1.3.0'
|
||||
// === 核心工具库 ===
|
||||
implementation 'com.google.code.gson:gson:2.10.1' // 统一版本
|
||||
implementation 'org.apache.logging.log4j:log4j-api:2.20.0'
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
package com.chuangzhou.vivid2D.ai.anime_face_segmentation;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Anime-Face-Segmentation UNet 模型的标签和颜色调色板。
|
||||
* 基于 Anime-Face-Segmentation 项目的 util.py 中的颜色定义。
|
||||
* 标签索引必须与模型输出索引一致(0-6)。
|
||||
*/
|
||||
public class AnimeLabelPalette {
|
||||
|
||||
/**
|
||||
* Anime-Face-Segmentation UNet 模型的标准标签(7个类别,索引 0-6)
|
||||
*/
|
||||
public static List<String> defaultLabels() {
|
||||
return Arrays.asList(
|
||||
"background", // 0 - 青色 (0,255,255)
|
||||
"hair", // 1 - 蓝色 (255,0,0)
|
||||
"eye", // 2 - 红色 (0,0,255)
|
||||
"mouth", // 3 - 白色 (255,255,255)
|
||||
"face", // 4 - 绿色 (0,255,0)
|
||||
"skin", // 5 - 黄色 (255,255,0)
|
||||
"clothes" // 6 - 紫色 (255,0,255)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* 返回对应的调色板:类别名 -> ARGB 颜色值。
|
||||
* 颜色值基于 util.py 中的 PALETTE 数组的 RGB 值转换为 ARGB 格式 (0xFFRRGGBB)。
|
||||
*/
|
||||
public static Map<String, Integer> defaultPalette() {
|
||||
Map<String, Integer> map = new HashMap<>();
|
||||
// 索引 0: background -> (0,255,255) 青色
|
||||
map.put("background", 0xFF00FFFF);
|
||||
// 索引 1: hair -> (255,0,0) 蓝色
|
||||
map.put("hair", 0xFFFF0000);
|
||||
// 索引 2: eye -> (0,0,255) 红色
|
||||
map.put("eye", 0xFF0000FF);
|
||||
// 索引 3: mouth -> (255,255,255) 白色
|
||||
map.put("mouth", 0xFFFFFFFF);
|
||||
// 索引 4: face -> (0,255,0) 绿色
|
||||
map.put("face", 0xFF00FF00);
|
||||
// 索引 5: skin -> (255,255,0) 黄色
|
||||
map.put("skin", 0xFFFFFF00);
|
||||
// 索引 6: clothes -> (255,0,255) 紫色
|
||||
map.put("clothes", 0xFFFF00FF);
|
||||
|
||||
return map;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取类别索引到名称的映射
|
||||
*/
|
||||
public static Map<Integer, String> getIndexToLabelMap() {
|
||||
List<String> labels = defaultLabels();
|
||||
Map<Integer, String> map = new HashMap<>();
|
||||
for (int i = 0; i < labels.size(); i++) {
|
||||
map.put(i, labels.get(i));
|
||||
}
|
||||
return map;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,426 @@
|
||||
package com.chuangzhou.vivid2D.ai.anime_face_segmentation;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.Method;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* AnimeModelWrapper - 专门为 Anime-Face-Segmentation 模型封装的 Wrapper
|
||||
*/
|
||||
public class AnimeModelWrapper implements AutoCloseable {
|
||||
|
||||
private final AnimeSegmenter segmenter;
|
||||
private final List<String> labels; // index -> name
|
||||
private final Map<String, Integer> palette; // name -> ARGB
|
||||
|
||||
private AnimeModelWrapper(AnimeSegmenter segmenter, List<String> labels, Map<String, Integer> palette) {
|
||||
this.segmenter = segmenter;
|
||||
this.labels = labels;
|
||||
this.palette = palette;
|
||||
}
|
||||
|
||||
/**
|
||||
* 加载模型
|
||||
*/
|
||||
public static AnimeModelWrapper load(Path modelDir) throws Exception {
|
||||
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(AnimeLabelPalette::defaultLabels);
|
||||
AnimeSegmenter segmenter = new AnimeSegmenter(modelDir, labels);
|
||||
Map<String, Integer> palette = AnimeLabelPalette.defaultPalette();
|
||||
return new AnimeModelWrapper(segmenter, labels, palette);
|
||||
}
|
||||
|
||||
public List<String> getLabels() {
|
||||
return Collections.unmodifiableList(labels);
|
||||
}
|
||||
|
||||
public Map<String, Integer> getPalette() {
|
||||
return Collections.unmodifiableMap(palette);
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接返回分割结果(在丢给底层 segmenter 前会做通用预处理:RGB 转换 + 等比 letterbox 缩放到模型输入尺寸)
|
||||
*/
|
||||
public AnimeSegmentationResult segment(File inputImage) throws Exception {
|
||||
File pre = null;
|
||||
try {
|
||||
pre = preprocessAndSave(inputImage);
|
||||
// 将预处理后的临时文件丢给底层 segmenter
|
||||
return segmenter.segment(pre);
|
||||
} finally {
|
||||
if (pre != null && pre.exists()) {
|
||||
try { Files.deleteIfExists(pre.toPath()); } catch (Exception ignore) {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 分割并保存结果
|
||||
*/
|
||||
public Map<String, ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
|
||||
if (!Files.exists(outDir)) {
|
||||
Files.createDirectories(outDir);
|
||||
}
|
||||
|
||||
AnimeSegmentationResult res = segment(inputImage);
|
||||
BufferedImage original = ImageIO.read(inputImage);
|
||||
BufferedImage maskImage = res.getMaskImage();
|
||||
|
||||
int maskW = maskImage.getWidth();
|
||||
int maskH = maskImage.getHeight();
|
||||
|
||||
// 解析 targets
|
||||
Set<String> realTargets = parseTargetsSet(targets);
|
||||
Map<String, ResultFiles> saved = new LinkedHashMap<>();
|
||||
|
||||
for (String target : realTargets) {
|
||||
if (!palette.containsKey(target)) {
|
||||
// 尝试忽略大小写匹配
|
||||
String finalTarget = target;
|
||||
Optional<String> matched = palette.keySet().stream()
|
||||
.filter(k -> k.equalsIgnoreCase(finalTarget))
|
||||
.findFirst();
|
||||
if (matched.isPresent()) target = matched.get();
|
||||
else {
|
||||
System.err.println("Warning: unknown label '" + target + "' - skip.");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
int targetColor = palette.get(target);
|
||||
|
||||
// 1) 生成透明背景的二值掩码(只保留 target 像素)
|
||||
BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB);
|
||||
for (int y = 0; y < maskH; y++) {
|
||||
for (int x = 0; x < maskW; x++) {
|
||||
int c = maskImage.getRGB(x, y);
|
||||
if (c == targetColor) {
|
||||
partMask.setRGB(x, y, targetColor | 0xFF000000); // 保证不透明
|
||||
} else {
|
||||
partMask.setRGB(x, y, 0x00000000);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2) 将 mask 缩放到与原图一致(如果需要),并生成 overlay(半透明)
|
||||
BufferedImage maskResized = partMask;
|
||||
if (original.getWidth() != maskW || original.getHeight() != maskH) {
|
||||
maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
Graphics2D g = maskResized.createGraphics();
|
||||
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
|
||||
g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null);
|
||||
g.dispose();
|
||||
}
|
||||
|
||||
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
Graphics2D g2 = overlay.createGraphics();
|
||||
g2.drawImage(original, 0, 0, null);
|
||||
// 半透明颜色(alpha = 0x88)
|
||||
int rgbOnly = (targetColor & 0x00FFFFFF);
|
||||
int translucent = (0x88 << 24) | rgbOnly;
|
||||
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
for (int y = 0; y < colorOverlay.getHeight(); y++) {
|
||||
for (int x = 0; x < colorOverlay.getWidth(); x++) {
|
||||
int mc = maskResized.getRGB(x, y);
|
||||
if ((mc & 0x00FFFFFF) == (targetColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) {
|
||||
colorOverlay.setRGB(x, y, translucent);
|
||||
} else {
|
||||
colorOverlay.setRGB(x, y, 0x00000000);
|
||||
}
|
||||
}
|
||||
}
|
||||
g2.drawImage(colorOverlay, 0, 0, null);
|
||||
g2.dispose();
|
||||
|
||||
// 保存
|
||||
String safe = safeFileName(target);
|
||||
File maskOut = outDir.resolve(safe + "_mask.png").toFile();
|
||||
File overlayOut = outDir.resolve(safe + "_overlay.png").toFile();
|
||||
|
||||
ImageIO.write(maskResized, "png", maskOut);
|
||||
ImageIO.write(overlay, "png", overlayOut);
|
||||
|
||||
saved.put(target, new ResultFiles(maskOut, overlayOut));
|
||||
}
|
||||
|
||||
return saved;
|
||||
}
|
||||
|
||||
private static String safeFileName(String s) {
|
||||
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
|
||||
}
|
||||
|
||||
private static Set<String> parseTargetsSet(Set<String> in) {
|
||||
if (in == null || in.isEmpty()) return Collections.emptySet();
|
||||
// 若包含单个 "all"
|
||||
if (in.size() == 1) {
|
||||
String only = in.iterator().next();
|
||||
if ("all".equalsIgnoreCase(only.trim())) {
|
||||
// 返回所有标签
|
||||
return new LinkedHashSet<>(AnimeLabelPalette.defaultLabels());
|
||||
}
|
||||
}
|
||||
// 直接返回 trim 后的集合
|
||||
Set<String> out = new LinkedHashSet<>();
|
||||
for (String s : in) {
|
||||
if (s != null) out.add(s.trim());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* 专门提取眼睛的方法(在丢给底层 segmenter 前做预处理)
|
||||
*/
|
||||
public ResultFiles extractEyes(File inputImage, Path outDir) throws Exception {
|
||||
if (!Files.exists(outDir)) {
|
||||
Files.createDirectories(outDir);
|
||||
}
|
||||
|
||||
File pre = null;
|
||||
BufferedImage eyes;
|
||||
try {
|
||||
pre = preprocessAndSave(inputImage);
|
||||
eyes = segmenter.extractEyes(pre);
|
||||
} finally {
|
||||
if (pre != null && pre.exists()) {
|
||||
try { Files.deleteIfExists(pre.toPath()); } catch (Exception ignore) {}
|
||||
}
|
||||
}
|
||||
|
||||
File eyesMask = outDir.resolve("eyes_mask.png").toFile();
|
||||
ImageIO.write(eyes, "png", eyesMask);
|
||||
|
||||
// 创建眼睛的 overlay(原有逻辑,保持不变)
|
||||
BufferedImage original = ImageIO.read(inputImage);
|
||||
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
Graphics2D g2 = overlay.createGraphics();
|
||||
g2.drawImage(original, 0, 0, null);
|
||||
|
||||
// 缩放眼睛掩码到原图尺寸
|
||||
BufferedImage eyesResized = eyes;
|
||||
if (original.getWidth() != eyes.getWidth() || original.getHeight() != eyes.getHeight()) {
|
||||
eyesResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
Graphics2D g = eyesResized.createGraphics();
|
||||
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
|
||||
g.drawImage(eyes, 0, 0, original.getWidth(), original.getHeight(), null);
|
||||
g.dispose();
|
||||
}
|
||||
|
||||
int eyeColor = palette.getOrDefault("eye", 0xFF00FF); // 若没有 eye,给个显眼默认色
|
||||
int rgbOnly = (eyeColor & 0x00FFFFFF);
|
||||
int translucent = (0x88 << 24) | rgbOnly;
|
||||
|
||||
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
for (int y = 0; y < colorOverlay.getHeight(); y++) {
|
||||
for (int x = 0; x < colorOverlay.getWidth(); x++) {
|
||||
int mc = eyesResized.getRGB(x, y);
|
||||
if ((mc & 0x00FFFFFF) == (eyeColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) {
|
||||
colorOverlay.setRGB(x, y, translucent);
|
||||
} else {
|
||||
colorOverlay.setRGB(x, y, 0x00000000);
|
||||
}
|
||||
}
|
||||
}
|
||||
g2.drawImage(colorOverlay, 0, 0, null);
|
||||
g2.dispose();
|
||||
|
||||
File eyesOverlay = outDir.resolve("eyes_overlay.png").toFile();
|
||||
ImageIO.write(overlay, "png", eyesOverlay);
|
||||
|
||||
return new ResultFiles(eyesMask, eyesOverlay);
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭底层资源
|
||||
*/
|
||||
@Override
|
||||
public void close() {
|
||||
try {
|
||||
segmenter.close();
|
||||
} catch (Exception ignore) {}
|
||||
}
|
||||
|
||||
/**
|
||||
* 存放结果文件路径
|
||||
*/
|
||||
public static class ResultFiles {
|
||||
private final File maskFile;
|
||||
private final File overlayFile;
|
||||
|
||||
public ResultFiles(File maskFile, File overlayFile) {
|
||||
this.maskFile = maskFile;
|
||||
this.overlayFile = overlayFile;
|
||||
}
|
||||
|
||||
public File getMaskFile() {
|
||||
return maskFile;
|
||||
}
|
||||
|
||||
public File getOverlayFile() {
|
||||
return overlayFile;
|
||||
}
|
||||
}
|
||||
|
||||
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
|
||||
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
|
||||
Path syn = modelDir.resolve("synset.txt");
|
||||
if (Files.exists(syn)) {
|
||||
try {
|
||||
List<String> lines = Files.readAllLines(syn);
|
||||
List<String> cleaned = new ArrayList<>();
|
||||
for (String l : lines) {
|
||||
String s = l.trim();
|
||||
if (!s.isEmpty()) cleaned.add(s);
|
||||
}
|
||||
if (!cleaned.isEmpty()) return Optional.of(cleaned);
|
||||
} catch (IOException ignore) {}
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
// ========== 新增:预处理并保存到临时文件 ==========
|
||||
private File preprocessAndSave(File inputImage) throws IOException {
|
||||
BufferedImage img = ImageIO.read(inputImage);
|
||||
if (img == null) throw new IOException("无法读取图片: " + inputImage);
|
||||
|
||||
// 转成标准 RGB(去掉 alpha / 保证三通道)
|
||||
BufferedImage rgb = new BufferedImage(img.getWidth(), img.getHeight(), BufferedImage.TYPE_INT_RGB);
|
||||
Graphics2D g = rgb.createGraphics();
|
||||
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
|
||||
g.drawImage(img, 0, 0, null);
|
||||
g.dispose();
|
||||
|
||||
// 获取模型输入尺寸(尝试反射读取,找不到则使用默认 512x512)
|
||||
int[] size = getModelInputSize();
|
||||
int targetW = size[0], targetH = size[1];
|
||||
|
||||
// 等比缩放并居中填充(letterbox),背景用白色
|
||||
double scale = Math.min((double) targetW / rgb.getWidth(), (double) targetH / rgb.getHeight());
|
||||
int newW = Math.max(1, (int) Math.round(rgb.getWidth() * scale));
|
||||
int newH = Math.max(1, (int) Math.round(rgb.getHeight() * scale));
|
||||
|
||||
BufferedImage resized = new BufferedImage(targetW, targetH, BufferedImage.TYPE_INT_RGB);
|
||||
Graphics2D g2 = resized.createGraphics();
|
||||
g2.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
|
||||
g2.setColor(Color.WHITE);
|
||||
g2.fillRect(0, 0, targetW, targetH);
|
||||
int x = (targetW - newW) / 2;
|
||||
int y = (targetH - newH) / 2;
|
||||
g2.drawImage(rgb, x, y, newW, newH, null);
|
||||
g2.dispose();
|
||||
|
||||
// 保存为临时 PNG 文件(确保无压缩失真)
|
||||
File tmp = Files.createTempFile("anime_pre_", ".png").toFile();
|
||||
ImageIO.write(resized, "png", tmp);
|
||||
return tmp;
|
||||
}
|
||||
|
||||
// ========== 新增:尝试通过反射从 segmenter 上读取模型输入尺寸 ==========
|
||||
private int[] getModelInputSize() {
|
||||
// 默认值
|
||||
int defaultSize = 512;
|
||||
int w = defaultSize, h = defaultSize;
|
||||
|
||||
try {
|
||||
Class<?> cls = segmenter.getClass();
|
||||
|
||||
// 尝试方法 getInputWidth/getInputHeight
|
||||
try {
|
||||
Method mw = cls.getMethod("getInputWidth");
|
||||
Method mh = cls.getMethod("getInputHeight");
|
||||
Object ow = mw.invoke(segmenter);
|
||||
Object oh = mh.invoke(segmenter);
|
||||
if (ow instanceof Number && oh instanceof Number) {
|
||||
int iw = ((Number) ow).intValue();
|
||||
int ih = ((Number) oh).intValue();
|
||||
if (iw > 0 && ih > 0) {
|
||||
return new int[]{iw, ih};
|
||||
}
|
||||
}
|
||||
} catch (NoSuchMethodException ignored) {}
|
||||
|
||||
// 尝试方法 getInputSize 返回 int[] 或 Dimension
|
||||
try {
|
||||
Method ms = cls.getMethod("getInputSize");
|
||||
Object os = ms.invoke(segmenter);
|
||||
if (os instanceof int[] && ((int[]) os).length >= 2) {
|
||||
int iw = ((int[]) os)[0];
|
||||
int ih = ((int[]) os)[1];
|
||||
if (iw > 0 && ih > 0) return new int[]{iw, ih};
|
||||
} else if (os != null) {
|
||||
// 处理 java.awt.Dimension
|
||||
try {
|
||||
Method gw = os.getClass().getMethod("getWidth");
|
||||
Method gh = os.getClass().getMethod("getHeight");
|
||||
Object ow2 = gw.invoke(os);
|
||||
Object oh2 = gh.invoke(os);
|
||||
if (ow2 instanceof Number && oh2 instanceof Number) {
|
||||
int iw = ((Number) ow2).intValue();
|
||||
int ih = ((Number) oh2).intValue();
|
||||
if (iw > 0 && ih > 0) return new int[]{iw, ih};
|
||||
}
|
||||
} catch (Exception ignored2) {}
|
||||
}
|
||||
} catch (NoSuchMethodException ignored) {}
|
||||
|
||||
// 尝试字段 inputWidth/inputHeight
|
||||
try {
|
||||
try {
|
||||
java.lang.reflect.Field fw = cls.getDeclaredField("inputWidth");
|
||||
java.lang.reflect.Field fh = cls.getDeclaredField("inputHeight");
|
||||
fw.setAccessible(true); fh.setAccessible(true);
|
||||
Object ow = fw.get(segmenter);
|
||||
Object oh = fh.get(segmenter);
|
||||
if (ow instanceof Number && oh instanceof Number) {
|
||||
int iw = ((Number) ow).intValue();
|
||||
int ih = ((Number) oh).intValue();
|
||||
if (iw > 0 && ih > 0) return new int[]{iw, ih};
|
||||
}
|
||||
} catch (NoSuchFieldException ignoredField) {}
|
||||
} catch (Exception ignored) {}
|
||||
|
||||
} catch (Exception ignored) {
|
||||
// 任何反射异常都回退到默认值
|
||||
}
|
||||
|
||||
return new int[]{w, h};
|
||||
}
|
||||
|
||||
/* ================= convenience 主方法(快速测试) ================= */
|
||||
public static void main(String[] args) throws Exception {
|
||||
if (args.length < 4) {
|
||||
System.out.println("用法: AnimeModelWrapper <modelDir> <inputImage> <outDir> <targetsCommaOrAll>");
|
||||
System.out.println("示例: AnimeModelWrapper ./anime_unet.pt input.jpg outDir eye,face");
|
||||
System.out.println("标签: " + AnimeLabelPalette.defaultLabels());
|
||||
return;
|
||||
}
|
||||
Path modelDir = Path.of(args[0]);
|
||||
File input = new File(args[1]);
|
||||
Path out = Path.of(args[2]);
|
||||
String targetsArg = args[3];
|
||||
|
||||
Set<String> targets;
|
||||
if ("all".equalsIgnoreCase(targetsArg.trim())) {
|
||||
targets = new LinkedHashSet<>(AnimeLabelPalette.defaultLabels());
|
||||
} else {
|
||||
String[] parts = targetsArg.split(",");
|
||||
targets = new LinkedHashSet<>();
|
||||
for (String p : parts) {
|
||||
if (!p.trim().isEmpty()) targets.add(p.trim());
|
||||
}
|
||||
}
|
||||
|
||||
try (AnimeModelWrapper wrapper = AnimeModelWrapper.load(modelDir)) {
|
||||
Map<String, ResultFiles> m = wrapper.segmentAndSave(input, targets, out);
|
||||
m.forEach((k, v) -> {
|
||||
System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath()));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package com.chuangzhou.vivid2D.ai.anime_face_segmentation;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 动漫分割结果容器
|
||||
*/
|
||||
public class AnimeSegmentationResult {
|
||||
// 分割掩码图(每个像素的颜色为对应类别颜色)
|
||||
private final BufferedImage maskImage;
|
||||
|
||||
// 分割概率图(每个像素的类别概率分布)
|
||||
private final float[][][] probabilityMap;
|
||||
|
||||
// 类别索引 -> 类别名称
|
||||
private final Map<Integer, String> labels;
|
||||
|
||||
// 类别名称 -> ARGB 颜色
|
||||
private final Map<String, Integer> palette;
|
||||
|
||||
public AnimeSegmentationResult(BufferedImage maskImage, float[][][] probabilityMap,
|
||||
Map<Integer, String> labels, Map<String, Integer> palette) {
|
||||
this.maskImage = maskImage;
|
||||
this.probabilityMap = probabilityMap;
|
||||
this.labels = labels;
|
||||
this.palette = palette;
|
||||
}
|
||||
|
||||
public BufferedImage getMaskImage() {
|
||||
return maskImage;
|
||||
}
|
||||
|
||||
public float[][][] getProbabilityMap() {
|
||||
return probabilityMap;
|
||||
}
|
||||
|
||||
public Map<Integer, String> getLabels() {
|
||||
return labels;
|
||||
}
|
||||
|
||||
public Map<String, Integer> getPalette() {
|
||||
return palette;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取指定类别的概率图
|
||||
*/
|
||||
public float[][] getClassProbability(int classIndex) {
|
||||
if (probabilityMap == null) return null;
|
||||
int height = probabilityMap.length;
|
||||
int width = probabilityMap[0].length;
|
||||
float[][] result = new float[height][width];
|
||||
for (int y = 0; y < height; y++) {
|
||||
for (int x = 0; x < width; x++) {
|
||||
result[y][x] = probabilityMap[y][x][classIndex];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
package com.chuangzhou.vivid2D.ai.anime_face_segmentation;
|
||||
|
||||
import ai.djl.MalformedModelException;
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDList;
|
||||
import ai.djl.ndarray.NDManager;
|
||||
import ai.djl.ndarray.types.DataType;
|
||||
import ai.djl.ndarray.types.Shape;
|
||||
import ai.djl.repository.zoo.Criteria;
|
||||
import ai.djl.repository.zoo.ModelNotFoundException;
|
||||
import ai.djl.repository.zoo.ZooModel;
|
||||
import ai.djl.translate.Batchifier;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import ai.djl.translate.Translator;
|
||||
import ai.djl.translate.TranslatorContext;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* AnimeSegmenter: 专门为 Anime-Face-Segmentation UNet 模型设计的分割器
|
||||
*/
|
||||
public class AnimeSegmenter implements AutoCloseable {
|
||||
|
||||
// 模型默认输入大小(与训练时一致)。若模型不同可以修改为实际值或让 caller 通过构造参数传入。
|
||||
private static final int MODEL_INPUT_W = 512;
|
||||
private static final int MODEL_INPUT_H = 512;
|
||||
|
||||
// 内部类,用于从Translator安全地传出数据
|
||||
public static class SegmentationData {
|
||||
final int[] indices; // 类别索引 [H * W]
|
||||
final float[][][] probMap; // 概率图 [H][W][C]
|
||||
final long[] shape; // 形状 [H, W]
|
||||
|
||||
public SegmentationData(int[] indices, float[][][] probMap, long[] shape) {
|
||||
this.indices = indices;
|
||||
this.probMap = probMap;
|
||||
this.shape = shape;
|
||||
}
|
||||
}
|
||||
|
||||
private final ZooModel<Image, SegmentationData> modelWrapper;
|
||||
private final Predictor<Image, SegmentationData> predictor;
|
||||
private final List<String> labels;
|
||||
private final Map<String, Integer> palette;
|
||||
|
||||
public AnimeSegmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
|
||||
this.labels = new ArrayList<>(labels);
|
||||
this.palette = AnimeLabelPalette.defaultPalette();
|
||||
|
||||
Translator<Image, SegmentationData> translator = new Translator<Image, SegmentationData>() {
|
||||
@Override
|
||||
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||
NDManager manager = ctx.getNDManager();
|
||||
|
||||
// 如果图片已经是模型输入大小则不再 resize(避免重复缩放导致失真)
|
||||
Image toUse = input;
|
||||
if (!(input.getWidth() == MODEL_INPUT_W && input.getHeight() == MODEL_INPUT_H)) {
|
||||
toUse = input.resize(MODEL_INPUT_W, MODEL_INPUT_H, true);
|
||||
}
|
||||
|
||||
// 转换为NDArray并预处理
|
||||
NDArray array = toUse.toNDArray(manager);
|
||||
// DJL 返回 HWC 格式数组,转换为 CHW,并标准化到 [0,1]
|
||||
array = array.transpose(2, 0, 1) // HWC -> CHW
|
||||
.toType(DataType.FLOAT32, false)
|
||||
.div(255f) // 归一化到[0,1]
|
||||
.expandDims(0); // 添加batch维度 [1,3,H,W]
|
||||
|
||||
return new NDList(array);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SegmentationData processOutput(TranslatorContext ctx, NDList list) {
|
||||
if (list == null || list.isEmpty()) {
|
||||
throw new IllegalStateException("Model did not return any output.");
|
||||
}
|
||||
|
||||
NDArray output = list.get(0); // 期望形状 [1,C,H,W] 或 [1,C,W,H](以训练时一致为准)
|
||||
|
||||
// 确保维度:把 output 视作 [1, C, H, W]
|
||||
Shape outShape = output.getShape();
|
||||
if (outShape.dimension() != 4) {
|
||||
throw new IllegalStateException("Unexpected output shape: " + outShape);
|
||||
}
|
||||
|
||||
// 1. 获取类别索引(argmax) -> [H, W]
|
||||
NDArray squeezed = output.squeeze(0); // [C,H,W]
|
||||
NDArray classMap = squeezed.argMax(0).toType(DataType.INT32, false); // argMax over channel维度
|
||||
|
||||
// 2. 获取概率图(softmax 输出或模型已经输出概率),转换为 [H,W,C]
|
||||
NDArray probabilities = squeezed.transpose(1, 2, 0) // [H,W,C]
|
||||
.toType(DataType.FLOAT32, false);
|
||||
|
||||
// 3. 转换为Java数组
|
||||
long[] shape = classMap.getShape().getShape(); // [H, W]
|
||||
int[] indices = classMap.toIntArray();
|
||||
long[] probShape = probabilities.getShape().getShape(); // [H, W, C]
|
||||
int height = (int) probShape[0];
|
||||
int width = (int) probShape[1];
|
||||
int classes = (int) probShape[2];
|
||||
float[] flatProbMap = probabilities.toFloatArray();
|
||||
float[][][] probMap = new float[height][width][classes];
|
||||
for (int i = 0; i < height; i++) {
|
||||
for (int j = 0; j < width; j++) {
|
||||
for (int k = 0; k < classes; k++) {
|
||||
int index = i * width * classes + j * classes + k;
|
||||
probMap[i][j][k] = flatProbMap[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return new SegmentationData(indices, probMap, shape);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Batchifier getBatchifier() {
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
Criteria<Image, SegmentationData> criteria = Criteria.builder()
|
||||
.setTypes(Image.class, SegmentationData.class)
|
||||
.optModelPath(modelDir)
|
||||
.optEngine("PyTorch")
|
||||
.optTranslator(translator)
|
||||
.build();
|
||||
|
||||
this.modelWrapper = criteria.loadModel();
|
||||
this.predictor = modelWrapper.newPredictor();
|
||||
}
|
||||
|
||||
public AnimeSegmentationResult segment(File imgFile) throws TranslateException, IOException {
|
||||
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
|
||||
|
||||
// 预测并获取分割数据
|
||||
SegmentationData data = predictor.predict(img);
|
||||
|
||||
long[] shp = data.shape;
|
||||
int[] indices = data.indices;
|
||||
float[][][] probMap = data.probMap;
|
||||
|
||||
int height = (int) shp[0];
|
||||
int width = (int) shp[1];
|
||||
|
||||
// 创建掩码图像
|
||||
BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
|
||||
Map<Integer, String> labelsMap = AnimeLabelPalette.getIndexToLabelMap();
|
||||
|
||||
for (int y = 0; y < height; y++) {
|
||||
for (int x = 0; x < width; x++) {
|
||||
int idx = indices[y * width + x];
|
||||
String label = labelsMap.getOrDefault(idx, "unknown");
|
||||
int argb = palette.getOrDefault(label, 0xFF00FF00); // 默认绿色
|
||||
mask.setRGB(x, y, argb);
|
||||
}
|
||||
}
|
||||
|
||||
return new AnimeSegmentationResult(mask, probMap, labelsMap, palette);
|
||||
}
|
||||
|
||||
/**
|
||||
* 专门针对眼睛的分割方法
|
||||
*/
|
||||
public BufferedImage extractEyes(File imgFile) throws TranslateException, IOException {
|
||||
AnimeSegmentationResult result = segment(imgFile);
|
||||
BufferedImage mask = result.getMaskImage();
|
||||
BufferedImage eyeMask = new BufferedImage(mask.getWidth(), mask.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
|
||||
int eyeColor = palette.get("eye");
|
||||
|
||||
for (int y = 0; y < mask.getHeight(); y++) {
|
||||
for (int x = 0; x < mask.getWidth(); x++) {
|
||||
int rgb = mask.getRGB(x, y);
|
||||
if (rgb == eyeColor) {
|
||||
eyeMask.setRGB(x, y, eyeColor);
|
||||
} else {
|
||||
eyeMask.setRGB(x, y, 0x00000000); // 透明
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return eyeMask;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
try {
|
||||
predictor.close();
|
||||
} catch (Exception ignore) {
|
||||
}
|
||||
try {
|
||||
modelWrapper.close();
|
||||
} catch (Exception ignore) {
|
||||
}
|
||||
}
|
||||
|
||||
// 测试主函数
|
||||
public static void main(String[] args) throws Exception {
|
||||
if (args.length < 3) {
|
||||
System.out.println("用法: java AnimeSegmenter <modelDir> <inputImage> <outputMaskPng>");
|
||||
System.out.println("示例: java AnimeSegmenter ./anime_unet.pt input.jpg output.png");
|
||||
return;
|
||||
}
|
||||
Path modelDir = Path.of(args[0]);
|
||||
File input = new File(args[1]);
|
||||
File out = new File(args[2]);
|
||||
|
||||
List<String> labels = AnimeLabelPalette.defaultLabels();
|
||||
|
||||
try (AnimeSegmenter segmenter = new AnimeSegmenter(modelDir, labels)) {
|
||||
AnimeSegmentationResult res = segmenter.segment(input);
|
||||
ImageIO.write(res.getMaskImage(), "png", out);
|
||||
System.out.println("动漫分割掩码已保存到: " + out.getAbsolutePath());
|
||||
|
||||
// 额外保存眼睛分割结果
|
||||
BufferedImage eyes = segmenter.extractEyes(input);
|
||||
File eyesOut = new File(out.getParent(), "eyes_" + out.getName());
|
||||
ImageIO.write(eyes, "png", eyesOut);
|
||||
System.out.println("眼睛分割结果已保存到: " + eyesOut.getAbsolutePath());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package com.chuangzhou.vivid2D.ai.anime_segmentation;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* 动漫分割模型的标签和颜色调色板。
|
||||
* 这是一个二分类模型:背景和前景(动漫人物)
|
||||
*/
|
||||
public class Anime2LabelPalette {
|
||||
|
||||
/**
|
||||
* 动漫分割模型的标准标签(2个类别)
|
||||
*/
|
||||
public static List<String> defaultLabels() {
|
||||
return Arrays.asList(
|
||||
"background", // 0
|
||||
"foreground" // 1
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* 返回动漫分割模型的调色板
|
||||
*/
|
||||
public static Map<String, Integer> defaultPalette() {
|
||||
Map<String, Integer> map = new HashMap<>();
|
||||
// 索引 0: background - 黑色
|
||||
map.put("background", 0xFF000000);
|
||||
// 索引 1: foreground - 白色
|
||||
map.put("foreground", 0xFFFFFFFF);
|
||||
|
||||
return map;
|
||||
}
|
||||
|
||||
/**
|
||||
* 专门为动漫分割模型设计的调色板(可视化更友好)
|
||||
*/
|
||||
public static Map<String, Integer> animeSegmentationPalette() {
|
||||
Map<String, Integer> map = new HashMap<>();
|
||||
// 背景 - 透明
|
||||
map.put("background", 0x00000000);
|
||||
// 前景 - 红色(用于可视化)
|
||||
map.put("foreground", 0xFFFF0000);
|
||||
|
||||
return map;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
package com.chuangzhou.vivid2D.ai.anime_segmentation;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.*;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Anime2ModelWrapper - 对动漫分割模型的封装
|
||||
*
|
||||
* 用法示例:
|
||||
* Anime2ModelWrapper wrapper = Anime2ModelWrapper.load(Paths.get("/path/to/modelDir"));
|
||||
* Map<String, Anime2ModelWrapper.ResultFiles> out = wrapper.segmentAndSave(
|
||||
* new File("input.jpg"),
|
||||
* Set.of("foreground"), // 动漫分割主要关注前景
|
||||
* Paths.get("outDir")
|
||||
* );
|
||||
* wrapper.close();
|
||||
*/
|
||||
public class Anime2ModelWrapper implements AutoCloseable {
|
||||
|
||||
private final Anime2Segmenter segmenter;
|
||||
private final List<String> labels; // index -> name
|
||||
private final Map<String, Integer> palette; // name -> ARGB
|
||||
|
||||
private Anime2ModelWrapper(Anime2Segmenter segmenter, List<String> labels, Map<String, Integer> palette) {
|
||||
this.segmenter = segmenter;
|
||||
this.labels = labels;
|
||||
this.palette = palette;
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 Anime2Segmenter 实例
|
||||
*/
|
||||
public static Anime2ModelWrapper load(Path modelDir) throws Exception {
|
||||
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels);
|
||||
Anime2Segmenter s = new Anime2Segmenter(modelDir, labels);
|
||||
Map<String, Integer> palette = Anime2LabelPalette.animeSegmentationPalette();
|
||||
return new Anime2ModelWrapper(s, labels, palette);
|
||||
}
|
||||
|
||||
public List<String> getLabels() {
|
||||
return Collections.unmodifiableList(labels);
|
||||
}
|
||||
|
||||
public Map<String, Integer> getPalette() {
|
||||
return Collections.unmodifiableMap(palette);
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接返回分割结果
|
||||
*/
|
||||
public Anime2SegmentationResult segment(File inputImage) throws Exception {
|
||||
return segmenter.segment(inputImage);
|
||||
}
|
||||
|
||||
/**
|
||||
* 把指定 targets(标签名集合)从输入图片中分割并保存到 outDir
|
||||
*/
|
||||
public Map<String, ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
|
||||
if (!Files.exists(outDir)) {
|
||||
Files.createDirectories(outDir);
|
||||
}
|
||||
|
||||
Anime2SegmentationResult res = segment(inputImage);
|
||||
BufferedImage original = ImageIO.read(inputImage);
|
||||
BufferedImage maskImage = res.getMaskImage();
|
||||
|
||||
int maskW = maskImage.getWidth();
|
||||
int maskH = maskImage.getHeight();
|
||||
|
||||
// 解析 targets
|
||||
Set<String> realTargets = parseTargetsSet(targets);
|
||||
Map<String, ResultFiles> saved = new LinkedHashMap<>();
|
||||
|
||||
for (String target : realTargets) {
|
||||
if (!palette.containsKey(target)) {
|
||||
System.err.println("Warning: unknown label '" + target + "' - skip.");
|
||||
continue;
|
||||
}
|
||||
|
||||
int targetColor = palette.get(target);
|
||||
|
||||
// 1) 生成透明背景的二值掩码(只保留 target 像素)
|
||||
BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB);
|
||||
for (int y = 0; y < maskH; y++) {
|
||||
for (int x = 0; x < maskW; x++) {
|
||||
int c = maskImage.getRGB(x, y);
|
||||
if (c == targetColor) {
|
||||
partMask.setRGB(x, y, targetColor | 0xFF000000); // 保证不透明
|
||||
} else {
|
||||
partMask.setRGB(x, y, 0x00000000);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2) 将 mask 缩放到与原图一致
|
||||
BufferedImage maskResized = partMask;
|
||||
if (original.getWidth() != maskW || original.getHeight() != maskH) {
|
||||
maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
Graphics2D g = maskResized.createGraphics();
|
||||
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
|
||||
g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null);
|
||||
g.dispose();
|
||||
}
|
||||
|
||||
// 3) 生成叠加图
|
||||
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
Graphics2D g2 = overlay.createGraphics();
|
||||
g2.drawImage(original, 0, 0, null);
|
||||
|
||||
int rgbOnly = (targetColor & 0x00FFFFFF);
|
||||
int translucent = (0x88 << 24) | rgbOnly;
|
||||
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
for (int y = 0; y < colorOverlay.getHeight(); y++) {
|
||||
for (int x = 0; x < colorOverlay.getWidth(); x++) {
|
||||
int mc = maskResized.getRGB(x, y);
|
||||
if ((mc & 0x00FFFFFF) == (targetColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) {
|
||||
colorOverlay.setRGB(x, y, translucent);
|
||||
} else {
|
||||
colorOverlay.setRGB(x, y, 0x00000000);
|
||||
}
|
||||
}
|
||||
}
|
||||
g2.drawImage(colorOverlay, 0, 0, null);
|
||||
g2.dispose();
|
||||
|
||||
// 保存
|
||||
String safe = safeFileName(target);
|
||||
File maskOut = outDir.resolve(safe + "_mask.png").toFile();
|
||||
File overlayOut = outDir.resolve(safe + "_overlay.png").toFile();
|
||||
|
||||
ImageIO.write(maskResized, "png", maskOut);
|
||||
ImageIO.write(overlay, "png", overlayOut);
|
||||
|
||||
saved.put(target, new ResultFiles(maskOut, overlayOut));
|
||||
}
|
||||
|
||||
return saved;
|
||||
}
|
||||
|
||||
private static String safeFileName(String s) {
|
||||
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
|
||||
}
|
||||
|
||||
private static Set<String> parseTargetsSet(Set<String> in) {
|
||||
if (in == null || in.isEmpty()) return Collections.emptySet();
|
||||
if (in.size() == 1) {
|
||||
String only = in.iterator().next();
|
||||
if ("all".equalsIgnoreCase(only.trim())) {
|
||||
return Set.of("foreground"); // 动漫分割主要关注前景
|
||||
}
|
||||
}
|
||||
Set<String> out = new LinkedHashSet<>();
|
||||
for (String s : in) {
|
||||
if (s != null) out.add(s.trim());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭底层资源
|
||||
*/
|
||||
@Override
|
||||
public void close() {
|
||||
try {
|
||||
segmenter.close();
|
||||
} catch (Exception ignore) {}
|
||||
}
|
||||
|
||||
/**
|
||||
* 存放结果文件路径
|
||||
*/
|
||||
public static class ResultFiles {
|
||||
private final File maskFile;
|
||||
private final File overlayFile;
|
||||
|
||||
public ResultFiles(File maskFile, File overlayFile) {
|
||||
this.maskFile = maskFile;
|
||||
this.overlayFile = overlayFile;
|
||||
}
|
||||
|
||||
public File getMaskFile() {
|
||||
return maskFile;
|
||||
}
|
||||
|
||||
public File getOverlayFile() {
|
||||
return overlayFile;
|
||||
}
|
||||
}
|
||||
|
||||
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
|
||||
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
|
||||
Path syn = modelDir.resolve("synset.txt");
|
||||
if (Files.exists(syn)) {
|
||||
try {
|
||||
List<String> lines = Files.readAllLines(syn);
|
||||
List<String> cleaned = new ArrayList<>();
|
||||
for (String l : lines) {
|
||||
String s = l.trim();
|
||||
if (!s.isEmpty()) cleaned.add(s);
|
||||
}
|
||||
if (!cleaned.isEmpty()) return Optional.of(cleaned);
|
||||
} catch (IOException ignore) {}
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
/* ================= convenience 主方法(快速测试) ================= */
|
||||
public static void main(String[] args) throws Exception {
|
||||
if (args.length < 4) {
|
||||
System.out.println("用法: Anime2ModelWrapper <modelDir> <inputImage> <outDir> <targetsCommaOrAll>");
|
||||
System.out.println("示例: Anime2ModelWrapper /models/anime_seg /images/in.jpg outDir foreground");
|
||||
return;
|
||||
}
|
||||
Path modelDir = Path.of(args[0]);
|
||||
File input = new File(args[1]);
|
||||
Path out = Path.of(args[2]);
|
||||
String targetsArg = args[3];
|
||||
|
||||
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels);
|
||||
Set<String> targets;
|
||||
if ("all".equalsIgnoreCase(targetsArg.trim())) {
|
||||
targets = new LinkedHashSet<>(labels);
|
||||
} else {
|
||||
String[] parts = targetsArg.split(",");
|
||||
targets = new LinkedHashSet<>();
|
||||
for (String p : parts) {
|
||||
if (!p.trim().isEmpty()) targets.add(p.trim());
|
||||
}
|
||||
}
|
||||
|
||||
try (Anime2ModelWrapper wrapper = Anime2ModelWrapper.load(modelDir)) {
|
||||
Map<String, ResultFiles> m = wrapper.segmentAndSave(input, targets, out);
|
||||
m.forEach((k, v) -> {
|
||||
System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath()));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package com.chuangzhou.vivid2D.ai.anime_segmentation;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 动漫分割结果容器
|
||||
*/
|
||||
public class Anime2SegmentationResult {
|
||||
// 分割掩码图(每个像素的颜色为对应类别颜色)
|
||||
private final BufferedImage maskImage;
|
||||
|
||||
// 类别索引 -> 类别名称
|
||||
private final Map<Integer, String> labels;
|
||||
|
||||
// 类别名称 -> ARGB 颜色
|
||||
private final Map<String, Integer> palette;
|
||||
|
||||
public Anime2SegmentationResult(BufferedImage maskImage, Map<Integer, String> labels, Map<String, Integer> palette) {
|
||||
this.maskImage = maskImage;
|
||||
this.labels = labels;
|
||||
this.palette = palette;
|
||||
}
|
||||
|
||||
public BufferedImage getMaskImage() {
|
||||
return maskImage;
|
||||
}
|
||||
|
||||
public Map<Integer, String> getLabels() {
|
||||
return labels;
|
||||
}
|
||||
|
||||
public Map<String, Integer> getPalette() {
|
||||
return palette;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,175 @@
|
||||
package com.chuangzhou.vivid2D.ai.anime_segmentation;
|
||||
|
||||
import ai.djl.MalformedModelException;
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDList;
|
||||
import ai.djl.ndarray.NDManager;
|
||||
import ai.djl.ndarray.types.DataType;
|
||||
import ai.djl.ndarray.types.Shape;
|
||||
import ai.djl.repository.zoo.Criteria;
|
||||
import ai.djl.repository.zoo.ModelNotFoundException;
|
||||
import ai.djl.repository.zoo.ZooModel;
|
||||
import ai.djl.translate.Batchifier;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import ai.djl.translate.Translator;
|
||||
import ai.djl.translate.TranslatorContext;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Anime2Segmenter: 专门用于动漫分割模型
|
||||
* 处理 anime-segmentation 模型的二值分割输出
|
||||
*/
|
||||
public class Anime2Segmenter implements AutoCloseable {
|
||||
|
||||
public static class SegmentationData {
|
||||
final int[] indices;
|
||||
final long[] shape;
|
||||
|
||||
public SegmentationData(int[] indices, long[] shape) {
|
||||
this.indices = indices;
|
||||
this.shape = shape;
|
||||
}
|
||||
}
|
||||
|
||||
private final ZooModel<Image, SegmentationData> modelWrapper;
|
||||
private final Predictor<Image, SegmentationData> predictor;
|
||||
private final List<String> labels;
|
||||
private final Map<String, Integer> palette;
|
||||
|
||||
public Anime2Segmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
|
||||
this.labels = new ArrayList<>(labels);
|
||||
this.palette = Anime2LabelPalette.animeSegmentationPalette();
|
||||
|
||||
Translator<Image, SegmentationData> translator = new Translator<Image, SegmentationData>() {
|
||||
@Override
|
||||
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||
NDManager manager = ctx.getNDManager();
|
||||
|
||||
// 调整输入图像尺寸到模型期望的大小 (1024x1024)
|
||||
Image resized = input.resize(1024, 1024, true);
|
||||
NDArray array = resized.toNDArray(manager);
|
||||
|
||||
// 转换为 CHW 格式并归一化
|
||||
array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false);
|
||||
array = array.div(255f);
|
||||
array = array.expandDims(0); // 添加batch维度
|
||||
|
||||
return new NDList(array);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SegmentationData processOutput(TranslatorContext ctx, NDList list) {
|
||||
if (list == null || list.isEmpty()) {
|
||||
throw new IllegalStateException("Model did not return any output.");
|
||||
}
|
||||
|
||||
NDArray out = list.get(0);
|
||||
|
||||
// 动漫分割模型输出形状: [1, 1, H, W] - 单通道概率图
|
||||
// 应用sigmoid并二值化
|
||||
NDArray probabilities = out.div(out.neg().exp().add(1));
|
||||
NDArray binaryMask = probabilities.gt(0.5).toType(DataType.INT32, false);
|
||||
|
||||
// 移除batch和channel维度
|
||||
if (binaryMask.getShape().dimension() == 4) {
|
||||
binaryMask = binaryMask.squeeze(0).squeeze(0);
|
||||
}
|
||||
|
||||
// 转换为Java数组
|
||||
long[] finalShape = binaryMask.getShape().getShape();
|
||||
int[] indices = binaryMask.toIntArray();
|
||||
|
||||
return new SegmentationData(indices, finalShape);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Batchifier getBatchifier() {
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
Criteria<Image, SegmentationData> criteria = Criteria.builder()
|
||||
.setTypes(Image.class, SegmentationData.class)
|
||||
.optModelPath(modelDir)
|
||||
.optEngine("PyTorch")
|
||||
.optTranslator(translator)
|
||||
.build();
|
||||
|
||||
this.modelWrapper = criteria.loadModel();
|
||||
this.predictor = modelWrapper.newPredictor();
|
||||
}
|
||||
|
||||
public Anime2SegmentationResult segment(File imgFile) throws TranslateException, IOException {
|
||||
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
|
||||
|
||||
SegmentationData data = predictor.predict(img);
|
||||
|
||||
long[] shp = data.shape;
|
||||
int[] indices = data.indices;
|
||||
|
||||
int height, width;
|
||||
if (shp.length == 2) {
|
||||
height = (int) shp[0];
|
||||
width = (int) shp[1];
|
||||
} else {
|
||||
throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp));
|
||||
}
|
||||
|
||||
// 创建分割掩码
|
||||
BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
|
||||
Map<Integer, String> labelsMap = new HashMap<>();
|
||||
for (int i = 0; i < labels.size(); i++) {
|
||||
labelsMap.put(i, labels.get(i));
|
||||
}
|
||||
|
||||
for (int y = 0; y < height; y++) {
|
||||
for (int x = 0; x < width; x++) {
|
||||
int idx = indices[y * width + x];
|
||||
String label = labelsMap.getOrDefault(idx, "unknown");
|
||||
int argb = palette.getOrDefault(label, 0xFFFF0000);
|
||||
mask.setRGB(x, y, argb);
|
||||
}
|
||||
}
|
||||
|
||||
return new Anime2SegmentationResult(mask, labelsMap, palette);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
try {
|
||||
predictor.close();
|
||||
} catch (Exception ignore) {
|
||||
}
|
||||
try {
|
||||
modelWrapper.close();
|
||||
} catch (Exception ignore) {
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
if (args.length < 3) {
|
||||
System.out.println("用法: java Anime2Segmenter <modelDir> <inputImage> <outputMaskPng>");
|
||||
return;
|
||||
}
|
||||
Path modelDir = Path.of(args[0]);
|
||||
File input = new File(args[1]);
|
||||
File out = new File(args[2]);
|
||||
|
||||
List<String> labels = Anime2LabelPalette.defaultLabels();
|
||||
|
||||
try (Anime2Segmenter s = new Anime2Segmenter(modelDir, labels)) {
|
||||
Anime2SegmentationResult res = s.segment(input);
|
||||
ImageIO.write(res.getMaskImage(), "png", out);
|
||||
System.out.println("动漫分割掩码已保存到: " + out.getAbsolutePath());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,262 @@
|
||||
package com.chuangzhou.vivid2D.ai.anime_segmentation;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.*;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Anime2VividModelWrapper - 对之前 Anime2Segmenter 的封装,提供更便捷的API
|
||||
*
|
||||
* 用法示例:
|
||||
* Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(Paths.get("/path/to/modelDir"));
|
||||
* Map<String, Anime2VividModelWrapper.ResultFiles> out = wrapper.segmentAndSave(
|
||||
* new File("input.jpg"),
|
||||
* Set.of("foreground"), // 动漫分割主要关注前景
|
||||
* Paths.get("outDir")
|
||||
* );
|
||||
* // out contains 每个目标标签对应的 mask+overlay 文件路径
|
||||
* wrapper.close();
|
||||
*/
|
||||
public class Anime2VividModelWrapper implements AutoCloseable {
|
||||
|
||||
private final Anime2Segmenter segmenter;
|
||||
private final List<String> labels; // index -> name
|
||||
private final Map<String, Integer> palette; // name -> ARGB
|
||||
|
||||
private Anime2VividModelWrapper(Anime2Segmenter segmenter, List<String> labels, Map<String, Integer> palette) {
|
||||
this.segmenter = segmenter;
|
||||
this.labels = labels;
|
||||
this.palette = palette;
|
||||
}
|
||||
|
||||
/**
|
||||
* 读取 modelDir/synset.txt(每行一个标签),若不存在则使用 Anime2LabelPalette.defaultLabels()
|
||||
* 并创建 Anime2Segmenter 实例。
|
||||
*/
|
||||
public static Anime2VividModelWrapper load(Path modelDir) throws Exception {
|
||||
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels);
|
||||
Anime2Segmenter s = new Anime2Segmenter(modelDir, labels);
|
||||
Map<String, Integer> palette = Anime2LabelPalette.animeSegmentationPalette();
|
||||
return new Anime2VividModelWrapper(s, labels, palette);
|
||||
}
|
||||
|
||||
public List<String> getLabels() {
|
||||
return Collections.unmodifiableList(labels);
|
||||
}
|
||||
|
||||
public Map<String, Integer> getPalette() {
|
||||
return Collections.unmodifiableMap(palette);
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接返回分割结果(Anime2SegmentationResult)
|
||||
*/
|
||||
public Anime2SegmentationResult segment(File inputImage) throws Exception {
|
||||
return segmenter.segment(inputImage);
|
||||
}
|
||||
|
||||
/**
|
||||
* 把指定 targets(标签名集合)从输入图片中分割并保存到 outDir。
|
||||
* 如果 targets 包含单个元素 "all"(忽略大小写),则保存所有标签。
|
||||
* <p>
|
||||
* 返回值:Map<labelName, ResultFiles>,ResultFiles 包含 maskFile、overlayFile(两个 PNG)
|
||||
*/
|
||||
public Map<String, ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
|
||||
if (!Files.exists(outDir)) {
|
||||
Files.createDirectories(outDir);
|
||||
}
|
||||
|
||||
Anime2SegmentationResult res = segment(inputImage);
|
||||
BufferedImage original = ImageIO.read(inputImage);
|
||||
BufferedImage maskImage = res.getMaskImage();
|
||||
|
||||
int maskW = maskImage.getWidth();
|
||||
int maskH = maskImage.getHeight();
|
||||
|
||||
// 解析 targets
|
||||
Set<String> realTargets = parseTargetsSet(targets);
|
||||
Map<String, ResultFiles> saved = new LinkedHashMap<>();
|
||||
|
||||
for (String target : realTargets) {
|
||||
if (!palette.containsKey(target)) {
|
||||
// 尝试忽略大小写匹配
|
||||
String finalTarget = target;
|
||||
Optional<String> matched = palette.keySet().stream()
|
||||
.filter(k -> k.equalsIgnoreCase(finalTarget))
|
||||
.findFirst();
|
||||
if (matched.isPresent()) target = matched.get();
|
||||
else {
|
||||
System.err.println("Warning: unknown label '" + target + "' - skip.");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
int targetColor = palette.get(target);
|
||||
|
||||
// 1) 生成透明背景的二值掩码(只保留 target 像素)
|
||||
BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB);
|
||||
for (int y = 0; y < maskH; y++) {
|
||||
for (int x = 0; x < maskW; x++) {
|
||||
int c = maskImage.getRGB(x, y);
|
||||
if (c == targetColor) {
|
||||
partMask.setRGB(x, y, targetColor | 0xFF000000); // 保证不透明
|
||||
} else {
|
||||
partMask.setRGB(x, y, 0x00000000);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2) 将 mask 缩放到与原图一致(如果需要),并生成 overlay(半透明)
|
||||
BufferedImage maskResized = partMask;
|
||||
if (original.getWidth() != maskW || original.getHeight() != maskH) {
|
||||
maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
Graphics2D g = maskResized.createGraphics();
|
||||
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
|
||||
g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null);
|
||||
g.dispose();
|
||||
}
|
||||
|
||||
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
Graphics2D g2 = overlay.createGraphics();
|
||||
g2.drawImage(original, 0, 0, null);
|
||||
// 半透明颜色(alpha = 0x88)
|
||||
int rgbOnly = (targetColor & 0x00FFFFFF);
|
||||
int translucent = (0x88 << 24) | rgbOnly;
|
||||
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
for (int y = 0; y < colorOverlay.getHeight(); y++) {
|
||||
for (int x = 0; x < colorOverlay.getWidth(); x++) {
|
||||
int mc = maskResized.getRGB(x, y);
|
||||
if ((mc & 0x00FFFFFF) == (targetColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) {
|
||||
colorOverlay.setRGB(x, y, translucent);
|
||||
} else {
|
||||
colorOverlay.setRGB(x, y, 0x00000000);
|
||||
}
|
||||
}
|
||||
}
|
||||
g2.drawImage(colorOverlay, 0, 0, null);
|
||||
g2.dispose();
|
||||
|
||||
// 保存
|
||||
String safe = safeFileName(target);
|
||||
File maskOut = outDir.resolve(safe + "_mask.png").toFile();
|
||||
File overlayOut = outDir.resolve(safe + "_overlay.png").toFile();
|
||||
|
||||
ImageIO.write(maskResized, "png", maskOut);
|
||||
ImageIO.write(overlay, "png", overlayOut);
|
||||
|
||||
saved.put(target, new ResultFiles(maskOut, overlayOut));
|
||||
}
|
||||
|
||||
return saved;
|
||||
}
|
||||
|
||||
private static String safeFileName(String s) {
|
||||
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
|
||||
}
|
||||
|
||||
private static Set<String> parseTargetsSet(Set<String> in) {
|
||||
if (in == null || in.isEmpty()) return Collections.emptySet();
|
||||
// 若包含单个 "all"
|
||||
if (in.size() == 1) {
|
||||
String only = in.iterator().next();
|
||||
if ("all".equalsIgnoreCase(only.trim())) {
|
||||
// 由调用方自行取 labels(这里返回 sentinel, but caller already checks palette)
|
||||
// For convenience, return a set containing "all" and let caller logic handle it earlier.
|
||||
return Set.of("all");
|
||||
}
|
||||
}
|
||||
// 直接返回 trim 后的小写不变集合(保持用户传入的名字)
|
||||
Set<String> out = new LinkedHashSet<>();
|
||||
for (String s : in) {
|
||||
if (s != null) out.add(s.trim());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭底层资源
|
||||
*/
|
||||
@Override
|
||||
public void close() {
|
||||
try {
|
||||
segmenter.close();
|
||||
} catch (Exception ignore) {}
|
||||
}
|
||||
|
||||
/**
|
||||
* 存放结果文件路径
|
||||
*/
|
||||
public static class ResultFiles {
|
||||
private final File maskFile;
|
||||
private final File overlayFile;
|
||||
|
||||
public ResultFiles(File maskFile, File overlayFile) {
|
||||
this.maskFile = maskFile;
|
||||
this.overlayFile = overlayFile;
|
||||
}
|
||||
|
||||
public File getMaskFile() {
|
||||
return maskFile;
|
||||
}
|
||||
|
||||
public File getOverlayFile() {
|
||||
return overlayFile;
|
||||
}
|
||||
}
|
||||
|
||||
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
|
||||
|
||||
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
|
||||
Path syn = modelDir.resolve("synset.txt");
|
||||
if (Files.exists(syn)) {
|
||||
try {
|
||||
List<String> lines = Files.readAllLines(syn);
|
||||
List<String> cleaned = new ArrayList<>();
|
||||
for (String l : lines) {
|
||||
String s = l.trim();
|
||||
if (!s.isEmpty()) cleaned.add(s);
|
||||
}
|
||||
if (!cleaned.isEmpty()) return Optional.of(cleaned);
|
||||
} catch (IOException ignore) {}
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
/* ================= convenience 主方法(快速测试) ================= */
|
||||
public static void main(String[] args) throws Exception {
|
||||
if (args.length < 4) {
|
||||
System.out.println("用法: Anime2VividModelWrapper <modelDir> <inputImage> <outDir> <targetsCommaOrAll>");
|
||||
System.out.println("示例: Anime2VividModelWrapper /models/anime_seg /images/in.jpg outDir foreground");
|
||||
System.out.println("示例: Anime2VividModelWrapper /models/anime_seg /images/in.jpg outDir all");
|
||||
return;
|
||||
}
|
||||
Path modelDir = Path.of(args[0]);
|
||||
File input = new File(args[1]);
|
||||
Path out = Path.of(args[2]);
|
||||
String targetsArg = args[3];
|
||||
|
||||
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels);
|
||||
Set<String> targets;
|
||||
if ("all".equalsIgnoreCase(targetsArg.trim())) {
|
||||
targets = new LinkedHashSet<>(labels);
|
||||
} else {
|
||||
String[] parts = targetsArg.split(",");
|
||||
targets = new LinkedHashSet<>();
|
||||
for (String p : parts) {
|
||||
if (!p.trim().isEmpty()) targets.add(p.trim());
|
||||
}
|
||||
}
|
||||
|
||||
try (Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(modelDir)) {
|
||||
Map<String, ResultFiles> m = wrapper.segmentAndSave(input, targets, out);
|
||||
m.forEach((k, v) -> {
|
||||
System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath()));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package com.chuangzhou.vivid2D.ai.face_parsing;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* BiSeNet 人脸解析模型的标准标签和颜色调色板。
|
||||
* 颜色值基于 zllrunning/face-parsing.PyTorch 仓库的 test.py 文件。
|
||||
* 标签索引必须与模型输出索引一致(0-18)。
|
||||
*/
|
||||
public class LabelPalette {
|
||||
|
||||
/**
|
||||
* BiSeNet 人脸解析模型的标准标签(19个类别,索引 0-18)
|
||||
*/
|
||||
public static List<String> defaultLabels() {
|
||||
return Arrays.asList(
|
||||
"background", // 0
|
||||
"skin", // 1
|
||||
"nose", // 2
|
||||
"eye_left", // 3
|
||||
"eye_right", // 4
|
||||
"eyebrow_left", // 5
|
||||
"eyebrow_right",// 6
|
||||
"ear_left", // 7
|
||||
"ear_right", // 8
|
||||
"mouth", // 9
|
||||
"lip_upper", // 10
|
||||
"lip_lower", // 11
|
||||
"hair", // 12
|
||||
"hat", // 13
|
||||
"earring", // 14
|
||||
"necklace", // 15
|
||||
"clothes", // 16
|
||||
"facial_hair",// 17
|
||||
"neck" // 18
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* 返回一个对应的调色板:类别名 -> ARGB 颜色值。
|
||||
* 颜色值基于 test.py 中 part_colors 数组的 RGB 值转换为 ARGB 格式 (0xFFRRGGBB)。
|
||||
*/
|
||||
public static Map<String, Integer> defaultPalette() {
|
||||
Map<String, Integer> map = new HashMap<>();
|
||||
// 索引 0: background
|
||||
map.put("background", 0xFF000000); // 黑色
|
||||
|
||||
// 索引 1-18: 对应 part_colors 数组的前 18 个颜色
|
||||
// 注意:这里假设 part_colors[i-1] 对应 索引 i 的标签。
|
||||
// 索引 1: skin -> [255, 0, 0]
|
||||
map.put("skin", 0xFFFF0000);
|
||||
// 索引 2: nose -> [255, 85, 0]
|
||||
map.put("nose", 0xFFFF5500);
|
||||
// 索引 3: eye_left -> [255, 170, 0]
|
||||
map.put("eye_left", 0xFFFFAA00);
|
||||
// 索引 4: eye_right -> [255, 0, 85]
|
||||
map.put("eye_right", 0xFFFF0055);
|
||||
// 索引 5: eyebrow_left -> [255, 0, 170]
|
||||
map.put("eyebrow_left",0xFFFF00AA);
|
||||
// 索引 6: eyebrow_right -> [0, 255, 0]
|
||||
map.put("eyebrow_right",0xFF00FF00);
|
||||
// 索引 7: ear_left -> [85, 255, 0]
|
||||
map.put("ear_left", 0xFF55FF00);
|
||||
// 索引 8: ear_right -> [170, 255, 0]
|
||||
map.put("ear_right", 0xFFAAFF00);
|
||||
// 索引 9: mouth -> [0, 255, 85]
|
||||
map.put("mouth", 0xFF00FF55);
|
||||
// 索引 10: lip_upper -> [0, 255, 170]
|
||||
map.put("lip_upper", 0xFF00FFAA);
|
||||
// 索引 11: lip_lower -> [0, 0, 255]
|
||||
map.put("lip_lower", 0xFF0000FF);
|
||||
// 索引 12: hair -> [85, 0, 255]
|
||||
map.put("hair", 0xFF5500FF);
|
||||
// 索引 13: hat -> [170, 0, 255]
|
||||
map.put("hat", 0xFFAA00FF);
|
||||
// 索引 14: earring -> [0, 85, 255]
|
||||
map.put("earring", 0xFF0055FF);
|
||||
// 索引 15: necklace -> [0, 170, 255]
|
||||
map.put("necklace", 0xFF00AAFF);
|
||||
// 索引 16: clothes -> [255, 255, 0]
|
||||
map.put("clothes", 0xFFFFFF00);
|
||||
// 索引 17: facial_hair -> [255, 85, 85]
|
||||
map.put("facial_hair", 0xFFFF5555);
|
||||
// 索引 18: neck -> [255, 170, 170]
|
||||
map.put("neck", 0xFFFFAAAA);
|
||||
|
||||
return map;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package com.chuangzhou.vivid2D.ai.face_parsing;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 分割结果容器
|
||||
*/
|
||||
public class SegmentationResult {
|
||||
// 分割掩码图(每个像素的颜色为对应类别颜色)
|
||||
private final BufferedImage maskImage;
|
||||
|
||||
// 类别索引 -> 类别名称
|
||||
private final Map<Integer, String> labels;
|
||||
|
||||
// 类别名称 -> ARGB 颜色
|
||||
private final Map<String, Integer> palette;
|
||||
|
||||
public SegmentationResult(BufferedImage maskImage, Map<Integer, String> labels, Map<String, Integer> palette) {
|
||||
this.maskImage = maskImage;
|
||||
this.labels = labels;
|
||||
this.palette = palette;
|
||||
}
|
||||
|
||||
public BufferedImage getMaskImage() {
|
||||
return maskImage;
|
||||
}
|
||||
|
||||
public Map<Integer, String> getLabels() {
|
||||
return labels;
|
||||
}
|
||||
|
||||
public Map<String, Integer> getPalette() {
|
||||
return palette;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,193 @@
|
||||
package com.chuangzhou.vivid2D.ai.face_parsing;
|
||||
|
||||
import ai.djl.MalformedModelException;
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDList;
|
||||
import ai.djl.ndarray.NDManager;
|
||||
import ai.djl.ndarray.types.DataType;
|
||||
import ai.djl.repository.zoo.Criteria;
|
||||
import ai.djl.repository.zoo.ModelNotFoundException;
|
||||
import ai.djl.repository.zoo.ZooModel;
|
||||
import ai.djl.translate.Batchifier;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import ai.djl.translate.Translator;
|
||||
import ai.djl.translate.TranslatorContext;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Segmenter: 加载模型并对图片做语义分割
|
||||
*
|
||||
* 说明:
|
||||
* - Translator.processOutput 在翻译器层就把模型输出处理成 (H, W) 的类别索引 NDArray,
|
||||
* 并把该 NDArray 拷贝到 persistentManager 中返回,从而避免后续 native 资源被释放的问题。
|
||||
* - 这里改为在 Translator 内部把 classMap 转为 Java int[](通过 classMap.toIntArray()),
|
||||
* 再用 persistentManager.create(int[], shape) 创建新的 NDArray 返回,确保安全。
|
||||
*/
|
||||
public class Segmenter implements AutoCloseable {
|
||||
|
||||
// 内部类,用于从Translator安全地传出数据
|
||||
public static class SegmentationData {
|
||||
final int[] indices;
|
||||
final long[] shape;
|
||||
|
||||
public SegmentationData(int[] indices, long[] shape) {
|
||||
this.indices = indices;
|
||||
this.shape = shape;
|
||||
}
|
||||
}
|
||||
|
||||
private final ZooModel<Image, SegmentationData> modelWrapper;
|
||||
private final Predictor<Image, SegmentationData> predictor;
|
||||
private final List<String> labels;
|
||||
private final Map<String, Integer> palette;
|
||||
|
||||
public Segmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
|
||||
this.labels = new ArrayList<>(labels);
|
||||
this.palette = LabelPalette.defaultPalette();
|
||||
|
||||
// Translator 的输出类型现在是 SegmentationData
|
||||
Translator<Image, SegmentationData> translator = new Translator<Image, SegmentationData>() {
|
||||
@Override
|
||||
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||
NDManager manager = ctx.getNDManager();
|
||||
NDArray array = input.toNDArray(manager);
|
||||
array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false);
|
||||
array = array.div(255f);
|
||||
array = array.expandDims(0);
|
||||
return new NDList(array);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SegmentationData processOutput(TranslatorContext ctx, NDList list) {
|
||||
if (list == null || list.isEmpty()) {
|
||||
throw new IllegalStateException("Model did not return any output.");
|
||||
}
|
||||
|
||||
NDArray out = list.get(0);
|
||||
NDArray classMap;
|
||||
|
||||
// 1. 解析模型输出,得到类别图谱 (classMap)
|
||||
long[] shape = out.getShape().getShape();
|
||||
if (shape.length == 4 && shape[1] > 1) {
|
||||
classMap = out.argMax(1);
|
||||
} else if (shape.length == 3) {
|
||||
classMap = (shape[0] == 1) ? out : out.argMax(0);
|
||||
} else if (shape.length == 2) {
|
||||
classMap = out;
|
||||
} else {
|
||||
throw new IllegalStateException("Unexpected output shape: " + Arrays.toString(shape));
|
||||
}
|
||||
|
||||
if (classMap.getShape().dimension() == 3) {
|
||||
classMap = classMap.squeeze(0);
|
||||
}
|
||||
|
||||
// 2. *** 关键步骤 ***
|
||||
// 在 NDArray 仍然有效的上下文中,将其转换为 Java 原生类型
|
||||
|
||||
// 首先,确保数据类型是 INT32
|
||||
NDArray int32ClassMap = classMap.toType(DataType.INT32, false);
|
||||
|
||||
// 然后,获取形状和 int[] 数组
|
||||
long[] finalShape = int32ClassMap.getShape().getShape();
|
||||
int[] indices = int32ClassMap.toIntArray();
|
||||
|
||||
// 3. 将 Java 对象封装并返回
|
||||
return new SegmentationData(indices, finalShape);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Batchifier getBatchifier() {
|
||||
return null; // 或者根据需要使用 Batchifier.STACK
|
||||
}
|
||||
};
|
||||
|
||||
// Criteria 的类型也需要更新
|
||||
Criteria<Image, SegmentationData> criteria = Criteria.builder()
|
||||
.setTypes(Image.class, SegmentationData.class)
|
||||
.optModelPath(modelDir)
|
||||
.optEngine("PyTorch")
|
||||
.optTranslator(translator)
|
||||
.build();
|
||||
|
||||
this.modelWrapper = criteria.loadModel();
|
||||
this.predictor = modelWrapper.newPredictor();
|
||||
}
|
||||
|
||||
public SegmentationResult segment(File imgFile) throws TranslateException, IOException {
|
||||
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
|
||||
|
||||
// predict 方法现在直接返回安全的 Java 对象
|
||||
SegmentationData data = predictor.predict(img);
|
||||
|
||||
long[] shp = data.shape;
|
||||
int[] indices = data.indices;
|
||||
|
||||
int height, width;
|
||||
if (shp.length == 2) {
|
||||
height = (int) shp[0];
|
||||
width = (int) shp[1];
|
||||
} else {
|
||||
throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp));
|
||||
}
|
||||
|
||||
// 后续处理完全基于 Java 对象,不再有 Native resource 问题
|
||||
BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
|
||||
Map<Integer, String> labelsMap = new HashMap<>();
|
||||
for (int i = 0; i < labels.size(); i++) {
|
||||
labelsMap.put(i, labels.get(i));
|
||||
}
|
||||
|
||||
for (int y = 0; y < height; y++) {
|
||||
for (int x = 0; x < width; x++) {
|
||||
int idx = indices[y * width + x];
|
||||
String label = labelsMap.getOrDefault(idx, "unknown");
|
||||
int argb = palette.getOrDefault(label, 0xFF00FF00);
|
||||
mask.setRGB(x, y, argb);
|
||||
}
|
||||
}
|
||||
|
||||
return new SegmentationResult(mask, labelsMap, palette);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
try {
|
||||
predictor.close();
|
||||
} catch (Exception ignore) {
|
||||
}
|
||||
try {
|
||||
modelWrapper.close();
|
||||
} catch (Exception ignore) {
|
||||
}
|
||||
}
|
||||
|
||||
// 小测试主函数(示例)
|
||||
public static void main(String[] args) throws Exception {
|
||||
if (args.length < 3) {
|
||||
System.out.println("用法: java Segmenter <modelDir> <inputImage> <outputMaskPng>");
|
||||
return;
|
||||
}
|
||||
Path modelDir = Path.of(args[0]);
|
||||
File input = new File(args[1]);
|
||||
File out = new File(args[2]);
|
||||
|
||||
List<String> labels = LabelPalette.defaultLabels();
|
||||
|
||||
try (Segmenter s = new Segmenter(modelDir, labels)) {
|
||||
SegmentationResult res = s.segment(input);
|
||||
ImageIO.write(res.getMaskImage(), "png", out);
|
||||
System.out.println("分割掩码已保存到: " + out.getAbsolutePath());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,184 @@
|
||||
package com.chuangzhou.vivid2D.ai.face_parsing;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.awt.*;
|
||||
import java.io.*;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* SegmenterExample
|
||||
*
|
||||
* 使用说明(命令行):
|
||||
* java -cp <your-classpath> com.chuangzhou.vivid2D.ai.face_parsing.SegmenterExample \
|
||||
* <modelDir> <inputImage> <outputDir> <targetsCommaOrAll>
|
||||
*
|
||||
* 示例:
|
||||
* java ... SegmenterExample /models/face_bisent /images/in.jpg /out "eye,face"
|
||||
* java ... SegmenterExample /models/face_bisent /images/in.jpg /out all
|
||||
*/
|
||||
public class SegmenterExample {
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
if (args.length < 4) {
|
||||
System.err.println("用法: SegmenterExample <modelDir> <inputImage> <outputDir> <targetsCommaOrAll>");
|
||||
System.err.println("例如: SegmenterExample /models/face_bisent input.jpg outDir eye,face");
|
||||
return;
|
||||
}
|
||||
|
||||
Path modelDir = Path.of(args[0]);
|
||||
File inputImage = new File(args[1]);
|
||||
Path outDir = Path.of(args[2]);
|
||||
String targetsArg = args[3];
|
||||
|
||||
if (!Files.exists(modelDir)) {
|
||||
System.err.println("modelDir 不存在: " + modelDir);
|
||||
return;
|
||||
}
|
||||
if (!inputImage.exists()) {
|
||||
System.err.println("输入图片不存在: " + inputImage.getAbsolutePath());
|
||||
return;
|
||||
}
|
||||
if (!Files.exists(outDir)) {
|
||||
Files.createDirectories(outDir);
|
||||
}
|
||||
|
||||
// 读取 synset.txt(如果有),否则使用默认 LabelPalette
|
||||
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels);
|
||||
|
||||
// 打开 Segmenter
|
||||
try (Segmenter segmenter = new Segmenter(modelDir, labels)) {
|
||||
SegmentationResult res = segmenter.segment(inputImage);
|
||||
|
||||
// 原始图片
|
||||
BufferedImage original = ImageIO.read(inputImage);
|
||||
|
||||
// palette: labelName -> ARGB int
|
||||
Map<String, Integer> palette = res.getPalette();
|
||||
Map<Integer, String> labelsMap = res.getLabels(); // index -> name
|
||||
|
||||
// 解析目标 labels 列表
|
||||
Set<String> targets = parseTargets(targetsArg, labels);
|
||||
|
||||
System.out.println("Will export targets: " + targets);
|
||||
|
||||
// maskImage: 每像素是类别颜色(ARGB)
|
||||
BufferedImage maskImage = res.getMaskImage();
|
||||
int w = maskImage.getWidth();
|
||||
int h = maskImage.getHeight();
|
||||
|
||||
// 为快速查 color -> labelName
|
||||
Map<Integer, String> colorToLabel = new HashMap<>();
|
||||
for (Map.Entry<String, Integer> e : palette.entrySet()) {
|
||||
colorToLabel.put(e.getValue(), e.getKey());
|
||||
}
|
||||
|
||||
// 对每个 target 生成单独的 mask 和 overlay
|
||||
for (String target : targets) {
|
||||
if (!palette.containsKey(target)) {
|
||||
System.err.println("警告:模型 palette 中没有标签 '" + target + "',跳过。");
|
||||
continue;
|
||||
}
|
||||
int targetColor = palette.get(target);
|
||||
|
||||
// 1) 生成透明背景的二值掩码(只保留 target 像素)
|
||||
BufferedImage partMask = new BufferedImage(w, h, BufferedImage.TYPE_INT_ARGB);
|
||||
for (int y = 0; y < h; y++) {
|
||||
for (int x = 0; x < w; x++) {
|
||||
int c = maskImage.getRGB(x, y);
|
||||
if (c == targetColor) {
|
||||
// 保留为不透明(使用原始颜色)
|
||||
partMask.setRGB(x, y, targetColor);
|
||||
} else {
|
||||
// 透明
|
||||
partMask.setRGB(x, y, 0x00000000);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2) 生成 overlay:在原图上叠加半透明的 targetColor 区域
|
||||
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
// 若分辨率不同,先缩放 mask 到原图大小(简单粗暴地按尺寸相同假设)
|
||||
BufferedImage maskResized = maskImage;
|
||||
if (original.getWidth() != w || original.getHeight() != h) {
|
||||
// 简单缩放 mask 到原图尺寸
|
||||
maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
Graphics2D g = maskResized.createGraphics();
|
||||
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
|
||||
g.drawImage(maskImage, 0, 0, original.getWidth(), original.getHeight(), null);
|
||||
g.dispose();
|
||||
}
|
||||
|
||||
// 在 overlay 上先画原图
|
||||
Graphics2D g2 = overlay.createGraphics();
|
||||
g2.drawImage(original, 0, 0, null);
|
||||
|
||||
// 创建半透明颜色(将 targetColor 的 alpha 设为 0x88)
|
||||
int rgbOnly = (targetColor & 0x00FFFFFF);
|
||||
int translucent = (0x88 << 24) | rgbOnly;
|
||||
// 创建一个图像,把 mask 中对应像素设置为 translucent,否则透明
|
||||
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
for (int y = 0; y < colorOverlay.getHeight(); y++) {
|
||||
for (int x = 0; x < colorOverlay.getWidth(); x++) {
|
||||
int mc = maskResized.getRGB(x, y);
|
||||
if (mc == targetColor) {
|
||||
colorOverlay.setRGB(x, y, translucent);
|
||||
} else {
|
||||
colorOverlay.setRGB(x, y, 0x00000000);
|
||||
}
|
||||
}
|
||||
}
|
||||
// 将 colorOverlay 画到 overlay 上
|
||||
g2.drawImage(colorOverlay, 0, 0, null);
|
||||
g2.dispose();
|
||||
|
||||
// 保存文件
|
||||
File maskOut = outDir.resolve(safeFileName(target) + "_mask.png").toFile();
|
||||
File overlayOut = outDir.resolve(safeFileName(target) + "_overlay.png").toFile();
|
||||
|
||||
ImageIO.write(partMask, "png", maskOut);
|
||||
ImageIO.write(overlay, "png", overlayOut);
|
||||
|
||||
System.out.println("Saved mask: " + maskOut.getAbsolutePath());
|
||||
System.out.println("Saved overlay: " + overlayOut.getAbsolutePath());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
|
||||
Path syn = modelDir.resolve("synset.txt");
|
||||
if (Files.exists(syn)) {
|
||||
try {
|
||||
List<String> lines = Files.readAllLines(syn);
|
||||
List<String> cleaned = new ArrayList<>();
|
||||
for (String l : lines) {
|
||||
String s = l.trim();
|
||||
if (!s.isEmpty()) cleaned.add(s);
|
||||
}
|
||||
if (!cleaned.isEmpty()) return Optional.of(cleaned);
|
||||
} catch (IOException ignore) {}
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
private static Set<String> parseTargets(String arg, List<String> availableLabels) {
|
||||
String a = arg.trim();
|
||||
if (a.equalsIgnoreCase("all")) {
|
||||
return new LinkedHashSet<>(availableLabels);
|
||||
}
|
||||
String[] parts = a.split(",");
|
||||
Set<String> out = new LinkedHashSet<>();
|
||||
for (String p : parts) {
|
||||
String t = p.trim();
|
||||
if (!t.isEmpty()) out.add(t);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
private static String safeFileName(String s) {
|
||||
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,261 @@
|
||||
package com.chuangzhou.vivid2D.ai.face_parsing;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.*;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.*;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* VividModelWrapper - 对之前 Segmenter / SegmenterExample 的封装
|
||||
*
|
||||
* 用法示例:
|
||||
* VividModelWrapper wrapper = VividModelWrapper.load(Paths.get("/path/to/modelDir"));
|
||||
* Map<String, VividModelWrapper.ResultFiles> out = wrapper.segmentAndSave(
|
||||
* new File("input.jpg"),
|
||||
* Set.of("eye","face"), // 或 Set.of(all labels...);若想全部传 "all" 可以用 helper parseTargets
|
||||
* Paths.get("outDir")
|
||||
* );
|
||||
* // out contains 每个目标标签对应的 mask+overlay 文件路径
|
||||
* wrapper.close();
|
||||
*/
|
||||
public class VividModelWrapper implements AutoCloseable {
|
||||
|
||||
private final Segmenter segmenter;
|
||||
private final List<String> labels; // index -> name
|
||||
private final Map<String, Integer> palette; // name -> ARGB
|
||||
|
||||
private VividModelWrapper(Segmenter segmenter, List<String> labels, Map<String, Integer> palette) {
|
||||
this.segmenter = segmenter;
|
||||
this.labels = labels;
|
||||
this.palette = palette;
|
||||
}
|
||||
|
||||
/**
|
||||
* 读取 modelDir/synset.txt(每行一个标签),若不存在则使用 LabelPalette.defaultLabels()
|
||||
* 并创建 Segmenter 实例。
|
||||
*/
|
||||
public static VividModelWrapper load(Path modelDir) throws Exception {
|
||||
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels);
|
||||
Segmenter s = new Segmenter(modelDir, labels);
|
||||
Map<String, Integer> palette = LabelPalette.defaultPalette();
|
||||
return new VividModelWrapper(s, labels, palette);
|
||||
}
|
||||
|
||||
public List<String> getLabels() {
|
||||
return Collections.unmodifiableList(labels);
|
||||
}
|
||||
|
||||
public Map<String, Integer> getPalette() {
|
||||
return Collections.unmodifiableMap(palette);
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接返回分割结果(SegmentationResult)
|
||||
*/
|
||||
public SegmentationResult segment(File inputImage) throws Exception {
|
||||
return segmenter.segment(inputImage);
|
||||
}
|
||||
|
||||
/**
|
||||
* 把指定 targets(标签名集合)从输入图片中分割并保存到 outDir。
|
||||
* 如果 targets 包含单个元素 "all"(忽略大小写),则保存所有标签。
|
||||
* <p>
|
||||
* 返回值:Map<labelName, ResultFiles>,ResultFiles 包含 maskFile、overlayFile(两个 PNG)
|
||||
*/
|
||||
public Map<String, ResultFiles> segmentAndSave(File inputImage, Set<String> targets, Path outDir) throws Exception {
|
||||
if (!Files.exists(outDir)) {
|
||||
Files.createDirectories(outDir);
|
||||
}
|
||||
|
||||
SegmentationResult res = segment(inputImage);
|
||||
BufferedImage original = ImageIO.read(inputImage);
|
||||
BufferedImage maskImage = res.getMaskImage();
|
||||
|
||||
int maskW = maskImage.getWidth();
|
||||
int maskH = maskImage.getHeight();
|
||||
|
||||
// 解析 targets
|
||||
Set<String> realTargets = parseTargetsSet(targets);
|
||||
Map<String, ResultFiles> saved = new LinkedHashMap<>();
|
||||
|
||||
for (String target : realTargets) {
|
||||
if (!palette.containsKey(target)) {
|
||||
// 尝试忽略大小写匹配
|
||||
String finalTarget = target;
|
||||
Optional<String> matched = palette.keySet().stream()
|
||||
.filter(k -> k.equalsIgnoreCase(finalTarget))
|
||||
.findFirst();
|
||||
if (matched.isPresent()) target = matched.get();
|
||||
else {
|
||||
System.err.println("Warning: unknown label '" + target + "' - skip.");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
int targetColor = palette.get(target);
|
||||
|
||||
// 1) 生成透明背景的二值掩码(只保留 target 像素)
|
||||
BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB);
|
||||
for (int y = 0; y < maskH; y++) {
|
||||
for (int x = 0; x < maskW; x++) {
|
||||
int c = maskImage.getRGB(x, y);
|
||||
if (c == targetColor) {
|
||||
partMask.setRGB(x, y, targetColor | 0xFF000000); // 保证不透明
|
||||
} else {
|
||||
partMask.setRGB(x, y, 0x00000000);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2) 将 mask 缩放到与原图一致(如果需要),并生成 overlay(半透明)
|
||||
BufferedImage maskResized = partMask;
|
||||
if (original.getWidth() != maskW || original.getHeight() != maskH) {
|
||||
maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
Graphics2D g = maskResized.createGraphics();
|
||||
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
|
||||
g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null);
|
||||
g.dispose();
|
||||
}
|
||||
|
||||
BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
Graphics2D g2 = overlay.createGraphics();
|
||||
g2.drawImage(original, 0, 0, null);
|
||||
// 半透明颜色(alpha = 0x88)
|
||||
int rgbOnly = (targetColor & 0x00FFFFFF);
|
||||
int translucent = (0x88 << 24) | rgbOnly;
|
||||
BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB);
|
||||
for (int y = 0; y < colorOverlay.getHeight(); y++) {
|
||||
for (int x = 0; x < colorOverlay.getWidth(); x++) {
|
||||
int mc = maskResized.getRGB(x, y);
|
||||
if ((mc & 0x00FFFFFF) == (targetColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) {
|
||||
colorOverlay.setRGB(x, y, translucent);
|
||||
} else {
|
||||
colorOverlay.setRGB(x, y, 0x00000000);
|
||||
}
|
||||
}
|
||||
}
|
||||
g2.drawImage(colorOverlay, 0, 0, null);
|
||||
g2.dispose();
|
||||
|
||||
// 保存
|
||||
String safe = safeFileName(target);
|
||||
File maskOut = outDir.resolve(safe + "_mask.png").toFile();
|
||||
File overlayOut = outDir.resolve(safe + "_overlay.png").toFile();
|
||||
|
||||
ImageIO.write(maskResized, "png", maskOut);
|
||||
ImageIO.write(overlay, "png", overlayOut);
|
||||
|
||||
saved.put(target, new ResultFiles(maskOut, overlayOut));
|
||||
}
|
||||
|
||||
return saved;
|
||||
}
|
||||
|
||||
private static String safeFileName(String s) {
|
||||
return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_");
|
||||
}
|
||||
|
||||
private static Set<String> parseTargetsSet(Set<String> in) {
|
||||
if (in == null || in.isEmpty()) return Collections.emptySet();
|
||||
// 若包含单个 "all"
|
||||
if (in.size() == 1) {
|
||||
String only = in.iterator().next();
|
||||
if ("all".equalsIgnoreCase(only.trim())) {
|
||||
// 由调用方自行取 labels(这里返回 sentinel, but caller already checks palette)
|
||||
// For convenience, return a set containing "all" and let caller logic handle it earlier.
|
||||
return Set.of("all");
|
||||
}
|
||||
}
|
||||
// 直接返回 trim 后的小写不变集合(保持用户传入的名字)
|
||||
Set<String> out = new LinkedHashSet<>();
|
||||
for (String s : in) {
|
||||
if (s != null) out.add(s.trim());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭底层资源
|
||||
*/
|
||||
@Override
|
||||
public void close() {
|
||||
try {
|
||||
segmenter.close();
|
||||
} catch (Exception ignore) {}
|
||||
}
|
||||
|
||||
/**
|
||||
* 存放结果文件路径
|
||||
*/
|
||||
public static class ResultFiles {
|
||||
private final File maskFile;
|
||||
private final File overlayFile;
|
||||
|
||||
public ResultFiles(File maskFile, File overlayFile) {
|
||||
this.maskFile = maskFile;
|
||||
this.overlayFile = overlayFile;
|
||||
}
|
||||
|
||||
public File getMaskFile() {
|
||||
return maskFile;
|
||||
}
|
||||
|
||||
public File getOverlayFile() {
|
||||
return overlayFile;
|
||||
}
|
||||
}
|
||||
|
||||
/* ================= helper: 从 modelDir 读取 synset.txt ================= */
|
||||
|
||||
private static Optional<List<String>> loadLabelsFromSynset(Path modelDir) {
|
||||
Path syn = modelDir.resolve("synset.txt");
|
||||
if (Files.exists(syn)) {
|
||||
try {
|
||||
List<String> lines = Files.readAllLines(syn);
|
||||
List<String> cleaned = new ArrayList<>();
|
||||
for (String l : lines) {
|
||||
String s = l.trim();
|
||||
if (!s.isEmpty()) cleaned.add(s);
|
||||
}
|
||||
if (!cleaned.isEmpty()) return Optional.of(cleaned);
|
||||
} catch (IOException ignore) {}
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
/* ================= convenience 主方法(快速测试) ================= */
|
||||
public static void main(String[] args) throws Exception {
|
||||
if (args.length < 4) {
|
||||
System.out.println("用法: VividModelWrapper <modelDir> <inputImage> <outDir> <targetsCommaOrAll>");
|
||||
System.out.println("示例: VividModelWrapper /models/bisenet /images/in.jpg outDir eye,face");
|
||||
return;
|
||||
}
|
||||
Path modelDir = Path.of(args[0]);
|
||||
File input = new File(args[1]);
|
||||
Path out = Path.of(args[2]);
|
||||
String targetsArg = args[3];
|
||||
|
||||
List<String> labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels);
|
||||
Set<String> targets;
|
||||
if ("all".equalsIgnoreCase(targetsArg.trim())) {
|
||||
targets = new LinkedHashSet<>(labels);
|
||||
} else {
|
||||
String[] parts = targetsArg.split(",");
|
||||
targets = new LinkedHashSet<>();
|
||||
for (String p : parts) {
|
||||
if (!p.trim().isEmpty()) targets.add(p.trim());
|
||||
}
|
||||
}
|
||||
|
||||
try (VividModelWrapper wrapper = VividModelWrapper.load(modelDir)) {
|
||||
Map<String, ResultFiles> m = wrapper.segmentAndSave(input, targets, out);
|
||||
m.forEach((k, v) -> {
|
||||
System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath()));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
38
src/main/java/com/chuangzhou/vivid2D/test/AI2Test.java
Normal file
38
src/main/java/com/chuangzhou/vivid2D/test/AI2Test.java
Normal file
@@ -0,0 +1,38 @@
|
||||
package com.chuangzhou.vivid2D.test;
|
||||
|
||||
import com.chuangzhou.vivid2D.ai.anime_face_segmentation.AnimeModelWrapper;
|
||||
|
||||
import java.io.PrintStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* 用来分析人物的脸部信息头发、眼睛、嘴巴、脸部、皮肤、衣服
|
||||
*/
|
||||
public class AI2Test {
|
||||
public static void main(String[] args) throws Exception {
|
||||
System.setOut(new PrintStream(System.out, true, StandardCharsets.UTF_8));
|
||||
System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8));
|
||||
|
||||
// 使用 AnimeModelWrapper 而不是 VividModelWrapper
|
||||
AnimeModelWrapper wrapper = AnimeModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\Anime-Face-Segmentation\\anime_unet.pt"));
|
||||
|
||||
// 使用 Anime-Face-Segmentation 的 7 个标签
|
||||
Set<String> animeLabels = Set.of(
|
||||
"background",
|
||||
"hair", // 头发
|
||||
"eye", // 眼睛
|
||||
"mouth", // 嘴巴
|
||||
"face", // 脸部
|
||||
"skin", // 皮肤
|
||||
"clothes" // 衣服
|
||||
);
|
||||
|
||||
wrapper.segmentAndSave(
|
||||
Paths.get("C:\\Users\\Administrator\\Desktop\\b_215609167a3a20ac2075487bd532bbff.jpg").toFile(),
|
||||
animeLabels,
|
||||
Paths.get("C:\\models\\out")
|
||||
);
|
||||
}
|
||||
}
|
||||
28
src/main/java/com/chuangzhou/vivid2D/test/AI3Test.java
Normal file
28
src/main/java/com/chuangzhou/vivid2D/test/AI3Test.java
Normal file
@@ -0,0 +1,28 @@
|
||||
package com.chuangzhou.vivid2D.test;
|
||||
|
||||
import com.chuangzhou.vivid2D.ai.anime_segmentation.Anime2VividModelWrapper;
|
||||
import com.chuangzhou.vivid2D.ai.face_parsing.VividModelWrapper;
|
||||
|
||||
import java.io.PrintStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* 这个ai模型负责分离人物与背景
|
||||
*/
|
||||
public class AI3Test {
|
||||
public static void main(String[] args) throws Exception {
|
||||
System.setOut(new PrintStream(System.out, true, StandardCharsets.UTF_8));
|
||||
System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8));
|
||||
Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\anime-segmentation-main\\isnetis_traced.pt"));
|
||||
|
||||
Set<String> faceLabels = Set.of("foreground");
|
||||
|
||||
wrapper.segmentAndSave(
|
||||
Paths.get("C:\\Users\\Administrator\\Desktop\\b_7a8349adece17d1e4bebd20cb2387cf6.jpg").toFile(),
|
||||
faceLabels,
|
||||
Paths.get("C:\\models\\out")
|
||||
);
|
||||
}
|
||||
}
|
||||
33
src/main/java/com/chuangzhou/vivid2D/test/AITest.java
Normal file
33
src/main/java/com/chuangzhou/vivid2D/test/AITest.java
Normal file
@@ -0,0 +1,33 @@
|
||||
package com.chuangzhou.vivid2D.test;
|
||||
|
||||
import com.chuangzhou.vivid2D.ai.face_parsing.VividModelWrapper;
|
||||
|
||||
import java.io.PrintStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* 测试人脸解析模型
|
||||
*/
|
||||
public class AITest {
|
||||
public static void main(String[] args) throws Exception {
|
||||
System.setOut(new PrintStream(System.out, true, StandardCharsets.UTF_8));
|
||||
System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8));
|
||||
VividModelWrapper wrapper = VividModelWrapper.load(Paths.get("C:\\models\\bisenet_face_parsing.pt"));
|
||||
|
||||
// 使用 BiSeNet 人脸解析模型的 18 个非背景标签
|
||||
Set<String> faceLabels = Set.of(
|
||||
"skin", "nose", "eye_left", "eye_right", "eyebrow_left",
|
||||
"eyebrow_right", "ear_left", "ear_right", "mouth", "lip_upper",
|
||||
"lip_lower", "hair", "hat", "earring", "necklace", "clothes",
|
||||
"facial_hair", "neck"
|
||||
);
|
||||
|
||||
wrapper.segmentAndSave(
|
||||
Paths.get("C:\\Users\\Administrator\\Desktop\\b_f4881214f0d18b6cf848b6736f554821.png").toFile(),
|
||||
faceLabels,
|
||||
Paths.get("C:\\models\\out")
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user