您当前的位置: 首页 >  深度学习

寒冰屋

暂无认证

  • 0浏览

    0关注

    2286博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

C#中的深度学习:预处理硬币检测数据集

寒冰屋 发布时间:2020-12-15 21:18:18 ,浏览量:0

在这里,我们将预处理硬币数据集,以供以后在监督学习模型中进行训练。在机器学习中预处理数据集通常涉及以下任务:

  • 下载源1.5 MB
  • 清理数据——通过平均周围数据的值或使用其他策略来填补丢失或损坏的数据留下的漏洞。
  • 标准化数据——将值缩放到标准范围内,通常为0到1。具有广泛范围值的数据可能会导致不规则性,因此我们将所有内容都置于一个公共范围内。
  • 一个“热编码”标签——将数据集中的对象的标签或类编码为二进制N维向量,其中N是类的总数。除了与对象类对应的元素设置为1之外,所有数组元素都设置为0。这意味着在每个数组中只有一个元素的值为1。
  • 将输入数据集分为训练集和验证集——训练集用于训练模型,验证集用于通过针对未得到训练的子集评估结果模型(训练后)来检查训练的准确性。在原始数据集上进行训练。

在此示例中,我们将使用Numpy.NET,它基本上是Python中流行的Numpy库的.NET版本。Numpy是一个致力于处理矩阵的库。

为了实现我们的数据集处理器,我们在PreProcessing文件夹中创建Utils和DataSet类。Utils类包含一个静态Normalize方法,如下所示:

public class Utils
   {
       public static NDarray Normalize(string path)
       {
           var colorMode = Settings.Channels == 3 ? "rgb" : "grayscale";
           var img = ImageUtil.LoadImg(path, color_mode: colorMode, target_size: (Settings.ImgWidth, Settings.ImgHeight));
           return ImageUtil.ImageToArray(img) / 255;
       }

   }

在此方法中,我们加载具有给定颜色模式(RGB或灰度)的图像,并将其大小调整为给定的宽度和高度。然后,我们返回包含图像的矩阵,其中每个元素都除以255。将每个元素除以255会将它们归一化,因为图像中任何像素的值都在0到255之间,因此通过将它们除以255,我们可以确保新范围是0到1(含)。

我们还在代码中使用了一个Settings类。此类包含用于在整个应用程序中使用的许多参数的可能值的常量。另一个类DataSet表示我们将用于训练机器学习模型的数据集。这里我们有以下几个领域:

  • _pathToFolder ——包含图像的文件夹的路径。
  • _extList ——要考虑的文件扩展名列表。
  • _labels——_pathToFolder中的图像标签或类别。
  • _objs——图片本身,以表示Numpy.NDarray。
  • _validationSplit ——用于将图像总数划分为验证集和训练集的百分比,在这种情况下,该百分比将定义相对于图像总数的验证集的大小。
  • NumberClasses ——数据集中唯一类的总数。
  • TrainX——训练数据,以表示Numpy.NDarray。
  • TrainY——训练标签,以表示Numpy.NDarray。
  • ValidationX——验证数据,以表示Numpy.NDarray。
  • ValidationY——验证标签,以表示Numpy.NDarray。

这是DataSet类:

