YoloV6Detector.cs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. using Microsoft.ML.OnnxRuntime;
  2. using Microsoft.ML.OnnxRuntime.Tensors;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Drawing;
  7. using System.Drawing.Imaging;
  8. using System.Linq;
  9. namespace YoloTest
  10. {
  11. internal class YoloV6Detector : IDisposable
  12. {
  13. private InferenceSession _session;
  14. public string[] _classNames; // 添加此字段
  15. public YoloV6Detector(string modelPath, string[] classNames = null, bool useGpu = true)
  16. {
  17. var options = new SessionOptions
  18. {
  19. InterOpNumThreads = 1, // 减少线程切换开销
  20. IntraOpNumThreads = Environment.ProcessorCount,
  21. GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL
  22. };
  23. options.AddSessionConfigEntry("session.intra_op.allow_spinning", "1");
  24. options.AddSessionConfigEntry("session.inter_op.allow_spinning", "1");
  25. if (useGpu)
  26. {
  27. try
  28. {
  29. //options.AppendExecutionProvider("OpenVINO");
  30. //options.AppendExecutionProvider_DML(0);
  31. options.AppendExecutionProvider_CUDA(0);
  32. Console.WriteLine("GPU推理运行");
  33. }
  34. catch(Exception e)
  35. {
  36. Console.WriteLine($"{e.Message}");
  37. options.AppendExecutionProvider_CPU(0);
  38. Console.WriteLine("CPU推理运行");
  39. }
  40. }
  41. _session = new InferenceSession(modelPath, options);
  42. _classNames = classNames ?? new string[]
  43. {
  44. "OK", "NG"
  45. };
  46. }
  47. public YoloV6Detector(string modelPath, string[] classNames = null)
  48. {
  49. var options = new SessionOptions();
  50. _session = new InferenceSession(modelPath, options);
  51. // 初始化类别名称
  52. _classNames = classNames ?? new string[]
  53. {
  54. "OK", "NG"
  55. // 根据你的模型实际类别修改
  56. };
  57. }
  58. Stopwatch OnceRunTime = new Stopwatch();
  59. public List<Detection> Detect(Image inputImage)
  60. {
  61. var inputTensor = PreprocessImage(inputImage);
  62. var inputs = new List<NamedOnnxValue>
  63. {
  64. NamedOnnxValue.CreateFromTensor("images", inputTensor)
  65. };
  66. OnceRunTime.Restart();
  67. using (var results = _session.Run(inputs))
  68. {
  69. OnceRunTime.Stop();
  70. Console.WriteLine("单次运行耗时{0}",OnceRunTime.ElapsedMilliseconds);
  71. return Postprocess(results, inputImage.Width, inputImage.Height);
  72. }
  73. }
  74. private Tensor<float> PreprocessImage(Image image)
  75. {
  76. const int inputWidth = 640;
  77. const int inputHeight = 640;
  78. const int channels = 3;
  79. using (var resizedImage = new Bitmap(image, inputWidth, inputHeight))
  80. {
  81. var tensorData = new float[1 * channels * inputHeight * inputWidth];
  82. // 使用 LockBits 直接访问内存,比 GetPixel 快 10-50 倍
  83. var bitmapData = resizedImage.LockBits(
  84. new Rectangle(0, 0, inputWidth, inputHeight),
  85. ImageLockMode.ReadOnly,
  86. PixelFormat.Format24bppRgb);
  87. try
  88. {
  89. unsafe
  90. {
  91. byte* ptr = (byte*)bitmapData.Scan0;
  92. int stride = bitmapData.Stride;
  93. for (int y = 0; y < inputHeight; y++)
  94. {
  95. for (int x = 0; x < inputWidth; x++)
  96. {
  97. byte* pixel = ptr + y * stride + x * 3;
  98. int idx = y * inputWidth + x;
  99. tensorData[idx] = pixel[2] / 255.0f; // R
  100. tensorData[inputHeight * inputWidth + idx] = pixel[1] / 255.0f; // G
  101. tensorData[2 * inputHeight * inputWidth + idx] = pixel[0] / 255.0f; // B
  102. }
  103. }
  104. }
  105. }
  106. finally
  107. {
  108. resizedImage.UnlockBits(bitmapData);
  109. }
  110. return new DenseTensor<float>(tensorData, new[] { 1, channels, inputHeight, inputWidth });
  111. }
  112. }
  113. //private Tensor<float> PreprocessImage(Image image)
  114. //{
  115. // const int inputWidth = 640;
  116. // const int inputHeight = 640;
  117. // const int channels = 3;
  118. // using (var resizedImage = new Bitmap(image, inputWidth, inputHeight))
  119. // {
  120. // var tensorData = new float[1 * channels * inputHeight * inputWidth];
  121. // for (int y = 0; y < inputHeight; y++)
  122. // {
  123. // for (int x = 0; x < inputWidth; x++)
  124. // {
  125. // var pixel = resizedImage.GetPixel(x, y);
  126. // tensorData[y * inputWidth + x] = pixel.R / 255.0f;
  127. // tensorData[inputHeight * inputWidth + y * inputWidth + x] = pixel.G / 255.0f;
  128. // tensorData[2 * inputHeight * inputWidth + y * inputWidth + x] = pixel.B / 255.0f;
  129. // }
  130. // }
  131. // return new DenseTensor<float>(tensorData, new[] { 1, channels, inputHeight, inputWidth });
  132. // }
  133. //}
  134. private List<Detection> Postprocess(IEnumerable<NamedOnnxValue> outputs, int imgWidth, int imgHeight)
  135. {
  136. var detections = new List<Detection>();
  137. // 1. 获取模型输出
  138. var output = outputs.FirstOrDefault();
  139. if (output == null) return detections;
  140. var outputTensor = output.AsTensor<float>();
  141. var outputShape = outputTensor.Dimensions;
  142. // YOLOv6 输出格式:[batch, num_anchors, 85] 或 [batch, num_anchors, 4 + 1 + num_classes]
  143. int numClasses = _classNames.Length;
  144. int numAnchors = outputShape[1];
  145. int numValuesPerAnchor = outputShape[2]; // 通常为 4 + 1 + numClasses
  146. // 2. 解析输出张量
  147. var rawDetections = new List<RawDetection>();
  148. const float confidenceThreshold = 0.4f;
  149. const float nmsThreshold = 0.2f;
  150. for (int i = 0; i < numAnchors; i++)
  151. {
  152. // 获取目标置信度 (第 5 个值,索引从 4 开始)
  153. float objectness = outputTensor[0, i, 4];
  154. if (objectness < confidenceThreshold) continue;
  155. // 获取各类别概率并计算最大置信度
  156. float maxClassProb = 0;
  157. int maxClassId = 0;
  158. for (int c = 0; c < numClasses; c++)
  159. {
  160. float classProb = outputTensor[0, i, 5 + c];
  161. if (classProb > maxClassProb)
  162. {
  163. maxClassProb = classProb;
  164. maxClassId = c;
  165. }
  166. }
  167. float confidence = objectness * maxClassProb;
  168. if (confidence < confidenceThreshold) continue;
  169. // 获取边界框坐标 (cx, cy, w, h)
  170. float cx = outputTensor[0, i, 0];
  171. float cy = outputTensor[0, i, 1];
  172. float w = outputTensor[0, i, 2];
  173. float h = outputTensor[0, i, 3];
  174. // 转换为左上角坐标
  175. float x = cx - w / 2;
  176. float y = cy - h / 2;
  177. rawDetections.Add(new RawDetection
  178. {
  179. X = x,
  180. Y = y,
  181. Width = w,
  182. Height = h,
  183. ClassId = maxClassId,
  184. Confidence = confidence
  185. });
  186. }
  187. // 3. 执行非极大值抑制 (NMS)
  188. var nmsDetections = ApplyNMS(rawDetections, nmsThreshold);
  189. // 4. 将坐标从模型尺寸转换回原始图像尺寸
  190. float scaleX = (float)imgWidth / 640;
  191. float scaleY = (float)imgHeight / 640;
  192. foreach (var det in nmsDetections)
  193. {
  194. detections.Add(new Detection(
  195. det.X * scaleX,
  196. det.Y * scaleY,
  197. det.Width * scaleX,
  198. det.Height * scaleY,
  199. det.ClassId,
  200. det.Confidence
  201. ));
  202. }
  203. return detections;
  204. }
  205. // 辅助类:原始检测结果
  206. private class RawDetection
  207. {
  208. public float X { get; set; }
  209. public float Y { get; set; }
  210. public float Width { get; set; }
  211. public float Height { get; set; }
  212. public int ClassId { get; set; }
  213. public float Confidence { get; set; }
  214. }
  215. // 非极大值抑制 (NMS)
  216. private List<RawDetection> ApplyNMS(List<RawDetection> detections, float iouThreshold)
  217. {
  218. if (detections.Count == 0) return detections;
  219. // 按置信度降序排序
  220. var sorted = detections.OrderByDescending(d => d.Confidence).ToList();
  221. var results = new List<RawDetection>();
  222. while (sorted.Count > 0)
  223. {
  224. var best = sorted[0];
  225. results.Add(best);
  226. sorted.RemoveAt(0);
  227. // 移除与当前最佳检测框 IoU 过高的框
  228. sorted = sorted.Where(d =>
  229. {
  230. if (d.ClassId != best.ClassId) return true;
  231. return CalculateIoU(best, d) < iouThreshold;
  232. }).ToList();
  233. }
  234. return results;
  235. }
  236. // 计算 IoU (交并比)
  237. private float CalculateIoU(RawDetection a, RawDetection b)
  238. {
  239. // 计算交集
  240. float x1 = Math.Max(a.X, b.X);
  241. float y1 = Math.Max(a.Y, b.Y);
  242. float x2 = Math.Min(a.X + a.Width, b.X + b.Width);
  243. float y2 = Math.Min(a.Y + a.Height, b.Y + b.Height);
  244. float intersection = Math.Max(0, x2 - x1) * Math.Max(0, y2 - y1);
  245. // 计算并集
  246. float areaA = a.Width * a.Height;
  247. float areaB = b.Width * b.Height;
  248. float union = areaA + areaB - intersection;
  249. return union > 0 ? intersection / union : 0;
  250. }
  251. public void Dispose()
  252. {
  253. _session?.Dispose();
  254. }
  255. }
  256. // Detection 类定义
  257. public class Detection
  258. {
  259. public float X { get; set; }
  260. public float Y { get; set; }
  261. public float Width { get; set; }
  262. public float Height { get; set; }
  263. public int ClassId { get; set; }
  264. public float Confidence { get; set; }
  265. public Detection(float x, float y, float width, float height, int classId, float confidence)
  266. {
  267. X = x;
  268. Y = y;
  269. Width = width;
  270. Height = height;
  271. ClassId = classId;
  272. Confidence = confidence;
  273. }
  274. }
  275. }