YoloV6Detector.cs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. using Microsoft.ML.OnnxRuntime;
  2. using Microsoft.ML.OnnxRuntime.Tensors;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Drawing;
  7. using System.Drawing.Imaging;
  8. using System.Linq;
  9. namespace YoloTest
  10. {
  11. internal class YoloV6Detector : IDisposable
  12. {
  13. private InferenceSession _session;
  14. public string[] _classNames; // 添加此字段
  15. public YoloV6Detector(string modelPath, string[] classNames = null, bool useGpu = true)
  16. {
  17. var options = new SessionOptions
  18. {
  19. InterOpNumThreads = 1, // 减少线程切换开销
  20. IntraOpNumThreads = Environment.ProcessorCount,
  21. GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL
  22. };
  23. options.AddSessionConfigEntry("session.intra_op.allow_spinning", "1");
  24. options.AddSessionConfigEntry("session.inter_op.allow_spinning", "1");
  25. if (useGpu)
  26. {
  27. try
  28. {
  29. options.AppendExecutionProvider_OpenVINO("CPU");
  30. Console.WriteLine("OpenVINO推理运行");
  31. }
  32. catch(Exception e)
  33. {
  34. Console.WriteLine($"{e.Message}");
  35. options.AppendExecutionProvider_CPU(0);
  36. Console.WriteLine("CPU推理运行");
  37. }
  38. // 方案二:使用 DirectML(Windows 通用,支持 AMD/Intel/NVIDIA)
  39. //options.AppendExecutionProvider("DML", new Dictionary<string, string>
  40. //{
  41. // { "device_id", "0" }
  42. //});
  43. }
  44. _session = new InferenceSession(modelPath, options);
  45. _classNames = classNames ?? new string[]
  46. {
  47. "OK", "NG"
  48. };
  49. }
  50. public YoloV6Detector(string modelPath, string[] classNames = null)
  51. {
  52. var options = new SessionOptions();
  53. _session = new InferenceSession(modelPath, options);
  54. // 初始化类别名称
  55. _classNames = classNames ?? new string[]
  56. {
  57. "OK", "NG"
  58. // 根据你的模型实际类别修改
  59. };
  60. }
  61. Stopwatch OnceRunTime = new Stopwatch();
  62. public List<Detection> Detect(Image inputImage)
  63. {
  64. var inputTensor = PreprocessImage(inputImage);
  65. var inputs = new List<NamedOnnxValue>
  66. {
  67. NamedOnnxValue.CreateFromTensor("images", inputTensor)
  68. };
  69. OnceRunTime.Restart();
  70. using (var results = _session.Run(inputs))
  71. {
  72. OnceRunTime.Stop();
  73. Console.WriteLine("单次推理耗时{0}", OnceRunTime.ElapsedMilliseconds);
  74. return Postprocess(results, inputImage.Width, inputImage.Height);
  75. }
  76. }
  77. private Tensor<float> PreprocessImage(Image image)
  78. {
  79. const int inputWidth = 640;
  80. const int inputHeight = 640;
  81. const int channels = 3;
  82. using (var resizedImage = new Bitmap(image, inputWidth, inputHeight))
  83. {
  84. var tensorData = new float[1 * channels * inputHeight * inputWidth];
  85. // 使用 LockBits 直接访问内存,比 GetPixel 快 10-50 倍
  86. var bitmapData = resizedImage.LockBits(
  87. new Rectangle(0, 0, inputWidth, inputHeight),
  88. ImageLockMode.ReadOnly,
  89. PixelFormat.Format24bppRgb);
  90. try
  91. {
  92. unsafe
  93. {
  94. byte* ptr = (byte*)bitmapData.Scan0;
  95. int stride = bitmapData.Stride;
  96. for (int y = 0; y < inputHeight; y++)
  97. {
  98. for (int x = 0; x < inputWidth; x++)
  99. {
  100. byte* pixel = ptr + y * stride + x * 3;
  101. int idx = y * inputWidth + x;
  102. tensorData[idx] = pixel[2] / 255.0f; // R
  103. tensorData[inputHeight * inputWidth + idx] = pixel[1] / 255.0f; // G
  104. tensorData[2 * inputHeight * inputWidth + idx] = pixel[0] / 255.0f; // B
  105. }
  106. }
  107. }
  108. }
  109. finally
  110. {
  111. resizedImage.UnlockBits(bitmapData);
  112. }
  113. return new DenseTensor<float>(tensorData, new[] { 1, channels, inputHeight, inputWidth });
  114. }
  115. }
  116. //private Tensor<float> PreprocessImage(Image image)
  117. //{
  118. // const int inputWidth = 640;
  119. // const int inputHeight = 640;
  120. // const int channels = 3;
  121. // using (var resizedImage = new Bitmap(image, inputWidth, inputHeight))
  122. // {
  123. // var tensorData = new float[1 * channels * inputHeight * inputWidth];
  124. // for (int y = 0; y < inputHeight; y++)
  125. // {
  126. // for (int x = 0; x < inputWidth; x++)
  127. // {
  128. // var pixel = resizedImage.GetPixel(x, y);
  129. // tensorData[y * inputWidth + x] = pixel.R / 255.0f;
  130. // tensorData[inputHeight * inputWidth + y * inputWidth + x] = pixel.G / 255.0f;
  131. // tensorData[2 * inputHeight * inputWidth + y * inputWidth + x] = pixel.B / 255.0f;
  132. // }
  133. // }
  134. // return new DenseTensor<float>(tensorData, new[] { 1, channels, inputHeight, inputWidth });
  135. // }
  136. //}
  137. private List<Detection> Postprocess(IEnumerable<NamedOnnxValue> outputs, int imgWidth, int imgHeight)
  138. {
  139. var detections = new List<Detection>();
  140. // 1. 获取模型输出
  141. var output = outputs.FirstOrDefault();
  142. if (output == null) return detections;
  143. var outputTensor = output.AsTensor<float>();
  144. var outputShape = outputTensor.Dimensions;
  145. // YOLOv6 输出格式:[batch, num_anchors, 85] 或 [batch, num_anchors, 4 + 1 + num_classes]
  146. int numClasses = _classNames.Length;
  147. int numAnchors = outputShape[1];
  148. int numValuesPerAnchor = outputShape[2]; // 通常为 4 + 1 + numClasses
  149. // 2. 解析输出张量
  150. var rawDetections = new List<RawDetection>();
  151. const float confidenceThreshold = 0.4f;
  152. const float nmsThreshold = 0.2f;
  153. for (int i = 0; i < numAnchors; i++)
  154. {
  155. // 获取目标置信度 (第 5 个值,索引从 4 开始)
  156. float objectness = outputTensor[0, i, 4];
  157. if (objectness < confidenceThreshold) continue;
  158. // 获取各类别概率并计算最大置信度
  159. float maxClassProb = 0;
  160. int maxClassId = 0;
  161. for (int c = 0; c < numClasses; c++)
  162. {
  163. float classProb = outputTensor[0, i, 5 + c];
  164. if (classProb > maxClassProb)
  165. {
  166. maxClassProb = classProb;
  167. maxClassId = c;
  168. }
  169. }
  170. float confidence = objectness * maxClassProb;
  171. if (confidence < confidenceThreshold) continue;
  172. // 获取边界框坐标 (cx, cy, w, h)
  173. float cx = outputTensor[0, i, 0];
  174. float cy = outputTensor[0, i, 1];
  175. float w = outputTensor[0, i, 2];
  176. float h = outputTensor[0, i, 3];
  177. // 转换为左上角坐标
  178. float x = cx - w / 2;
  179. float y = cy - h / 2;
  180. rawDetections.Add(new RawDetection
  181. {
  182. X = x,
  183. Y = y,
  184. Width = w,
  185. Height = h,
  186. ClassId = maxClassId,
  187. Confidence = confidence
  188. });
  189. }
  190. // 3. 执行非极大值抑制 (NMS)
  191. var nmsDetections = ApplyNMS(rawDetections, nmsThreshold);
  192. // 4. 将坐标从模型尺寸转换回原始图像尺寸
  193. float scaleX = (float)imgWidth / 640;
  194. float scaleY = (float)imgHeight / 640;
  195. foreach (var det in nmsDetections)
  196. {
  197. detections.Add(new Detection(
  198. det.X * scaleX,
  199. det.Y * scaleY,
  200. det.Width * scaleX,
  201. det.Height * scaleY,
  202. det.ClassId,
  203. det.Confidence
  204. ));
  205. }
  206. return detections;
  207. }
  208. // 辅助类:原始检测结果
  209. private class RawDetection
  210. {
  211. public float X { get; set; }
  212. public float Y { get; set; }
  213. public float Width { get; set; }
  214. public float Height { get; set; }
  215. public int ClassId { get; set; }
  216. public float Confidence { get; set; }
  217. }
  218. // 非极大值抑制 (NMS)
  219. private List<RawDetection> ApplyNMS(List<RawDetection> detections, float iouThreshold)
  220. {
  221. if (detections.Count == 0) return detections;
  222. // 按置信度降序排序
  223. var sorted = detections.OrderByDescending(d => d.Confidence).ToList();
  224. var results = new List<RawDetection>();
  225. while (sorted.Count > 0)
  226. {
  227. var best = sorted[0];
  228. results.Add(best);
  229. sorted.RemoveAt(0);
  230. // 移除与当前最佳检测框 IoU 过高的框
  231. sorted = sorted.Where(d =>
  232. {
  233. if (d.ClassId != best.ClassId) return true;
  234. return CalculateIoU(best, d) < iouThreshold;
  235. }).ToList();
  236. }
  237. return results;
  238. }
  239. // 计算 IoU (交并比)
  240. private float CalculateIoU(RawDetection a, RawDetection b)
  241. {
  242. // 计算交集
  243. float x1 = Math.Max(a.X, b.X);
  244. float y1 = Math.Max(a.Y, b.Y);
  245. float x2 = Math.Min(a.X + a.Width, b.X + b.Width);
  246. float y2 = Math.Min(a.Y + a.Height, b.Y + b.Height);
  247. float intersection = Math.Max(0, x2 - x1) * Math.Max(0, y2 - y1);
  248. // 计算并集
  249. float areaA = a.Width * a.Height;
  250. float areaB = b.Width * b.Height;
  251. float union = areaA + areaB - intersection;
  252. return union > 0 ? intersection / union : 0;
  253. }
  254. public void Dispose()
  255. {
  256. _session?.Dispose();
  257. }
  258. }
  259. // Detection 类定义
  260. public class Detection
  261. {
  262. public float X { get; set; }
  263. public float Y { get; set; }
  264. public float Width { get; set; }
  265. public float Height { get; set; }
  266. public int ClassId { get; set; }
  267. public float Confidence { get; set; }
  268. public Detection(float x, float y, float width, float height, int classId, float confidence)
  269. {
  270. X = x;
  271. Y = y;
  272. Width = width;
  273. Height = height;
  274. ClassId = classId;
  275. Confidence = confidence;
  276. }
  277. }
  278. }