pythonproject/image_classification_dataset.ipynb

973 lines
122 KiB
Plaintext
Raw Permalink Normal View History

2024-06-25 14:15:07 +08:00
{
"cells": [
{
"metadata": {},
"cell_type": "raw",
"source": "MNIST数据集 (LeCun et al., 1998) 是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。 我们将使用类似但更复杂的Fashion-MNIST数据集 (Xiao et al., 2017)。",
"id": "58ac648c45d06f50"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:55:25.877309Z",
"start_time": "2024-05-23T13:55:25.812732Z"
}
},
"cell_type": "code",
"source": [
"import tensorflow as tf\n",
"from d2l import tensorflow as d2l\n",
"from IPython import display\n",
"d2l.use_svg_display()"
],
"id": "4f19e5d16d0a7341",
"outputs": [],
"execution_count": 73
},
{
"metadata": {},
"cell_type": "raw",
"source": [
"3.5.1. 读取数据集\n",
"我们可以通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。"
],
"id": "c67a3433075cb8ed"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:55:32.337258Z",
"start_time": "2024-05-23T13:55:31.440906Z"
}
},
"cell_type": "code",
"source": "mnist_train, mnist_test = tf.keras.datasets.fashion_mnist.load_data()",
"id": "e5defd40322a3af2",
"outputs": [],
"execution_count": 74
},
{
"metadata": {},
"cell_type": "raw",
"source": "Fashion-MNIST由10个类别的图像组成 每个类别由训练数据集train dataset中的6000张图像 和测试数据集test dataset中的1000张图像组成。 因此训练集和测试集分别包含60000和10000张图像。 测试数据集不会用于训练,只用于评估模型性能。",
"id": "9f2c0b217b6b5e30"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:55:34.042764Z",
"start_time": "2024-05-23T13:55:34.012956Z"
}
},
"cell_type": "code",
"source": "len(mnist_train[0]), len(mnist_test[0])",
"id": "6129edd7e9c7e9fc",
"outputs": [
{
"data": {
"text/plain": [
"(60000, 10000)"
]
},
"execution_count": 75,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 75
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"每个输入图像的高度和宽度均为28像素。 数据集由灰度图像组成其通道数为1。 为了简洁起见,本书将高度\n",
"h像素、宽度w像素图像的形状记hxw或h,w。"
],
"id": "719bd9e11b3b06a9"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:55:36.376606Z",
"start_time": "2024-05-23T13:55:36.366558Z"
}
},
"cell_type": "code",
"source": "mnist_train[0][0].shape",
"id": "3ed7450fafac4f36",
"outputs": [
{
"data": {
"text/plain": [
"(28, 28)"
]
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 76
},
{
"metadata": {},
"cell_type": "raw",
"source": "Fashion-MNIST中包含的10个类别分别为t-shirtT恤、trouser裤子、pullover套衫、dress连衣裙、coat外套、sandal凉鞋、shirt衬衫、sneaker运动鞋、bag和ankle boot短靴。 以下函数用于在数字标签索引及其文本名称之间进行转换。",
"id": "8003fdd00380a19b"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:55:38.313353Z",
"start_time": "2024-05-23T13:55:38.302321Z"
}
},
"cell_type": "code",
"source": [
"def get_fashion_mnist_labels(labels): #@save\n",
" \"\"\"返回Fashion-MNIST数据集的文本标签\"\"\"\n",
" text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',\n",
" 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']\n",
" return [text_labels[int(i)] for i in labels]"
],
"id": "89d7939efe1b7dda",
"outputs": [],
"execution_count": 77
},
{
"metadata": {},
"cell_type": "markdown",
"source": "我们现在可以创建一个函数来可视化这些样本。",
"id": "b197821aad4ea3d9"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:55:39.686787Z",
"start_time": "2024-05-23T13:55:39.674485Z"
}
},
"cell_type": "code",
"source": [
"def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save\n",
" \"\"\"绘制图像列表\"\"\"\n",
" figsize = (num_cols * scale, num_rows * scale)\n",
" _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)\n",
" axes = axes.flatten()\n",
" for i, (ax, img) in enumerate(zip(axes, imgs)):\n",
" ax.imshow(img.numpy())\n",
" ax.axes.get_xaxis().set_visible(False)\n",
" ax.axes.get_yaxis().set_visible(False)\n",
" if titles:\n",
" ax.set_title(titles[i])\n",
" return axes"
],
"id": "a8769deec41f021d",
"outputs": [],
"execution_count": 78
},
{
"metadata": {},
"cell_type": "markdown",
"source": "以下是训练数据集中前几个样本的图像及其相应的标签。",
"id": "a982d30eb1488615"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:55:42.522468Z",
"start_time": "2024-05-23T13:55:41.174261Z"
}
},
"cell_type": "code",
"source": [
"X = tf.constant(mnist_train[0][:18])\n",
"y = tf.constant(mnist_train[1][:18])\n",
"show_images(X, 2, 9, titles=get_fashion_mnist_labels(y));"
],
"id": "92f6aa86ebd139f1",
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 1350x300 with 18 Axes>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"767.7pt\" height=\"191.304163pt\" viewBox=\"0 0 767.7 191.304163\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n <metadata>\n <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n <cc:Work>\n <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n <dc:date>2024-05-23T21:55:42.349715</dc:date>\n <dc:format>image/svg+xml</dc:format>\n <dc:creator>\n <cc:Agent>\n <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n </cc:Agent>\n </dc:creator>\n </cc:Work>\n </rdf:RDF>\n </metadata>\n <defs>\n <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n </defs>\n <g id=\"figure_1\">\n <g id=\"patch_1\">\n <path d=\"M 0 191.304163 \nL 767.7 191.304163 \nL 767.7 0 \nL 0 0 \nz\n\" style=\"fill: #ffffff\"/>\n </g>\n <g id=\"axes_1\">\n <g id=\"patch_2\">\n <path d=\"M 7.2 93.384163 \nL 78.266038 93.384163 \nL 78.266038 22.318125 \nL 7.2 22.318125 \nz\n\" style=\"fill: #ffffff\"/>\n </g>\n <g clip-path=\"url(#p27b7da0425)\">\n <image xlink:href=\"data:image/png;base64,\niVBORw0KGgoAAAANSUhEUgAAAGMAAABjCAYAAACPO76VAAAJE0lEQVR4nO1dbYxUVxk+986987GzMzvsLuyy3YWFsghUoGgXSrWglEIlNWjAj4SY9IeJxmiM8SvRmuhPTIx/xNRY6y+aRpvUqqlaZQMRIkJcipS1C2VZELaws9/DztedueMPk/Oe56wzaRrCvsT3+fWcee7OvTvvnPe8533POePsdg7WlIAF3MV+AAFBjMEIYgxGEGMwghiDEcQYjODdi5s4/Rs1L3QmQItNljTPrUQtfWVec7dUAS08/6+7+YgsID2DEcQYjCDGYARPuRF8Jay+qz+cP7AN2tu+c5Z46gpo66KnNfdVCFrcoXbSdUAr1ihTY39rThZ6NK9a6sD0emgHNdJv59Og+ZH6/29Yo+cpVHzQZgtxzSMuZpSKx9s1bxsKQIu9dlbVg/QMRhBjMILTKGs7e+hRaK/+0rDmWzNXQRucW6H59VwraEFINvdddFNNflnzeAS7dNRwIa7CxwwVuZBkpAxa0itBO+0VNU9FiqC5Dj6PiYhxzzOzvXWvS1n3qxhucXsLuuwXrj6mecu+t/FZ6t5BcM8hxmAEMQYjLBgzRg5v1/x7+38NF5+YWaf5jfkMaHPlmOZLm+ZB626a0bzVR63Fy2sedzDlMVul9EiTi+NC1RgzbpVaQCuEUWgHIYXvpRBD+YQxThWqGL5m/ILmuUoctFFjXBybwvs3N9G4lEngGPVU50XNj/5sL2jSMxhBjMEIXrhzC7zwqT2nNH/59iOgdcRzmj/deQG04Xyn5mMFnOXOBdTFA8tN3CxmNF8Wy4G2IjapecotgBZ1KOzt8mdAeyg6Bu3rlSWa36pkQBvKd9H7xGZBuzBHWr6Crs91yLu3t9wBrSVGrqm/9Rpo+Sq588IyDNelZzCCGIMRxBiM4OV6YvCC6Qt3tb8F2kSQ0vyNXA9o3YlpzVcnsqCtid3S3PbZf8w+pLmd1phw6X7DQSdofYlxzfc2D4H2o9u7of3J1kHNP568BNrjCUpJDJXxHitjE5rPVJtAK4UUBq+KjYMW1KiAGrGy1J0ejUt/Pv84aNIzGEGMwQhevgPt8XT6Dc1fnMKsbU98SvNNrddBa4tgeGci5VKotzMxCVrSpYzncWOGr5RSHT516e7oFGgfaSL38szXvg5aJY5FqpO9FL5XkhhOpjfT83xlzQBocSf4n1wppVJGJjhiZZQjxrURKytsZg7SF/CzkJ7BCGIMRhBjMILzpP9ZcHibz1Lm9InURbg4X6MwuBhihnMsWKLqIeaSDx0P0nWvs7E2TiFxfxzHqEPf/4bmU09gZvTKrl9C+1iBUjDZCt7/NxM0ngxex3D90V6qZm5M3QRttkKhrl09NEPdjJsHrVijz+1I31rQpGcwghiDERYUl0of69d82XdH4OKH0zc035DAbmu6rbiLYeBQ4QHN81bhZ3l0RnOz6yuFYeF4OQXaH65u0PzY1udAe3bsKWivSFBY/P7EDdAONM+pengpR653dRRn2SPlZZrbLtoMyXt9zEb0GQWrz/V8CDTpGYwgxmAEMQYjNFzE1gjecsxwBqs6NJ9aj74/30kpgIf34b6KZzpOap6tWutgjQUKuSru3eg0qnsDsxtAa7YWlbVEyE9/IDEK2kxIz9rlTYP27bcPat7RhFXI51e+pnlQw5THcEBTgJS1kOKv+TWav7JhKWjSMxhBjMEIC9yU41FhpFapLPiDu43C/q2a//sTuDz/0JYzmu9oxkLXmfyDmptuSCmllnoYrpqz3rEyhqFmGN7qYeY5E6HZc7WG39t5I0TPh1ig6zRC20eMwppSSu0580XNew6+CZr0DEYQYzCCGIMRFmw9hnHCwYqZ45HvdSKWHV1qh0UMLRttTUu8SuPC2ldRO6so2/rpUQwttzSNan4ryIDmO3g/cw9GdxSra+aYEVrjwriR4bUrmQ/4FAZfLmGYby5e6PaaQUv9FtM68Jx1FcE9hxiDEcQYjND4uIoaZkpqQdngd+cBHJ/idfP9bXz1C1+G9o+fO6K5r3CMiFpjRrlGY09vbAa0rJFmGSysAi1lzV9AM1a82GOUOSc5PNkHWqkFx2ET0jMYQYzBCPfkVJ1GMF2TG8etWmGRXMHERqwQthvbvy4FmCU217r+t01u6ngOT0+4Y+yXeCyFW4HNENXOGpuL0ezK5q4E7cnYefSboKX20Prd8uV+0KRnMIIYgxHEGIyw6GOGmXJplLLv+QVWCAc+36t5xMEQPFvBlIOZCn/H2qZ8cZJSGfsy/wRtpEQrQLqiWAU0w2d7dchgmU7V2bN7EDRz8fjrwQ7QpGcwghiDERbfTRmz/EZuqjqNbsI8reEz7X8HzQxllcKw1NxXoZRSrQlyYW+VloPmu/Q8UxXMvprv2ePj3pGZalLzb3UcA+3I5Ic19479AzTpGYwgxmAEMQYjLP6Y8R6RLZEPN1d/KKVU2UqHdBlVuffF8CiLtjY65SdbxZDYfN+yNQ6Zoa29V8Xc/4eKUqdur9Y8qXBhufQMRhBjMMLiuymnfrEFYBW6+pppv4S9WMBum1uD52uY/R0N2lU9mDN5OxP8/GU6+PHOCM7q3Qr9T09+9BxoP+ijVRc/VBtBk57BCGIMRhBjMMLijxm1BttD
},
"metadata": {},
"output_type": "display_data"
}
],
"execution_count": 79
},
{
"metadata": {},
"cell_type": "raw",
"source": [
"读取小批量\n",
"为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。 回顾一下在每次迭代中数据加载器每次都会读取一小批量数据大小为batch_size。 通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。"
],
"id": "a55d50ef614248ce"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:55:44.833747Z",
"start_time": "2024-05-23T13:55:44.666932Z"
}
},
"cell_type": "code",
"source": [
"batch_size = 256\n",
"train_iter = tf.data.Dataset.from_tensor_slices(\n",
" mnist_train).batch(batch_size).shuffle(len(mnist_train[0]))"
],
"id": "6e7eaf509e5620ae",
"outputs": [],
"execution_count": 80
},
{
"metadata": {},
"cell_type": "markdown",
"source": "我们看一下读取训练数据所需的时间。",
"id": "153d0ad0514e4c3"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:55:46.690118Z",
"start_time": "2024-05-23T13:55:45.834145Z"
}
},
"cell_type": "code",
"source": [
"timer = d2l.Timer()\n",
"for X, y in train_iter:\n",
" continue\n",
"f'{timer.stop():.2f} sec'"
],
"id": "8449183c346bd7fd",
"outputs": [
{
"data": {
"text/plain": [
"'0.85 sec'"
]
},
"execution_count": 81,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 81
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"整合所有组件\n",
"现在我们定义load_data_fashion_mnist函数用于获取和读取Fashion-MNIST数据集。 这个函数返回训练集和验证集的数据迭代器。 此外这个函数还接受一个可选参数resize用来将图像大小调整为另一种形状。"
],
"id": "f4548b392f4e0b10"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:55:48.027866Z",
"start_time": "2024-05-23T13:55:48.017949Z"
}
},
"cell_type": "code",
"source": [
"def load_data_fashion_mnist(batch_size, resize=None): #@save\n",
" \"\"\"下载Fashion-MNIST数据集然后将其加载到内存中\"\"\"\n",
" mnist_train, mnist_test = tf.keras.datasets.fashion_mnist.load_data()\n",
" # 将所有数字除以255使所有像素值介于0和1之间在最后添加一个批处理维度\n",
" # 并将标签转换为int32。\n",
" process = lambda X, y: (tf.expand_dims(X, axis=3) / 255,\n",
" tf.cast(y, dtype='int32'))\n",
" resize_fn = lambda X, y: (\n",
" tf.image.resize_with_pad(X, resize, resize) if resize else X, y)\n",
" return (\n",
" tf.data.Dataset.from_tensor_slices(process(*mnist_train)).batch(\n",
" batch_size).shuffle(len(mnist_train[0])).map(resize_fn),\n",
" tf.data.Dataset.from_tensor_slices(process(*mnist_test)).batch(\n",
" batch_size).map(resize_fn))"
],
"id": "66f0bc77212a9421",
"outputs": [],
"execution_count": 82
},
{
"metadata": {},
"cell_type": "markdown",
"source": "下面我们通过指定resize参数来测试load_data_fashion_mnist函数的图像大小调整功能。",
"id": "20cd2d13d33077d4"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:55:52.415863Z",
"start_time": "2024-05-23T13:55:49.893184Z"
}
},
"cell_type": "code",
"source": [
"train_iter, test_iter = load_data_fashion_mnist(256, resize=28)\n",
"for X, y in train_iter:\n",
" print(X.shape, X.dtype, y.shape, y.dtype)\n",
" break"
],
"id": "9eb474b9cd02b604",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(256, 28, 28, 1) <dtype: 'float32'> (256,) <dtype: 'int32'>\n"
]
}
],
"execution_count": 83
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"初始化模型参数\n",
"和之前线性回归的例子一样,这里的每个样本都将用固定长度的向量表示。 原始数据集中的每个样本都是28*28\n",
"的图像。 本节将展平每个图像把它们看作长度为784的向量。 在后面的章节中,我们将讨论能够利用图像空间结构的特征, 但现在我们暂时只把每个像素位置看作一个特征。\n",
"\n",
"回想一下在softmax回归中我们的输出与类别一样多。 因为我们的数据集有10个类别所以网络输出维度为10。 因此,权重将构成一个\n",
"784*10的矩阵 偏置将构成一个\n",
"1*10的行向量。 与线性回归一样我们将使用正态分布初始化我们的权重W偏置初始化为0。"
],
"id": "88be3897e31b3f86"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:55:54.766578Z",
"start_time": "2024-05-23T13:55:54.751775Z"
}
},
"cell_type": "code",
"source": [
"num_inputs = 784\n",
"num_outputs = 10\n",
"\n",
"W = tf.Variable(tf.random.normal(shape=(num_inputs, num_outputs),\n",
" mean=0, stddev=0.01))\n",
"b = tf.Variable(tf.zeros(num_outputs))"
],
"id": "30076248e5c261c5",
"outputs": [],
"execution_count": 84
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"定义softmax操作\n",
"在实现softmax回归模型之前我们简要回顾一下sum运算符如何沿着张量中的特定维度工作。 如 2.3.6节和 2.3.6.1节所述, 给定一个矩阵X我们可以对所有元素求和默认情况下。 也可以只求同一个轴上的元素即同一列轴0或同一行轴1。 如果X是一个形状为(2, 3)的张量,我们对列进行求和, 则结果将是一个具有形状(3,)的向量。 当调用sum运算符时我们可以指定保持在原始张量的轴数而不折叠求和的维度。 这将产生一个具有形状(1, 3)的二维张量。"
],
"id": "24404467fc764d51"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:55:56.277161Z",
"start_time": "2024-05-23T13:55:56.267009Z"
}
},
"cell_type": "code",
"source": [
"X = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])\n",
"tf.reduce_sum(X, 0, keepdims=True), tf.reduce_sum(X, 1, keepdims=True)"
],
"id": "8996e17eac4cd1aa",
"outputs": [
{
"data": {
"text/plain": [
"(<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[5., 7., 9.]], dtype=float32)>,\n",
" <tf.Tensor: shape=(2, 1), dtype=float32, numpy=\n",
" array([[ 6.],\n",
" [15.]], dtype=float32)>)"
]
},
"execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 85
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"回想一下实现softmax由三个步骤组成\n",
"\n",
"对每个项求幂使用exp\n",
"\n",
"对每一行求和(小批量中每个样本是一行),得到每个样本的规范化常数;\n",
"\n",
"将每一行除以其规范化常数确保结果的和为1。\n",
"分母或规范化常数,有时也称为配分函数(其对数称为对数-配分函数)。 该名称来自统计物理学中一个模拟粒子群分布的方程"
],
"id": "d6df09b9c449bf35"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:55:58.885965Z",
"start_time": "2024-05-23T13:55:58.880615Z"
}
},
"cell_type": "code",
"source": [
"def softmax(X):\n",
" X_exp = tf.exp(X)\n",
" partition = tf.reduce_sum(X_exp, 1, keepdims=True)\n",
" return X_exp / partition # 这里应用了广播机制"
],
"id": "2feba1ee37842658",
"outputs": [],
"execution_count": 86
},
{
"metadata": {},
"cell_type": "raw",
"source": "正如上述代码,对于任何随机输入,我们将每个元素变成一个非负数。 此外依据概率原理每行总和为1。",
"id": "eca590ff7b061202"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:56:00.848361Z",
"start_time": "2024-05-23T13:56:00.835263Z"
}
},
"cell_type": "code",
"source": [
"X = tf.random.normal((2, 5), 0, 1)\n",
"X_prob = softmax(X)\n",
"X_prob, tf.reduce_sum(X_prob, 1)"
],
"id": "1abc865b77fc69ef",
"outputs": [
{
"data": {
"text/plain": [
"(<tf.Tensor: shape=(2, 5), dtype=float32, numpy=\n",
" array([[0.08787813, 0.18106402, 0.27193552, 0.3478207 , 0.11130168],\n",
" [0.44081137, 0.18976508, 0.16773126, 0.10921898, 0.09247335]],\n",
" dtype=float32)>,\n",
" <tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 1.], dtype=float32)>)"
]
},
"execution_count": 87,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 87
},
{
"metadata": {},
"cell_type": "raw",
"source": "注意,虽然这在数学上看起来是正确的,但我们在代码实现中有点草率。 矩阵中的非常大或非常小的元素可能造成数值上溢或下溢,但我们没有采取措施来防止这点。",
"id": "e89ab1dd4ac2b0f0"
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"定义模型\n",
"定义softmax操作后我们可以实现softmax回归模型。 下面的代码定义了输入如何通过网络映射到输出。 注意将数据传递到模型之前我们使用reshape函数将每张原始图像展平为向量。"
],
"id": "39285267234028ba"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:56:03.856065Z",
"start_time": "2024-05-23T13:56:03.850533Z"
}
},
"cell_type": "code",
"source": [
"def net(X):\n",
" return softmax(tf.matmul(tf.reshape(X, (-1, W.shape[0])), W) + b)"
],
"id": "4297616ce7f0d9ec",
"outputs": [],
"execution_count": 88
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"接下来,我们实现 3.4节中引入的交叉熵损失函数。 这可能是深度学习中最常见的损失函数,因为目前分类问题的数量远远超过回归问题的数量。\n",
"\n",
"回顾一下,交叉熵采用真实标签的预测概率的负对数似然。 这里我们不使用Python的for循环迭代预测这往往是低效的 而是通过一个运算符选择所有元素。 下面我们创建一个数据样本y_hat其中包含2个样本在3个类别的预测概率 以及它们对应的标签y。 有了y我们知道在第一个样本中第一类是正确的预测 而在第二个样本中,第三类是正确的预测。 然后使用y作为y_hat中概率的索引 我们选择第一个样本中第一个类的概率和第二个样本中第三个类的概率。"
],
"id": "e74a7617b82fe7cd"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:56:07.408723Z",
"start_time": "2024-05-23T13:56:07.385287Z"
}
},
"cell_type": "code",
"source": [
"y_hat = tf.constant([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])\n",
"y = tf.constant([0, 2])\n",
"tf.boolean_mask(y_hat, tf.one_hot(y, depth=y_hat.shape[-1]))"
],
"id": "b6a18397b5cb8b16",
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.1, 0.5], dtype=float32)>"
]
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 89
},
{
"metadata": {},
"cell_type": "markdown",
"source": "现在我们只需一行代码就可以实现交叉熵损失函数。",
"id": "f0dc345648fa9239"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:56:09.228410Z",
"start_time": "2024-05-23T13:56:09.212464Z"
}
},
"cell_type": "code",
"source": [
"def cross_entropy(y_hat, y):\n",
" return -tf.math.log(tf.boolean_mask(\n",
" y_hat, tf.one_hot(y, depth=y_hat.shape[-1])))\n",
"\n",
"cross_entropy(y_hat, y)"
],
"id": "515fc3bb49eb7a6e",
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(2,), dtype=float32, numpy=array([2.3025851, 0.6931472], dtype=float32)>"
]
},
"execution_count": 90,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 90
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"分类精度\n",
"给定预测概率分布y_hat当我们必须输出硬预测hard prediction 我们通常选择预测概率最高的类。 许多应用都要求我们做出选择。如Gmail必须将电子邮件分类为“Primary主要邮件”、 “Social社交邮件”“Updates更新邮件”或“Forums论坛邮件”。 Gmail做分类时可能在内部估计概率但最终它必须在类中选择一个。\n",
"\n",
"当预测与标签分类y一致时即是正确的。 分类精度即正确预测数量与总预测数量之比。 虽然直接优化精度可能很困难(因为精度的计算不可导), 但精度通常是我们最关心的性能衡量标准,我们在训练分类器时几乎总会关注它。\n",
"\n",
"为了计算精度,我们执行以下操作。 首先如果y_hat是矩阵那么假定第二个维度存储每个类的预测分数。 我们使用argmax获得每行中最大元素的索引来获得预测类别。 然后我们将预测类别与真实y元素进行比较。 由于等式运算符“==”对数据类型很敏感, 因此我们将y_hat的数据类型转换为与y的数据类型一致。 结果是一个包含0和1的张量。 最后,我们求和会得到正确预测的数量。"
],
"id": "8591a222911a1d49"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:56:11.093755Z",
"start_time": "2024-05-23T13:56:11.086032Z"
}
},
"cell_type": "code",
"source": [
"def accuracy(y_hat, y): #@save\n",
" \"\"\"计算预测正确的数量\"\"\"\n",
" if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:\n",
" y_hat = tf.argmax(y_hat, axis=1)\n",
" cmp = tf.cast(y_hat, y.dtype) == y\n",
" return float(tf.reduce_sum(tf.cast(cmp, y.dtype)))"
],
"id": "9a4d9c33b0f605e9",
"outputs": [],
"execution_count": 91
},
{
"metadata": {},
"cell_type": "markdown",
"source": "我们将继续使用之前定义的变量y_hat和y分别作为预测的概率分布和标签。 可以看到第一个样本的预测类别是2该行的最大元素为0.6索引为2这与实际标签0不一致。 第二个样本的预测类别是2该行的最大元素为0.5索引为2这与实际标签2一致。 因此这两个样本的分类精度率为0.5。",
"id": "e1b3e1a7eac06018"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:56:12.750879Z",
"start_time": "2024-05-23T13:56:12.738448Z"
}
},
"cell_type": "code",
"source": "accuracy(y_hat, y) / len(y)",
"id": "a345cbf14b5a2e2e",
"outputs": [
{
"data": {
"text/plain": [
"0.5"
]
},
"execution_count": 92,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 92
},
{
"metadata": {},
"cell_type": "markdown",
"source": "同样对于任意数据迭代器data_iter可访问的数据集 我们可以评估在任意模型net的精度。",
"id": "cbf4d36f92baba8"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:56:14.529734Z",
"start_time": "2024-05-23T13:56:14.524295Z"
}
},
"cell_type": "code",
"source": [
"def evaluate_accuracy(net, data_iter): #@save\n",
" \"\"\"计算在指定数据集上模型的精度\"\"\"\n",
" metric = Accumulator(2) # 正确预测数、预测总数\n",
" for X, y in data_iter:\n",
" metric.add(accuracy(net(X), y), d2l.size(y))\n",
" return metric[0] / metric[1]"
],
"id": "16498ba0b0cd9b7d",
"outputs": [],
"execution_count": 93
},
{
"metadata": {},
"cell_type": "markdown",
"source": "这里定义一个实用程序类Accumulator用于对多个变量进行累加。 在上面的evaluate_accuracy函数中 我们在Accumulator实例中创建了2个变量 分别用于存储正确预测的数量和预测的总数量。 当我们遍历数据集时,两者都将随着时间的推移而累加。",
"id": "a9cbccb99a341622"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:56:15.962946Z",
"start_time": "2024-05-23T13:56:15.955628Z"
}
},
"cell_type": "code",
"source": [
"class Accumulator: #@save\n",
" \"\"\"在n个变量上累加\"\"\"\n",
" def __init__(self, n):\n",
" self.data = [0.0] * n\n",
"\n",
" def add(self, *args):\n",
" self.data = [a + float(b) for a, b in zip(self.data, args)]\n",
"\n",
" def reset(self):\n",
" self.data = [0.0] * len(self.data)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.data[idx]"
],
"id": "3262f7c8ccf1a36a",
"outputs": [],
"execution_count": 94
},
{
"metadata": {},
"cell_type": "markdown",
"source": "由于我们使用随机权重初始化net模型 因此该模型的精度应接近于随机猜测。 例如在有10个类别情况下的精度为0.1。",
"id": "a27b1d7ccbe16630"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:56:18.214919Z",
"start_time": "2024-05-23T13:56:17.970672Z"
}
},
"cell_type": "code",
"source": "evaluate_accuracy(net, test_iter)",
"id": "1ee7f5cab9450434",
"outputs": [
{
"data": {
"text/plain": [
"0.0766"
]
},
"execution_count": 95,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 95
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"训练\n",
"在我们看过 3.2节中的线性回归实现, softmax回归的训练过程代码应该看起来非常眼熟。 在这里,我们重构训练过程的实现以使其可重复使用。 首先,我们定义一个函数来训练一个迭代周期。 请注意updater是更新模型参数的常用函数它接受批量大小作为参数。 它可以是d2l.sgd函数也可以是框架的内置优化函数。"
],
"id": "35eab5640406d171"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:56:20.124699Z",
"start_time": "2024-05-23T13:56:20.112917Z"
}
},
"cell_type": "code",
"source": [
"def train_epoch_ch3(net, train_iter, loss, updater): #@save\n",
" \"\"\"训练模型一个迭代周期定义见第3章\"\"\"\n",
" # 训练损失总和、训练准确度总和、样本数\n",
" metric = Accumulator(3)\n",
" for X, y in train_iter:\n",
" # 计算梯度并更新参数\n",
" with tf.GradientTape() as tape:\n",
" y_hat = net(X)\n",
" # Keras内置的损失接受的是标签预测这不同于用户在本书中的实现。\n",
" # 本书的实现接受(预测,标签),例如我们上面实现的“交叉熵”\n",
" if isinstance(loss, tf.keras.losses.Loss):\n",
" l = loss(y, y_hat)\n",
" else:\n",
" l = loss(y_hat, y)\n",
" if isinstance(updater, tf.keras.optimizers.Optimizer):\n",
" params = net.trainable_variables\n",
" grads = tape.gradient(l, params)\n",
" updater.apply_gradients(zip(grads, params))\n",
" else:\n",
" updater(X.shape[0], tape.gradient(l, updater.params))\n",
" # Keras的loss默认返回一个批量的平均损失\n",
" l_sum = l * float(tf.size(y)) if isinstance(\n",
" loss, tf.keras.losses.Loss) else tf.reduce_sum(l)\n",
" metric.add(l_sum, accuracy(y_hat, y), tf.size(y))\n",
" # 返回训练损失和训练精度\n",
" return metric[0] / metric[2], metric[1] / metric[2]"
],
"id": "fba5d39e709f5567",
"outputs": [],
"execution_count": 96
},
{
"metadata": {},
"cell_type": "markdown",
"source": "在展示训练函数的实现之前我们定义一个在动画中绘制数据的实用程序类Animator 它能够简化本书其余部分的代码。",
"id": "e971a63d39f5fc4f"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:56:21.994572Z",
"start_time": "2024-05-23T13:56:21.979102Z"
}
},
"cell_type": "code",
"source": [
"class Animator: #@save\n",
" \"\"\"在动画中绘制数据\"\"\"\n",
" def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,\n",
" ylim=None, xscale='linear', yscale='linear',\n",
" fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,\n",
" figsize=(3.5, 2.5)):\n",
" # 增量地绘制多条线\n",
" if legend is None:\n",
" legend = []\n",
" d2l.use_svg_display()\n",
" self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)\n",
" if nrows * ncols == 1:\n",
" self.axes = [self.axes, ]\n",
" # 使用lambda函数捕获参数\n",
" self.config_axes = lambda: d2l.set_axes(\n",
" self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)\n",
" self.X, self.Y, self.fmts = None, None, fmts\n",
"\n",
" def add(self, x, y):\n",
" # 向图表中添加多个数据点\n",
" if not hasattr(y, \"__len__\"):\n",
" y = [y]\n",
" n = len(y)\n",
" if not hasattr(x, \"__len__\"):\n",
" x = [x] * n\n",
" if not self.X:\n",
" self.X = [[] for _ in range(n)]\n",
" if not self.Y:\n",
" self.Y = [[] for _ in range(n)]\n",
" for i, (a, b) in enumerate(zip(x, y)):\n",
" if a is not None and b is not None:\n",
" self.X[i].append(a)\n",
" self.Y[i].append(b)\n",
" self.axes[0].cla()\n",
" for x, y, fmt in zip(self.X, self.Y, self.fmts):\n",
" self.axes[0].plot(x, y, fmt)\n",
" self.config_axes()\n",
" display.display(self.fig)\n",
" display.clear_output(wait=True)"
],
"id": "b860c7019aa787f5",
"outputs": [],
"execution_count": 97
},
{
"metadata": {},
"cell_type": "markdown",
"source": "接下来我们实现一个训练函数, 它会在train_iter访问到的训练数据集上训练一个模型net。 该训练函数将会运行多个迭代周期由num_epochs指定。 在每个迭代周期结束时利用test_iter访问到的测试数据集对模型进行评估。 我们将利用Animator类来可视化训练进度。\n",
"id": "74368ee5f4a0c274"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:56:23.626251Z",
"start_time": "2024-05-23T13:56:23.615965Z"
}
},
"cell_type": "code",
"source": [
"def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater): #@save\n",
" \"\"\"训练模型定义见第3章\"\"\"\n",
" animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],\n",
" legend=['train loss', 'train acc', 'test acc'])\n",
" for epoch in range(num_epochs):\n",
" train_metrics = train_epoch_ch3(net, train_iter, loss, updater)\n",
" test_acc = evaluate_accuracy(net, test_iter)\n",
" animator.add(epoch + 1, train_metrics + (test_acc,))\n",
" train_loss, train_acc = train_metrics\n",
" assert train_loss < 0.5, train_loss\n",
" assert train_acc <= 1 and train_acc > 0.7, train_acc\n",
" assert test_acc <= 1 and test_acc > 0.7, test_acc"
],
"id": "824036eccf6943c2",
"outputs": [],
"execution_count": 98
},
{
"metadata": {},
"cell_type": "markdown",
"source": "作为一个从零开始的实现,我们使用 3.2节中定义的 小批量随机梯度下降来优化模型的损失函数设置学习率为0.1。",
"id": "78b1baf42d9dad20"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-23T13:56:25.196135Z",
"start_time": "2024-05-23T13:56:25.188277Z"
}
},
"cell_type": "code",
"source": [
"class Updater(): #@save\n",
" \"\"\"用小批量随机梯度下降法更新参数\"\"\"\n",
" def __init__(self, params, lr):\n",
" self.params = params\n",
" self.lr = lr\n",
"\n",
" def __call__(self, batch_size, grads):\n",
" d2l.sgd(self.params, grads, self.lr, batch_size)\n",
"\n",
"updater = Updater([W, b], lr=0.1)"
],
"id": "c353ba4d9117d412",
"outputs": [],
"execution_count": 99
},
{
"metadata": {},
"cell_type": "markdown",
"source": "现在我们训练模型10个迭代周期。 请注意迭代周期num_epochs和学习率lr都是可调节的超参数。 通过更改它们的值,我们可以提高模型的分类精度。",
"id": "e09156cf6f1f9ad4"
},
{
"metadata": {},
"cell_type": "code",
"source": [
"num_epochs = 10\n",
"train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)"
],
"id": "d62c9cdb2b662990",
"execution_count": 100,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"预测\n",
"现在训练已经完成,我们的模型已经准备好对图像进行分类预测。 给定一系列图像,我们将比较它们的实际标签(文本输出的第一行)和模型预测(文本输出的第二行)。"
],
"id": "1bf141339d14be4e"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"def predict_ch3(net, test_iter, n=6): #@save\n",
" \"\"\"预测标签定义见第3章\"\"\"\n",
" for X, y in test_iter:\n",
" break\n",
" trues = d2l.get_fashion_mnist_labels(y)\n",
" preds = d2l.get_fashion_mnist_labels(tf.argmax(net(X), axis=1))\n",
" titles = [true +'\\n' + pred for true, pred in zip(trues, preds)]\n",
" d2l.show_images(\n",
" tf.reshape(X[0:n], (n, 28, 28)), 1, n, titles=titles[0:n])\n",
"\n",
"predict_ch3(net, test_iter)"
],
"id": "82e8b2104d47c1ea"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}