public class DataSet
    {
        private string _pathToFolder;
        private string[] _extList;
        private List _labels;
        private List _objs;
        private double _validationSplit;
        public int NumberClasses { get; set; }
        public NDarray TrainX { get; set; }
        public NDarray ValidationX { get; set; }
        public NDarray TrainY { get; set; }
        public NDarray ValidationY { get; set; }

        public DataSet(string pathToFolder, string[] extList, int numberClasses, double validationSplit)
        {
            _pathToFolder = pathToFolder;
            _extList = extList;
            NumberClasses = numberClasses;
            _labels = new List();
            _objs = new List();
            _validationSplit = validationSplit;
        }

        public void LoadDataSet()
        {
            // Process the list of files found in the directory.
            string[] fileEntries = Directory.GetFiles(_pathToFolder);
            foreach (string fileName in fileEntries)
                if (IsRequiredExtFile(fileName))
                    ProcessFile(fileName);

            MapToClassRange();
            GetTrainValidationData();
        }

        private bool IsRequiredExtFile(string fileName)
        {
            foreach (var ext in _extList)
            {
                if (fileName.Contains("." + ext))
                {
                    return true;
                }
            }

            return false;
        }

        private void MapToClassRange()
        {
            HashSet uniqueLabels = _labels.ToHashSet();
            var uniqueLabelList = uniqueLabels.ToList();
            uniqueLabelList.Sort();

            _labels = _labels.Select(x => uniqueLabelList.IndexOf(x)).ToList();
        }

        private NDarray OneHotEncoding(List labels)
        {
            var npLabels = np.array(labels.ToArray()).reshape(-1);
            return Util.ToCategorical(npLabels, num_classes: NumberClasses);
        }

        private void ProcessFile(string path)
        {
            _objs.Add(Utils.Normalize(path));
            ProcessLabel(Path.GetFileName(path));
        }

        private void ProcessLabel(string filename)
        {
            _labels.Add(int.Parse(ExtractClassFromFileName(filename)));
        }

        private string ExtractClassFromFileName(string filename)
        {
            return filename.Split('_')[0].Replace("class", "");
        }

        private void GetTrainValidationData()
        {
            var listIndices = Enumerable.Range(0, _labels.Count).ToList();
            var toValidate = _objs.Count * _validationSplit;
            var random = new Random();
            var xValResult = new List();
            var yValResult = new List();
            var xTrainResult = new List();
            var yTrainResult = new List();

            // Split validation data
            for (var i = 0; i < toValidate; i++)
            {
                var randomIndex = random.Next(0, listIndices.Count);
                var indexVal = listIndices[randomIndex];
                xValResult.Add(_objs[indexVal]);
                yValResult.Add(_labels[indexVal]);
                listIndices.RemoveAt(randomIndex);
            }

            // Split rest (training data)
            listIndices.ForEach(indexVal => 
            { 
                xTrainResult.Add(_objs[indexVal]);
                yTrainResult.Add(_labels[indexVal]);
            });

            TrainY = OneHotEncoding(yTrainResult);
            ValidationY = OneHotEncoding(yValResult);
            TrainX = np.array(xTrainResult);
            ValidationX = np.array(xValResult);
        }
    }

这是每种方法的说明:

  • LoadDataSet()——我们调用该类的主要方法以将数据集加载到_pathToFolder中。它调用下面列出的其他方法来执行此操作。
  • IsRequiredExtFile(filename)——检查给定文件是否包含至少一个应为此数据集处理的扩展名(列于_extList中)。
  • MapToClassRange() ——获取数据集中唯一标签的列表。
  • ProcessFile(path)——使用Utils.Normalize方法对图像进行规范化并调用ProcessLabel方法。
  • ProcessLabel(filename)——将ExtractClassFromFileName方法的结果添加为标签。
  • ExtractClassFromFileName(filename)——从图像的文件名中提取类。
  • GetTrainValidationData() ——将数据集分为训练和验证子数据集。

在本系列中,我们将使用https://cvl.tuwien.ac.at/research/cvl-databases/coin-image-dataset/上的硬币图像数据集。

要加载数据集,我们可以在控制台应用程序的主类中包括以下内容:

var numberClasses = 60;
var fileExt = new string[] { ".png" };
var dataSetFilePath = @"C:/Users/arnal/Downloads/coin_dataset";
var dataSet = new PreProcessing.DataSet(dataSetFilePath, fileExt, numberClasses, 0.2);
dataSet.LoadDataSet();

现在,我们的数据可以输入到机器学习模型中。在接下来的文章会介绍监督机器学习的基础知识,以及训练和验证阶段包括哪些内容。它是为没有AI经验的读者准备的。

关注
打赏
1665926880
查看更多评论
立即登录/注册

微信扫码登录

0.0975s