如何使用C#实现Padim算法的训练和推理
目录
说明
项目背景
算法实现
预处理模块——图像预处理
主要模块——训练:Resnet层信息提取
主要模块——信息处理,计算Anomaly Map
主要模块——评估
主要模块——评估:门限值的确定
主要模块——推理
写在最后
项目下载链接
说明
作者:来瓶霸王防脱发
项目地址:
https://github.com/IntptrMax/PadimSharp
原文地址:
https://blog.csdn.net/qq_30270773/article/details/143029865
项目背景
缺陷检测(Anomaly Detection)算法是一个区分正常类别与异常类别的二分类问题,但在工业场景中大多数数据都为良品,不良数据难以获取,更难枚举,所以训练一个全监督的模型是不切实际的。因此,异常检测模型通常以单类别学习的模式。Padim算法是一种十分优秀的缺陷检测算法,直接上图可以看一下这个算法的效果。
良品图片
不良品图片
检测效果
C#是一种十分受欢迎的编程语言,这种编程语言在工业场景下使用也是十分广泛的。在一些AI领域,会在Python下将模型转化为onnx形式,通过onnxruntime加载使用,进行推理。但是在onnx形式下进行训练十分困难。很多C#开发者不太熟悉Python环境,或者某些条件下希望在纯粹的C#环境下进行深度学习的训练和使用。这个还是有一定的困难的。
目前搜索了Github和CSDN排名靠前的几十条数据,还没有Padim算法在除Python平台下的训练+推理的相关项目或资料。本文就是在C#平台实现了Padim的训练+推理过程,应该在相关领域也算是独一份了。
算法实现
Padim算法的“训练”过程其实并没有涉及到真正的训练,而是使用Resnet18算法提取关键信息加以处理,在推理时再次使用,因此“训练”过程速度非常快,这也是这个算法的优点之一。Padim算法的具体实现还请参考相关论文:PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization
https://arxiv.org/abs/2011.08785
如果论文看起来困难,还有一些大佬对该算法在Python平台下的解读,也可以参考:PaDiM 原理与代码解析
https://blog.csdn.net/ooooocj/article/details/127601035
预处理模块——图像预处理
图像预处理使用的方法比较常规,使用了缩放等方式,此处并没有使用LetterBox,也可以达到预期效果:
var transformers = torchvision.transforms.Compose([
torchvision.transforms.Resize(resizeHeight,resizeWidth),
torchvision.transforms.CenterCrop(cropHeight,cropWidth),
torchvision.transforms.Normalize(means, stdevs)]);
主要模块——训练:Resnet层信息提取
使用Resnet模型进行推理,并提取Layer1、Layer2、Layer3层的信息,并进行了拼接(EmbeddingConcat)。注意:这里提取时使用了钩子,钩子在使用时会有资源释放,因此这里使用了比较迂回的方式记录结果
实现代码如下:
public List<(string, Tensor)> Forward(Tensor input)
{
List<(string, Tensor)> outputs = new List<(string, Tensor)>();
List<TempTensor> tempTensors = new List<TempTensor>();
foreach (var named_module in model.named_children())
{
string name = named_module.name;
if (name == "layer1" || name == "layer2" || name == "layer3")
{
((Sequential)named_module.module).register_forward_hook((Module, input, output) =>
{
tempTensors.Add(new TempTensor
{
Data = output.data<float>().ToArray(),
Name = name,
Shape = output.shape,
});
return null;
});
}
}
model.forward(input);
var layer1output = tempTensors.Find(a => a.Name == "layer1");
var layer2output = tempTensors.Find(a => a.Name == "layer2");
var layer3output = tempTensors.Find(a => a.Name == "layer3");
Tensor l1 = torch.tensor(layer1output.Data, layer1output.Shape, device: input.device);
Tensor l2 = torch.tensor(layer2output.Data, layer2output.Shape, device: input.device);
Tensor l3 = torch.tensor(layer3output.Data, layer3output.Shape, device: input.device);
outputs.Add(new("layer1", l1));
outputs.Add(new("layer2", l2));
outputs.Add(new("layer3", l3));
GC.Collect();
return outputs;
}
private Tensor EmbeddingConcat(Tensor[] features)
{
var embeddings = features[0];
for (int i = 1; i < features.Length; i++)
{
var layerEmbedding = features[i];
layerEmbedding = torch.nn.functional.interpolate(layerEmbedding, size: [embeddings.shape[2], embeddings.shape[2]], mode: InterpolationMode.Nearest);
embeddings = torch.cat([embeddings, layerEmbedding], 1);
}
return embeddings;
}
主要模块——信息处理,计算Anomaly Map
这一块主要对信息进行处理,获取矩阵的mean和cov(协方差矩阵),代码如下:
public Tensor ComputeAnomalyMapInternal(Tensor embedding, Tensor mean, Tensor covariance)
{
var scoreMap = ComputeDistance(embedding, mean, covariance);
var upSampledScoreMap = UpSample(scoreMap);
var smoothedAnomalyMap = SmoothAnomalyMap(upSampledScoreMap);
return smoothedAnomalyMap;
}
public Tensor ComputeAnomalyMap(List<(string, Tensor)> outputs, Tensor mean, Tensor covariance, Tensor idx)
{
Tensor embedding = GetEmbedding(outputs);
var embeddingVectors = torch.index_select(embedding, 1, idx);
return ComputeAnomalyMapInternal(embeddingVectors, mean, covariance);
}
主要模块——评估
与训练过程开始部分相似,也是获取图像的Embeddings,然后利用之前获取的Cov和mean计算马氏距离,以此评估图像的异常情况。马氏距离的计算方法如下:
private Tensor ComputeDistance(Tensor embedding, Tensor mean, Tensor covariance)
{
long batch = embedding.shape[0];
long channel = embedding.shape[1];
long height = embedding.shape[2];
long width = embedding.shape[3];
Tensor inv_covariance = covariance.permute(2, 0, 1).inverse();
var embedding_reshaped = embedding.reshape(batch, channel, height * width);
var delta = (embedding_reshaped - mean).permute(2, 0, 1);
var distances = (torch.matmul(delta, inv_covariance) * delta).sum(2).permute(1, 0);
distances = distances.reshape(batch, 1, height, width);
distances = distances.clamp(0).sqrt();
return distances;
}
主要模块——评估:门限值的确定
这里需要确定图像的评估门限和像素值的评估门限。如果在评估时有负向样本,这个值会更准确,如果只有正向样本也是可以的。在Python下有个precision_recall_curve包,可以计算相关参数,但是在C#下时没有的,因此在此处仍旧只能造轮子,代码如下:
private (float[] precisions, float[] recalls, float[] thresholds) _precision_recall_curve_compute_single_class(Tensor yTrue, Tensor yScores, int pos_label = 1)
{
var (fps, tps, thresholds) = BinaryClfCurve(yScores, yTrue, pos_label);
var precision = tps / (tps + fps);
var recall = tps / tps[-1];
var lastInd = torch.where(tps == tps[-1])[0][0].ToInt32();
int[] sl = new int[lastInd + 1];
for (int i = 0; i < sl.Length; i++)
{
sl[i] = i;
}
var reversedPrecision = precision[sl].flip(0);
var reversedRecall = recall[sl].flip(0);
var reversedThresholds = thresholds[sl].flip(0);
precision = torch.cat(new Tensor[] { reversedPrecision, torch.ones(1, dtype: precision.dtype, device: precision.device) });
recall = torch.cat(new Tensor[] { reversedRecall, torch.zeros(1, dtype: recall.dtype, device: recall.device) });
return (precision.data<float>().ToArray(), recall.data<float>().ToArray(), reversedThresholds.data<float>().ToArray());
}
private (Tensor fps, Tensor tps, Tensor thresholds) BinaryClfCurve(Tensor preds, Tensor target, int posLabel = 1)
{
using (torch.no_grad())
{
if (preds.ndim > target.ndim)
{
preds = preds[TensorIndex.Ellipsis, 0];
}
var descScoreIndices = torch.argsort(preds, descending: true);
preds = preds[descScoreIndices];
target = target[descScoreIndices];
Tensor weight = torch.tensor(1.0f);
var distinctValueIndices = torch.nonzero(preds[1..] - preds[..^1]).squeeze();
var thresholdIdxs = torch.cat(new Tensor[] { distinctValueIndices, torch.tensor(new long[] { target.shape[0] - 1 }, device: preds.device) });
target = (target == posLabel).to_type(ScalarType.Int64);
var tps = torch.cumsum(target * weight, dim: 0)[thresholdIdxs];
Tensor fps = 1 + thresholdIdxs - tps;
return (fps, tps, preds[thresholdIdxs]);
}
}
主要模块——推理
这个过程与上面过程也十分相似,正向计算出图像的Anomaly Map后,取出这个张量中最大的值,与图像的门限值进行比较,即可评估图像是否是良品。然后对这个张量中每个元素与像素门限值做对比,即可得到按像素的异常区域,以便绘制Mask和热力图。
Tensor orgImg = tensors["orgImage"].clone().to(device);
Tensor t = anomaly_map > pixel_threshold;
anomaly_map = (anomaly_map * t).squeeze(0);
anomaly_map = torchvision.transforms.functional.resize(anomaly_map, (int)orgImg.size(2), (int)orgImg.size(1));
Tensor heatmapNormalized = (anomaly_map - anomaly_map.min()) / (anomaly_map.max() - anomaly_map.min());
Tensor coloredHeatmap = torch.zeros([3, (int)orgImg.size(2), (int)orgImg.size(1)],device:anomaly_map.device);
coloredHeatmap[0] = heatmapNormalized.squeeze(0);
float alpha = 0.3f;
Tensor blendedImage = (1 - alpha) * (orgImg / 255.0f) + alpha * coloredHeatmap;
var imageTensor = blendedImage.clamp(0, 1).mul(255).to(ScalarType.Byte);
torchvision.io.write_jpeg(imageTensor.cpu(), "result.jpg");
写在最后
使用C#开发深度学习项目,尤其是训练的项目,是一个十分困难的过程。或者说除了Python平台,训练都十分困难。C#进行深度学习训练这个方向在国内基本很少有人开展,所以能查得到的资料很少。本人十分喜爱C#这门语言,又十分喜爱深度学习,因此仅半年一直在这方面努力。遇到了很多困难,也收获了很多。
这条路走的不容易,希望能有更多人能加入进来,一起开发,一起学习。
我在Github上已经将完整的代码发布了,项目地址为:
https://github.com/IntptrMax/PadimSharp
,期待你能在Github上送我一颗小星星。在我的Github里还GGMLSharp这个项目,这个项目也是C#平台下深度学习的开发包,希望能得到你的支持。
项目下载链接
https://download.csdn.net/download/qq_30270773/89897710