51 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			51 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import tensorflow as tf
 | ||
| from tensorflow.keras.models import Sequential
 | ||
| from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
 | ||
| 
 | ||
| # 加载Fashion-MNIST数据集
 | ||
| fashion_mnist = tf.keras.datasets.fashion_mnist
 | ||
| (X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
 | ||
| 
 | ||
| # 数据预处理
 | ||
| X_train = X_train / 255.0  # 将像素值缩放到0-1之间
 | ||
| X_test = X_test / 255.0
 | ||
| 
 | ||
| # 如果使用卷积神经网络(CNN),需要调整数据形状
 | ||
| X_train = X_train.reshape((X_train.shape[0], 28, 28, 1))
 | ||
| X_test = X_test.reshape((X_test.shape[0], 28, 28, 1))
 | ||
| 
 | ||
| # 标签数据保持不变
 | ||
| print(f'Training data shape: {X_train.shape}, Training labels shape: {y_train.shape}')
 | ||
| print(f'Test data shape: {X_test.shape}, Test labels shape: {y_test.shape}')
 | ||
| 
 | ||
| # 构建模型
 | ||
| model = Sequential([
 | ||
|     Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
 | ||
|     MaxPooling2D((2, 2)),
 | ||
|     Conv2D(64, (3, 3), activation='relu'),
 | ||
|     MaxPooling2D((2, 2)),
 | ||
|     Flatten(),
 | ||
|     Dense(128, activation='relu'),
 | ||
|     Dense(10, activation='softmax')  # 输出层使用softmax进行10分类
 | ||
| ])
 | ||
| 
 | ||
| # # 编译模型
 | ||
| # model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
 | ||
| #
 | ||
| # # 训练模型
 | ||
| # model.fit(X_train, y_train, epochs=15, batch_size=32, validation_split=0.2)
 | ||
| 
 | ||
| early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
 | ||
| lr_schedule = tf.keras.callbacks.LearningRateScheduler(lambda epoch: 1e-3 * 10**(epoch / 20))
 | ||
| 
 | ||
| model.compile(optimizer=tf.keras.optimizers.Adam(), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
 | ||
| history = model.fit(X_train, y_train, epochs=100, batch_size=32, validation_split=0.2, callbacks=[early_stopping, lr_schedule])
 | ||
| 
 | ||
| 
 | ||
| # 评估模型
 | ||
| loss, accuracy = model.evaluate(X_test, y_test)
 | ||
| print(f'Test Loss: {loss}')
 | ||
| print(f'Test Accuracy: {accuracy}')
 | ||
| 
 | ||
| model.save('fashion_mnist_model.h5')
 | 
