目录

  1. 概述
  2. 数据文件名
  3. 数据集格式转换
    1. 读取数据集
    2. 保存数据集为 csv 文件
    3. 从 csv 文件中导入
  4. 总结
  5. 附录

概述

❓ MNIST 数据集是什么

MNIST 数据集是美国国家标准与技术研究院收集整理的大型手写数字数据库,包含 60,000 个示例的训练集以及 10,000 个示例的测试集

📚 这是一个公开数据集,可以通过访问官网进行下载

数据文件名

官网中存在四个链接文件,各个文件具体含义如下:

文件名含义
train-images-idx3-ubyte.gzTraining set images
train-labels-idx1-ubyte.gzTraining set labels
t10k-images-idx3-ubyte.gzTest set images
t10k-labels-idx1-ubyte.gzTest set labels

数据集格式转换

考虑到数据集的大小,图片最终都是压缩存储的。因此,如果想要获取真正的图片数据,还需要一些读取操作。具体的操作步骤如下:

读取数据集

下面的函数就能够将官网下载下来的数据进行读取:

1
2
3
4
5
6
7
8
9
10
11
def load_mnist(path, kind="train"):
# Load MNIST data from `path`
images_path = os.path.join(path, "%s-images.idx3-ubyte" % kind)
labels_path = os.path.join(path, "%s-labels.idx1-ubyte" % kind)
with open(labels_path, 'rb') as lb:
magic, n = struct.unpack('>II', lb.read(8))
labels = np.fromfile(lb, dtype=np.uint8)
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
return images, labels

保存数据集为 csv 文件

下面的代码主要进行持久化工作,能够将读取到文件持久化保存到本地

1
2
3
4
5
def saveFile(images, labels, test_images, test_label):
np.savetxt("../train_img.csv", images, fmt="%i", delimiter=',')
np.savetxt("../train_labels.csv", labels, fmt="%i", delimiter=',')
np.savetxt("../test_img.csv", test_images, fmt="%i", delimiter=',')
np.savetxt("../test_labels.csv", test_label, fmt="%i", delimiter=',')

从 csv 文件中导入

一旦将图片数据持久化到本地之后,后续就不再需要进行解压重新读取数据的操作了,只需要读取当时持久化保存的文件即可,具体的代码如下:

1
2
3
4
5
6
7
8
X_train = np.genfromtxt('train_img.csv',
dtype=int, delimiter=',')
y_train = np.genfromtxt('train_labels.csv',
dtype=int, delimiter=',')
X_test = np.genfromtxt('test_img.csv',
dtype=int, delimiter=',')
y_test = np.genfromtxt('test_labels.csv',
dtype=int, delimiter=',')

👴 从 CSV 文件中加载 MNIST 数据将会显著花费更长的时间, 因此如果可能的话, 还是建议维持数据集原有的字节格式

总结

✨ 如今,即使是简单的模型也能达到 95% 以上的分类准确率,因此不适合区分强模型和弱模型,所以 MNIST 更像是⼀个健全检查,而不是⼀个基准。

附录

MNIST 数据集
详解 MNIST 数据集