Ⅰ. はじめに
タイトルの通り「C#でONNXファイルを利用して手書き数字を認識する方法」です。
Ⅱ. やり方
2. 必要なパッケージをNuGetからインストールする
Install-Package Microsoft.ML.OnnxRuntime Install-Package SixLabors.ImageSharp -pre
3. MNIST を学習済みの ONNX ファイルをダウンロードし、model.onx をC:\model.onx に保存する
4. サンプルプログラムを書く
MainWindow.xaml
<Window Height="160" Width="210"> <Grid> <InkCanvas x:Name="inkCanvas1" StrokeCollected="InkCanvas1_StrokeCollected" Width="100" Height="100" Margin="10,11,0,0" Background="#FFA0A0A0" HorizontalAlignment="Left" VerticalAlignment="Top" /> <TextBlock x:Name="textBlock1" FontSize="50" Margin="115,11,0,0" TextWrapping="Wrap" Text="0" VerticalAlignment="Top" HorizontalAlignment="Left" Width="27"/> <Button Content="Clear" Click="Button_Click" HorizontalAlignment="Left" Margin="115,83,0,0" VerticalAlignment="Top" Width="73" Height="28"/> </Grid> </Window>
MainWindow.xaml.cs
namespace MNISTTest { public partial class MainWindow : Window { const int ImgWidth = 28; const int ImgHeight = 28; public MainWindow() { InitializeComponent(); } static float[] GetFloatArrayFromImage(byte[] imgBytes) { var floats = new float[ImgWidth * ImgHeight]; using (var img = SixLabors.ImageSharp.Image.Load(imgBytes)) { // img.Save(File.Create("out.jpg"), new JpegEncoder() { Quality = 80 }); img.Mutate(x => x.Resize(ImgWidth, ImgHeight) .BinaryThreshold(0.5f)); // 2値化する for (var x = 0; x < img.Width; x++) { for (var y = 0; y < img.Height; y++) { floats[x + y * img.Width] = (img[x, y].R == 255) ? 0 : 1; } } } return floats; } private byte[] GetImageBytes() { var bounds = VisualTreeHelper.GetDescendantBounds(inkCanvas1); var width = (int)inkCanvas1.Width; var height = (int)inkCanvas1.Height; var rtb = new RenderTargetBitmap(width, height, 96d, 96d, PixelFormats.Pbgra32); var drawingVisual = new DrawingVisual(); using (var ctx = drawingVisual.RenderOpen()) { var vb = new VisualBrush(inkCanvas1); ctx.DrawRectangle(vb, null, new Rect(new Point(bounds.X, bounds.Y), new Point(width, height))); } rtb.Render(drawingVisual); var encoder = new BmpBitmapEncoder(); encoder.Frames.Add(BitmapFrame.Create(rtb)); using (var ms = new MemoryStream()) { encoder.Save(ms); return ms.ToArray(); } } private int Check(float[] dimensions) { using (var session = new InferenceSession("C:\\model.onnx")) { var inputNodeName = session.InputMetadata.First().Key; var innodedims = session.InputMetadata.First().Value.Dimensions; var inputTensor = new DenseTensor<float>(dimensions, innodedims); var namedOnnxValues = new List<NamedOnnxValue> { NamedOnnxValue.CreateFromTensor(inputNodeName, inputTensor) }; using (var results = session.Run(namedOnnxValues)) { var resultScores = results.First().AsTensor<float>().ToArray(); return Array.IndexOf(resultScores, resultScores.Max()); } } } private void InkCanvas1_StrokeCollected(object sender, InkCanvasStrokeCollectedEventArgs e) { var imgBytes = GetImageBytes(); var dimensions = GetFloatArrayFromImage(imgBytes); var result = Check(dimensions); textBlock1.Text = result.ToString(); } private void Button_Click(object sender, RoutedEventArgs e) { inkCanvas1.Strokes.Clear(); } } }
実行結果