備忘録

備忘録

C#でONNXファイルを利用して手書き数字を認識する方法

Ⅰ. はじめに

タイトルの通り「C#でONNXファイルを利用して手書き数字を認識する方法」です。

Ⅱ. やり方

1. WPFとして新規プロジェクトを作成する

.NET Core / .NET Framework どちらでもOKです。

2. 必要なパッケージをNuGetからインストールする
Install-Package Microsoft.ML.OnnxRuntime
Install-Package SixLabors.ImageSharp -pre
3. MNIST を学習済みの ONNX ファイルをダウンロードし、model.onx をC:\model.onx に保存する

https://github.com/onnx/models/tree/master/mnist

4. サンプルプログラムを書く

MainWindow.xaml

<Window Height="160" Width="210">
  <Grid>
    <InkCanvas x:Name="inkCanvas1" StrokeCollected="InkCanvas1_StrokeCollected" Width="100" Height="100" Margin="10,11,0,0" Background="#FFA0A0A0" HorizontalAlignment="Left" VerticalAlignment="Top" />
    <TextBlock x:Name="textBlock1" FontSize="50" Margin="115,11,0,0" TextWrapping="Wrap" Text="0" VerticalAlignment="Top" HorizontalAlignment="Left" Width="27"/>
    <Button Content="Clear" Click="Button_Click" HorizontalAlignment="Left" Margin="115,83,0,0" VerticalAlignment="Top" Width="73" Height="28"/>
  </Grid>
</Window>

MainWindow.xaml.cs

namespace MNISTTest
{
  public partial class MainWindow : Window
  {
    const int ImgWidth = 28;
    const int ImgHeight = 28;

    public MainWindow()
    {
      InitializeComponent();
    }

    static float[] GetFloatArrayFromImage(byte[] imgBytes)
    {
      var floats = new float[ImgWidth * ImgHeight];

      using (var img = SixLabors.ImageSharp.Image.Load(imgBytes))
      {
        // img.Save(File.Create("out.jpg"), new JpegEncoder() { Quality = 80 });
        img.Mutate(x =>
          x.Resize(ImgWidth, ImgHeight)
          .BinaryThreshold(0.5f)); // 2値化する
        
        for (var x = 0; x < img.Width; x++)
        {
          for (var y = 0; y < img.Height; y++)
          {
            floats[x + y * img.Width] = (img[x, y].R == 255) ? 0 : 1;
          }
        }
      }

      return floats;
    }

    private byte[] GetImageBytes()
    {
      var bounds = VisualTreeHelper.GetDescendantBounds(inkCanvas1);
      var width = (int)inkCanvas1.Width;
      var height = (int)inkCanvas1.Height;
      var rtb = new RenderTargetBitmap(width, height, 96d, 96d, PixelFormats.Pbgra32);
      var drawingVisual = new DrawingVisual();

      using (var ctx = drawingVisual.RenderOpen())
      {
        var vb = new VisualBrush(inkCanvas1);
        ctx.DrawRectangle(vb, null, new Rect(new Point(bounds.X, bounds.Y), new Point(width, height)));
      }

      rtb.Render(drawingVisual);
      var encoder = new BmpBitmapEncoder();
      encoder.Frames.Add(BitmapFrame.Create(rtb));
      using (var ms = new MemoryStream())
      {
        encoder.Save(ms);
        return ms.ToArray();
      }
    }

    private int Check(float[] dimensions)
    {
      using (var session = new InferenceSession("C:\\model.onnx"))
      {
        var inputNodeName = session.InputMetadata.First().Key;
        var innodedims = session.InputMetadata.First().Value.Dimensions;
        var inputTensor = new DenseTensor<float>(dimensions, innodedims);

        var namedOnnxValues = new List<NamedOnnxValue>
        {
          NamedOnnxValue.CreateFromTensor(inputNodeName, inputTensor)
        };

        using (var results = session.Run(namedOnnxValues))
        {
          var resultScores = results.First().AsTensor<float>().ToArray();
          return Array.IndexOf(resultScores, resultScores.Max());
        }
      }
    }

    private void InkCanvas1_StrokeCollected(object sender, InkCanvasStrokeCollectedEventArgs e)
    {
      var imgBytes = GetImageBytes();
      var dimensions = GetFloatArrayFromImage(imgBytes);
      var result = Check(dimensions);

      textBlock1.Text = result.ToString();
    }

    private void Button_Click(object sender, RoutedEventArgs e)
    {
      inkCanvas1.Strokes.Clear();
    }
  }
}

実行結果

f:id:kagasu:20190521161401g:plain