利用Tensorflow提高Iris ML模型的精度

我是Python和ML的初学者。我正在练习这个Iris数据集,以使用张量流2.0创建一个ML模型。

我解析了csv,并使用该数据集训练模型。在我创建模型的过程中,我能够得到90 %的训练精度和91 %的验证精度。

import tensorflow as tf
import numpy as np
from sklearn import preprocessing

csv_data = np.loadtxt('iris_training.csv',delimiter=',')
target_all = csv_data[:,-1]

csv_data = csv_data[:,0:-1]

# Shuffling the input
shuffled_indices = np.arange(csv_data.shape[0])
np.random.shuffle(shuffled_indices)

shuffled_inputs = csv_data[shuffled_indices]
shuffled_targets = target_all[shuffled_indices]

# Standardize the Inputs
shuffled_inputs = preprocessing.scale(shuffled_inputs)

# Split date into train , validation and test
total_count = shuffled_inputs.shape[0]
train_data_count = int(0.8*total_count)
validation_data_count = int(0.1*total_count)
test_data_count = total_count - train_data_count - validation_data_count

train_inputs = shuffled_inputs[:train_data_count]
train_targets = shuffled_targets[:train_data_count]

validation_inputs = shuffled_inputs[train_data_count:train_data_count+validation_data_count]
validation_targets = shuffled_targets[train_data_count:train_data_count+validation_data_count]

test_inputs = shuffled_inputs[train_data_count+validation_data_count:]
test_targets = shuffled_targets[train_data_count+validation_data_count:]


print(len(train_inputs))
print(len(validation_inputs))
print(len(test_inputs))

# Model Creation
input_size = 4
hidden_layer_size = 100
output_size = 3

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(hidden_layer_size, input_dim=input_size,     activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(hidden_layer_size, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(output_size, activation=tf.nn.softmax))
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',metrics=['accuracy'])

model.fit(train_inputs,train_targets, epochs=10, validation_data=(validation_inputs, validation_targets), verbose=2)

prediction = model.predict(test_inputs)

如果我的代码中有什么地方可以提高我的模型在这个简单的虹膜数据集的准确性,请指点我。

用于训练我的模型的文件。Iris Csv

解决方案:

至于你的模型,你可以尝试做以下工作 超参数调整,

  • 将学习率设置为一个较低的值
  • 增加 时代
  • 添加 更多的训练数据集 因为你有一个小的数据集。

神经网络大放异彩 当有 数据量大 的培训。

您还可以添加 更多层次 在模型中添加 辍学,以避免过度适应以及使用 不同的激活功能.

这些都是影响模型性能的常见因素。

给TA打赏
共{{data.count}}人
人已打赏
解决方案

在Flutter Dart中从字符串中获取日期并转换为Date对象。

2022-5-12 7:00:21

解决方案

Javascript购物清单声明让

2022-5-12 7:00:24

0 条回复 A文章作者 M管理员
    暂无讨论,说说你的看法吧
个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索