快活林资源网 Design By www.csstdc.com
使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里开启多进程,会影响数据的读取和预处理效率,在本文中就不在叙述了,有需要的可以另外去百度。
下面是我所使用的代码
class SequenceData(Sequence): def __init__(self, path, batch_size=32): self.path = path self.batch_size = batch_size f = open(path) self.datas = f.readlines() self.L = len(self.datas) self.index = random.sample(range(self.L), self.L) #返回长度,通过len(<你的实例>)调用 def __len__(self): return self.L - self.batch_size #即通过索引获取a[0],a[1]这种 def __getitem__(self, idx): batch_indexs = self.index[idx:(idx+self.batch_size)] batch_datas = [self.datas[k] for k in batch_indexs] img1s,img2s,audios,labels = self.data_generation(batch_datas) return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels}) def data_generation(self, batch_datas): #预处理操作 return img1s,img2s,audios,labels
然后在代码里通过fit_generation函数调用并训练
这里要注意,use_multiprocessing参数是是否开启多进程,由于python的多线程不是真的多线程,所以多进程还是会获得比较客观的加速,但不支持windows,windows下python无法使用多进程。
D = SequenceData('train.csv') model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)), epochs=2, workers=20, #callbacks=[checkpoint], use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32))
同样的,也可以在测试的时候使用
model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)
补充知识:keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练
我就废话不多说了,大家还是直接看代码吧~
#coding=utf-8 ''' Created on 2018-7-10 ''' import keras import math import os import cv2 import numpy as np from keras.models import Sequential from keras.layers import Dense class DataGenerator(keras.utils.Sequence): def __init__(self, datas, batch_size=1, shuffle=True): self.batch_size = batch_size self.datas = datas self.indexes = np.arange(len(self.datas)) self.shuffle = shuffle def __len__(self): #计算每一个epoch的迭代次数 return math.ceil(len(self.datas) / float(self.batch_size)) def __getitem__(self, index): #生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了 # 生成batch_size个索引 batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size] # 根据索引获取datas集合中的数据 batch_datas = [self.datas[k] for k in batch_indexs] # 生成数据 X, y = self.data_generation(batch_datas) return X, y def on_epoch_end(self): #在每一次epoch结束是否需要进行一次随机,重新随机一下index if self.shuffle == True: np.random.shuffle(self.indexes) def data_generation(self, batch_datas): images = [] labels = [] # 生成数据 for i, data in enumerate(batch_datas): #x_train数据 image = cv2.imread(data) image = list(image) images.append(image) #y_train数据 right = data.rfind("\\",0) left = data.rfind("\\",0,right)+1 class_name = data[left:right] if class_name=="dog": labels.append([0,1]) else: labels.append([1,0]) #如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3] return np.array(images), np.array(labels) # 读取样本名称,然后根据样本名称去读取数据 class_num = 0 train_datas = [] for file in os.listdir("D:/xxx"): file_path = os.path.join("D:/xxx", file) if os.path.isdir(file_path): class_num = class_num + 1 for sub_file in os.listdir(file_path): train_datas.append(os.path.join(file_path, sub_file)) # 数据生成器 training_generator = DataGenerator(train_datas) #构建网络 model = Sequential() model.add(Dense(units=64, activation='relu', input_dim=784)) model.add(Dense(units=2, activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy']) model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)
以上这篇keras使用Sequence类调用大规模数据集进行训练的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
快活林资源网 Design By www.csstdc.com
广告合作:本站广告合作请联系QQ:858582 申请时备注:广告合作(否则不回)
免责声明:本站资源来自互联网收集,仅供用于学习和交流,请遵循相关法律法规,本站一切资源不代表本站立场,如有侵权、后门、不妥请联系本站删除!
免责声明:本站资源来自互联网收集,仅供用于学习和交流,请遵循相关法律法规,本站一切资源不代表本站立场,如有侵权、后门、不妥请联系本站删除!
快活林资源网 Design By www.csstdc.com
暂无评论...
更新日志
2025年01月10日
2025年01月10日
- 小骆驼-《草原狼2(蓝光CD)》[原抓WAV+CUE]
- 群星《欢迎来到我身边 电影原声专辑》[320K/MP3][105.02MB]
- 群星《欢迎来到我身边 电影原声专辑》[FLAC/分轨][480.9MB]
- 雷婷《梦里蓝天HQⅡ》 2023头版限量编号低速原抓[WAV+CUE][463M]
- 群星《2024好听新歌42》AI调整音效【WAV分轨】
- 王思雨-《思念陪着鸿雁飞》WAV
- 王思雨《喜马拉雅HQ》头版限量编号[WAV+CUE]
- 李健《无时无刻》[WAV+CUE][590M]
- 陈奕迅《酝酿》[WAV分轨][502M]
- 卓依婷《化蝶》2CD[WAV+CUE][1.1G]
- 群星《吉他王(黑胶CD)》[WAV+CUE]
- 齐秦《穿乐(穿越)》[WAV+CUE]
- 发烧珍品《数位CD音响测试-动向效果(九)》【WAV+CUE】
- 邝美云《邝美云精装歌集》[DSF][1.6G]
- 吕方《爱一回伤一回》[WAV+CUE][454M]